2个文件已删除
3个文件已添加
1 文件已重命名
7个文件已修改
| | |
| | | renameAction: "Rename session", |
| | | deleteAction: "Delete session", |
| | | toolTrace: "Tool Trace", |
| | | activityTrace: "Thinking & Tool Trace", |
| | | thinkingProcess: "Thinking Process", |
| | | thinkingExpand: "Show Thinking Process", |
| | | thinkingCollapse: "Hide Thinking Process", |
| | |
| | | thinkingStatusFailed: "Failed", |
| | | thinkingStatusAborted: "Aborted", |
| | | noToolTrace: "No tool call was triggered in this round", |
| | | noActivityTrace: "No thinking or tool trace in this round", |
| | | unknownTool: "Unknown tool", |
| | | traceTypeThinking: "Thinking", |
| | | traceTypeTool: "Tool", |
| | | toolStatusFailed: "Failed", |
| | | toolStatusCompleted: "Completed", |
| | | toolStatusRunning: "Running", |
| | |
| | | renameAction: "重命名会话", |
| | | deleteAction: "删除会话", |
| | | toolTrace: "工具调用轨迹", |
| | | activityTrace: "思维链与工具轨迹", |
| | | thinkingProcess: "思考过程", |
| | | thinkingExpand: "展开思考过程", |
| | | thinkingCollapse: "收起思考过程", |
| | |
| | | thinkingStatusFailed: "失败", |
| | | thinkingStatusAborted: "已中止", |
| | | noToolTrace: "当前轮未触发工具调用", |
| | | noActivityTrace: "当前轮尚无思考或工具轨迹", |
| | | unknownTool: "未知工具", |
| | | traceTypeThinking: "思维链", |
| | | traceTypeTool: "工具", |
| | | toolStatusFailed: "失败", |
| | | toolStatusCompleted: "完成", |
| | | toolStatusRunning: "执行中", |
| | |
| | | const DEFAULT_PROMPT_CODE = "home.default"; |
| | | const AI_CHAT_DRAWER_Z_INDEX = 1400; |
| | | const AI_CHAT_DIALOG_Z_INDEX = AI_CHAT_DRAWER_Z_INDEX + 20; |
| | | const THINKING_PHASE_ORDER = { |
| | | ANALYZE: 0, |
| | | TOOL_CALL: 1, |
| | | ANSWER: 2, |
| | | }; |
| | | |
| | | const normalizeMarkdownContent = (content) => { |
| | | if (!content) { |
| | |
| | | const [sessions, setSessions] = useState([]); |
| | | const [persistedMessages, setPersistedMessages] = useState([]); |
| | | const [messages, setMessages] = useState([]); |
| | | const [toolEvents, setToolEvents] = useState([]); |
| | | const [expandedToolIds, setExpandedToolIds] = useState([]); |
| | | const [thinkingEvents, setThinkingEvents] = useState([]); |
| | | const [thinkingExpanded, setThinkingExpanded] = useState(true); |
| | | const [traceEvents, setTraceEvents] = useState([]); |
| | | const [expandedTraceIds, setExpandedTraceIds] = useState([]); |
| | | const [input, setInput] = useState(""); |
| | | const [loadingRuntime, setLoadingRuntime] = useState(false); |
| | | const [streaming, setStreaming] = useState(false); |
| | |
| | | }; |
| | | }, [runtime]); |
| | | |
| | | const currentThinkingMessageIndex = useMemo(() => { |
| | | if (!thinkingEvents.length || !messages.length) { |
| | | return -1; |
| | | } |
| | | for (let i = messages.length - 1; i >= 0; i -= 1) { |
| | | if (messages[i]?.role === "assistant") { |
| | | return i; |
| | | } |
| | | } |
| | | return -1; |
| | | }, [messages, thinkingEvents]); |
| | | |
| | | useEffect(() => { |
| | | if (open) { |
| | | setRuntimePanelExpanded(false); |
| | |
| | | }, [open, messages, streaming]); |
| | | |
| | | const initializeDrawer = async (targetSessionId = null) => { |
| | | setToolEvents([]); |
| | | setExpandedToolIds([]); |
| | | setThinkingEvents([]); |
| | | setThinkingExpanded(true); |
| | | setTraceEvents([]); |
| | | setExpandedTraceIds([]); |
| | | await Promise.all([ |
| | | loadRuntime(targetSessionId), |
| | | loadSessions(sessionKeyword), |
| | |
| | | setSessionId(null); |
| | | setPersistedMessages([]); |
| | | setMessages([]); |
| | | setToolEvents([]); |
| | | setExpandedToolIds([]); |
| | | setThinkingEvents([]); |
| | | setThinkingExpanded(true); |
| | | setTraceEvents([]); |
| | | setExpandedTraceIds([]); |
| | | setUsage(null); |
| | | setDrawerError(""); |
| | | }; |
| | |
| | | return; |
| | | } |
| | | setUsage(null); |
| | | setToolEvents([]); |
| | | setExpandedToolIds([]); |
| | | setThinkingEvents([]); |
| | | setThinkingExpanded(true); |
| | | setTraceEvents([]); |
| | | setExpandedTraceIds([]); |
| | | await loadRuntime(targetSessionId); |
| | | }; |
| | | |
| | |
| | | return next; |
| | | }; |
| | | |
| | | const upsertToolEvent = (payload) => { |
| | | if (!payload?.toolCallId) { |
| | | const appendTraceEvent = (payload) => { |
| | | if (!payload?.traceId) { |
| | | return; |
| | | } |
| | | setToolEvents((prev) => { |
| | | const index = prev.findIndex((item) => item.toolCallId === payload.toolCallId); |
| | | if (index < 0) { |
| | | return [...prev, payload]; |
| | | } |
| | | const next = [...prev]; |
| | | next[index] = { ...next[index], ...payload }; |
| | | return next; |
| | | }); |
| | | }; |
| | | |
| | | const toggleToolEventExpanded = (toolCallId) => { |
| | | if (!toolCallId) { |
| | | return; |
| | | } |
| | | setExpandedToolIds((prev) => ( |
| | | prev.includes(toolCallId) |
| | | ? prev.filter((item) => item !== toolCallId) |
| | | : [...prev, toolCallId] |
| | | )); |
| | | }; |
| | | |
| | | const upsertThinkingEvent = (payload) => { |
| | | if (!payload?.phase) { |
| | | return; |
| | | } |
| | | setThinkingEvents((prev) => { |
| | | const index = prev.findIndex((item) => item.phase === payload.phase); |
| | | setTraceEvents((prev) => { |
| | | const index = prev.findIndex((item) => item.traceId === payload.traceId); |
| | | if (index < 0) { |
| | | return [...prev, payload].sort((left, right) => ( |
| | | (THINKING_PHASE_ORDER[left.phase] ?? Number.MAX_SAFE_INTEGER) |
| | | - (THINKING_PHASE_ORDER[right.phase] ?? Number.MAX_SAFE_INTEGER) |
| | | (left?.sequence ?? 0) - (right?.sequence ?? 0) |
| | | )); |
| | | } |
| | | const next = [...prev]; |
| | |
| | | }); |
| | | }; |
| | | |
| | | const toggleThinkingExpanded = () => { |
| | | setThinkingExpanded((prev) => !prev); |
| | | const toggleTraceEventExpanded = (traceId) => { |
| | | if (!traceId) { |
| | | return; |
| | | } |
| | | setExpandedTraceIds((prev) => ( |
| | | prev.includes(traceId) |
| | | ? prev.filter((item) => item !== traceId) |
| | | : [...prev, traceId] |
| | | )); |
| | | }; |
| | | |
| | | const getThinkingStatusLabel = (status) => { |
| | |
| | | return translate("ai.drawer.thinkingStatusStarted"); |
| | | }; |
| | | |
| | | const getToolStatusLabel = (status) => { |
| | | if (status === "FAILED") { |
| | | return translate("ai.drawer.toolStatusFailed"); |
| | | } |
| | | if (status === "COMPLETED") { |
| | | return translate("ai.drawer.toolStatusCompleted"); |
| | | } |
| | | return translate("ai.drawer.toolStatusRunning"); |
| | | }; |
| | | |
| | | const handleSend = async () => { |
| | | const content = input.trim(); |
| | | if (!content || streaming) { |
| | |
| | | setInput(""); |
| | | setUsage(null); |
| | | setDrawerError(""); |
| | | setToolEvents([]); |
| | | setExpandedToolIds([]); |
| | | setThinkingEvents([]); |
| | | setThinkingExpanded(true); |
| | | setTraceEvents([]); |
| | | setExpandedTraceIds([]); |
| | | setMessages(ensureAssistantPlaceholder(nextMessages)); |
| | | setStreaming(true); |
| | | |
| | |
| | | if (eventName === "delta") { |
| | | appendAssistantDelta(payload?.content || ""); |
| | | } |
| | | if (eventName === "tool_start" || eventName === "tool_result" || eventName === "tool_error") { |
| | | upsertToolEvent(payload); |
| | | } |
| | | if (eventName === "thinking") { |
| | | upsertThinkingEvent(payload); |
| | | if (eventName === "trace") { |
| | | appendTraceEvent(payload); |
| | | } |
| | | if (eventName === "done") { |
| | | setUsage(payload); |
| | |
| | | > |
| | | <Box px={2} py={1.5} display="flex" flexDirection="column" minHeight={0}> |
| | | <Typography variant="subtitle2" mb={1}> |
| | | {translate("ai.drawer.toolTrace")} |
| | | {translate("ai.drawer.activityTrace")} |
| | | </Typography> |
| | | <Paper variant="outlined" sx={{ flex: 1, minHeight: { xs: 140, md: 0 }, overflow: "hidden", bgcolor: "grey.50" }}> |
| | | {!toolEvents.length ? ( |
| | | {!traceEvents.length ? ( |
| | | <Box px={1.5} py={1.25}> |
| | | <Typography variant="body2" color="text.secondary"> |
| | | {translate("ai.drawer.noToolTrace")} |
| | | {translate("ai.drawer.noActivityTrace")} |
| | | </Typography> |
| | | </Box> |
| | | ) : ( |
| | | <Stack spacing={1} sx={{ p: 1.25, maxHeight: { xs: 220, md: "calc(100vh - 180px)" }, overflow: "auto" }}> |
| | | {toolEvents.map((item) => ( |
| | | {traceEvents.map((item) => ( |
| | | <Paper |
| | | key={item.toolCallId} |
| | | key={item.traceId} |
| | | variant="outlined" |
| | | sx={{ |
| | | p: 1.25, |
| | | bgcolor: item.status === "FAILED" ? "error.lighter" : "common.white", |
| | | borderColor: item.status === "FAILED" ? "error.light" : "divider", |
| | | bgcolor: item.status === "FAILED" |
| | | ? "error.lighter" |
| | | : item.traceType === "thinking" |
| | | ? "info.lighter" |
| | | : "common.white", |
| | | borderColor: item.status === "FAILED" |
| | | ? "error.light" |
| | | : item.traceType === "thinking" |
| | | ? "info.light" |
| | | : "divider", |
| | | }} |
| | | > |
| | | <Stack direction="row" spacing={1} alignItems="center" flexWrap="wrap" useFlexGap> |
| | | <Chip |
| | | size="small" |
| | | variant="outlined" |
| | | color={item.traceType === "thinking" ? "info" : "primary"} |
| | | label={translate(item.traceType === "thinking" ? "ai.drawer.traceTypeThinking" : "ai.drawer.traceTypeTool")} |
| | | /> |
| | | <Typography variant="body2" fontWeight={700}> |
| | | {item.toolName || translate("ai.drawer.unknownTool")} |
| | | {item.traceType === "thinking" |
| | | ? (item.title || translate("ai.drawer.thinkingProcess")) |
| | | : (item.toolName || item.title || translate("ai.drawer.unknownTool"))} |
| | | </Typography> |
| | | <Chip |
| | | size="small" |
| | | color={item.status === "FAILED" ? "error" : item.status === "COMPLETED" ? "success" : "info"} |
| | | label={translate(item.status === "FAILED" ? "ai.drawer.toolStatusFailed" : item.status === "COMPLETED" ? "ai.drawer.toolStatusCompleted" : "ai.drawer.toolStatusRunning")} |
| | | color={item.status === "FAILED" |
| | | ? "error" |
| | | : item.status === "COMPLETED" |
| | | ? "success" |
| | | : item.status === "ABORTED" |
| | | ? "warning" |
| | | : "info"} |
| | | label={item.traceType === "thinking" |
| | | ? getThinkingStatusLabel(item.status) |
| | | : getToolStatusLabel(item.status)} |
| | | /> |
| | | {item.durationMs != null && ( |
| | | <Typography variant="caption" color="text.secondary"> |
| | | {item.durationMs} ms |
| | | </Typography> |
| | | )} |
| | | {(item.inputSummary || item.outputSummary || item.errorMessage) && ( |
| | | {item.traceType === "tool" && (item.inputSummary || item.outputSummary || item.errorMessage) && ( |
| | | <Button |
| | | size="small" |
| | | onClick={() => toggleToolEventExpanded(item.toolCallId)} |
| | | endIcon={expandedToolIds.includes(item.toolCallId) |
| | | onClick={() => toggleTraceEventExpanded(item.traceId)} |
| | | endIcon={expandedTraceIds.includes(item.traceId) |
| | | ? <ExpandLessOutlinedIcon fontSize="small" /> |
| | | : <ExpandMoreOutlinedIcon fontSize="small" />} |
| | | sx={{ ml: "auto", minWidth: "auto", px: 0.5 }} |
| | | > |
| | | {expandedToolIds.includes(item.toolCallId) ? translate("ai.drawer.collapseDetail") : translate("ai.drawer.viewDetail")} |
| | | {expandedTraceIds.includes(item.traceId) ? translate("ai.drawer.collapseDetail") : translate("ai.drawer.viewDetail")} |
| | | </Button> |
| | | )} |
| | | </Stack> |
| | | <Collapse in={expandedToolIds.includes(item.toolCallId)} timeout="auto" unmountOnExit> |
| | | {!!item.inputSummary && ( |
| | | <Typography variant="caption" display="block" sx={{ mt: 0.75, whiteSpace: "pre-wrap" }}> |
| | | {translate("ai.drawer.toolInput", { value: item.inputSummary })} |
| | | </Typography> |
| | | )} |
| | | {!!item.outputSummary && ( |
| | | <Typography variant="caption" display="block" sx={{ mt: 0.75, whiteSpace: "pre-wrap" }}> |
| | | {translate("ai.drawer.toolOutput", { value: item.outputSummary })} |
| | | </Typography> |
| | | )} |
| | | {!!item.errorMessage && ( |
| | | <Typography variant="caption" color="error.main" display="block" sx={{ mt: 0.75, whiteSpace: "pre-wrap" }}> |
| | | {translate("ai.drawer.toolError", { value: item.errorMessage })} |
| | | </Typography> |
| | | )} |
| | | </Collapse> |
| | | {item.traceType === "thinking" ? ( |
| | | <Typography variant="caption" display="block" color="text.secondary" sx={{ mt: 0.75, whiteSpace: "pre-wrap" }}> |
| | | {item.content || translate("ai.drawer.thinkingEmpty")} |
| | | </Typography> |
| | | ) : ( |
| | | <Collapse in={expandedTraceIds.includes(item.traceId)} timeout="auto" unmountOnExit> |
| | | {!!item.title && ( |
| | | <Typography variant="caption" display="block" color="text.secondary" sx={{ mt: 0.75, whiteSpace: "pre-wrap" }}> |
| | | {item.title} |
| | | </Typography> |
| | | )} |
| | | {!!item.inputSummary && ( |
| | | <Typography variant="caption" display="block" sx={{ mt: 0.75, whiteSpace: "pre-wrap" }}> |
| | | {translate("ai.drawer.toolInput", { value: item.inputSummary })} |
| | | </Typography> |
| | | )} |
| | | {!!item.outputSummary && ( |
| | | <Typography variant="caption" display="block" sx={{ mt: 0.75, whiteSpace: "pre-wrap" }}> |
| | | {translate("ai.drawer.toolOutput", { value: item.outputSummary })} |
| | | </Typography> |
| | | )} |
| | | {!!item.errorMessage && ( |
| | | <Typography variant="caption" color="error.main" display="block" sx={{ mt: 0.75, whiteSpace: "pre-wrap" }}> |
| | | {translate("ai.drawer.toolError", { value: item.errorMessage })} |
| | | </Typography> |
| | | )} |
| | | </Collapse> |
| | | )} |
| | | </Paper> |
| | | ))} |
| | | </Stack> |
| | |
| | | justifyContent={message.role === "user" ? "flex-end" : "flex-start"} |
| | | > |
| | | <Stack spacing={1} sx={{ maxWidth: "85%", width: "100%" }} alignItems={message.role === "user" ? "flex-end" : "flex-start"}> |
| | | {message.role === "assistant" && index === currentThinkingMessageIndex && !!thinkingEvents.length && ( |
| | | <Paper |
| | | variant="outlined" |
| | | sx={{ |
| | | width: "100%", |
| | | borderRadius: 2, |
| | | overflow: "hidden", |
| | | bgcolor: "grey.50", |
| | | }} |
| | | > |
| | | <Button |
| | | fullWidth |
| | | size="small" |
| | | onClick={toggleThinkingExpanded} |
| | | endIcon={thinkingExpanded |
| | | ? <ExpandLessOutlinedIcon fontSize="small" /> |
| | | : <ExpandMoreOutlinedIcon fontSize="small" />} |
| | | sx={{ |
| | | justifyContent: "space-between", |
| | | px: 1.25, |
| | | py: 0.75, |
| | | color: "text.primary", |
| | | }} |
| | | > |
| | | {thinkingExpanded ? translate("ai.drawer.thinkingCollapse") : translate("ai.drawer.thinkingExpand")} |
| | | </Button> |
| | | <Collapse in={thinkingExpanded} timeout="auto" unmountOnExit> |
| | | <Stack spacing={1} sx={{ px: 1.25, pb: 1.25 }}> |
| | | {thinkingEvents.map((item) => ( |
| | | <Paper key={item.phase} variant="outlined" sx={{ px: 1, py: 0.9, bgcolor: "common.white" }}> |
| | | <Stack direction="row" spacing={1} alignItems="center" flexWrap="wrap" useFlexGap> |
| | | <Typography variant="body2" fontWeight={700}> |
| | | {item.title || translate("ai.drawer.thinkingProcess")} |
| | | </Typography> |
| | | <Chip |
| | | size="small" |
| | | color={item.status === "FAILED" |
| | | ? "error" |
| | | : item.status === "COMPLETED" |
| | | ? "success" |
| | | : item.status === "ABORTED" |
| | | ? "warning" |
| | | : "info"} |
| | | label={getThinkingStatusLabel(item.status)} |
| | | /> |
| | | </Stack> |
| | | <Typography variant="caption" display="block" color="text.secondary" sx={{ mt: 0.75, whiteSpace: "pre-wrap" }}> |
| | | {item.content || translate("ai.drawer.thinkingEmpty")} |
| | | </Typography> |
| | | </Paper> |
| | | ))} |
| | | </Stack> |
| | | </Collapse> |
| | | </Paper> |
| | | )} |
| | | <Paper |
| | | elevation={0} |
| | | sx={{ |
| File was renamed from rsf-server/src/main/java/com/vincent/rsf/server/ai/dto/AiChatToolEventDto.java |
| | |
| | | |
| | | @Data |
| | | @Builder |
| | | public class AiChatToolEventDto { |
| | | public class AiChatTraceEventDto { |
| | | |
| | | private String requestId; |
| | | |
| | | private Long sessionId; |
| | | |
| | | private String traceId; |
| | | |
| | | private Long sequence; |
| | | |
| | | private String traceType; |
| | | |
| | | private String phase; |
| | | |
| | | private String status; |
| | | |
| | | private String title; |
| | | |
| | | private String content; |
| | | |
| | | private String toolCallId; |
| | | |
| | | private String toolName; |
| | | |
| | | private String mountName; |
| | | |
| | | private String status; |
| | | |
| | | private String inputSummary; |
| | | |
| | |
| | | public void handleStreamFailure(SseEmitter emitter, String requestId, Long sessionId, String model, long startedAt, |
| | | Long firstTokenAt, AiChatException exception, Long callLogId, |
| | | long toolSuccessCount, long toolFailureCount, |
| | | AiThinkingTraceEmitter thinkingTraceEmitter, |
| | | AiChatTraceEmitter traceEmitter, |
| | | Long tenantId, Long userId, String promptCode) { |
| | | if (isClientAbortException(exception)) { |
| | | log.warn("AI chat aborted by client, requestId={}, sessionId={}, stage={}, message={}", |
| | | requestId, sessionId, exception.getStage(), exception.getMessage()); |
| | | if (thinkingTraceEmitter != null) { |
| | | thinkingTraceEmitter.markTerminated("ABORTED"); |
| | | if (traceEmitter != null) { |
| | | traceEmitter.markTerminated("ABORTED"); |
| | | } |
| | | aiSseEventPublisher.emitSafely(emitter, "status", |
| | | aiSseEventPublisher.buildTerminalStatus(requestId, sessionId, "ABORTED", model, startedAt, firstTokenAt)); |
| | |
| | | } |
| | | log.error("AI chat failed, requestId={}, sessionId={}, category={}, stage={}, message={}", |
| | | requestId, sessionId, exception.getCategory(), exception.getStage(), exception.getMessage(), exception); |
| | | if (thinkingTraceEmitter != null) { |
| | | thinkingTraceEmitter.markTerminated("FAILED"); |
| | | if (traceEmitter != null) { |
| | | traceEmitter.markTerminated("FAILED"); |
| | | } |
| | | aiSseEventPublisher.emitSafely(emitter, "status", |
| | | aiSseEventPublisher.buildTerminalStatus(requestId, sessionId, "FAILED", model, startedAt, firstTokenAt)); |
| | |
| | | String requestId = request.getRequestId(); |
| | | long startedAt = System.currentTimeMillis(); |
| | | AtomicReference<Long> firstTokenAtRef = new AtomicReference<>(); |
| | | AtomicLong traceSequence = new AtomicLong(0); |
| | | AtomicLong toolCallSequence = new AtomicLong(0); |
| | | AtomicLong toolSuccessCount = new AtomicLong(0); |
| | | AtomicLong toolFailureCount = new AtomicLong(0); |
| | |
| | | Long callLogId = null; |
| | | String model = null; |
| | | String resolvedPromptCode = request.getPromptCode(); |
| | | AiThinkingTraceEmitter thinkingTraceEmitter = null; |
| | | AiChatTraceEmitter traceEmitter = null; |
| | | try { |
| | | ensureIdentity(userId, tenantId); |
| | | AiResolvedConfig config = resolveConfig(request, tenantId); |
| | |
| | | .build()); |
| | | log.info("AI chat started, requestId={}, userId={}, tenantId={}, sessionId={}, model={}", |
| | | requestId, userId, tenantId, session.getId(), resolvedModel); |
| | | thinkingTraceEmitter = new AiThinkingTraceEmitter(aiSseEventPublisher, emitter, requestId, session.getId()); |
| | | thinkingTraceEmitter.startAnalyze(); |
| | | AiThinkingTraceEmitter activeThinkingTraceEmitter = thinkingTraceEmitter; |
| | | traceEmitter = new AiChatTraceEmitter(aiSseEventPublisher, emitter, requestId, session.getId(), traceSequence); |
| | | traceEmitter.startAnalyze(); |
| | | AiChatTraceEmitter activeTraceEmitter = traceEmitter; |
| | | |
| | | ToolCallback[] observableToolCallbacks = aiToolObservationService.wrapToolCallbacks( |
| | | runtime.getToolCallbacks(), emitter, requestId, session.getId(), toolCallSequence, |
| | | toolSuccessCount, toolFailureCount, callLogId, userId, tenantId, activeThinkingTraceEmitter |
| | | runtime.getToolCallbacks(), requestId, session.getId(), toolCallSequence, |
| | | toolSuccessCount, toolFailureCount, callLogId, userId, tenantId, activeTraceEmitter |
| | | ); |
| | | Prompt prompt = new Prompt( |
| | | aiPromptMessageBuilder.buildPromptMessages(memory, mergedMessages, config.getPrompt(), request.getMetadata()), |
| | |
| | | String content = extractContent(response); |
| | | aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), content); |
| | | if (StringUtils.hasText(content)) { |
| | | aiSseEventPublisher.markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt, activeThinkingTraceEmitter); |
| | | aiSseEventPublisher.markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt, activeTraceEmitter); |
| | | aiSseEventPublisher.emitStrict(emitter, "delta", aiSseEventPublisher.buildMessagePayload("requestId", requestId, "content", content)); |
| | | } |
| | | activeThinkingTraceEmitter.completeCurrentPhase(); |
| | | activeTraceEmitter.completeCurrentPhase(); |
| | | aiSseEventPublisher.emitDone(emitter, requestId, response.getMetadata(), config.getAiParam().getModel(), |
| | | session.getId(), startedAt, firstTokenAtRef.get()); |
| | | aiSseEventPublisher.emitSafely(emitter, "status", |
| | |
| | | lastMetadata.set(response.getMetadata()); |
| | | String content = extractContent(response); |
| | | if (StringUtils.hasText(content)) { |
| | | aiSseEventPublisher.markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt, activeThinkingTraceEmitter); |
| | | aiSseEventPublisher.markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt, activeTraceEmitter); |
| | | assistantContent.append(content); |
| | | aiSseEventPublisher.emitStrict(emitter, "delta", |
| | | aiSseEventPublisher.buildMessagePayload("requestId", requestId, "content", content)); |
| | |
| | | e == null ? "AI 模型流式调用失败" : e.getMessage(), e); |
| | | } |
| | | aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), assistantContent.toString()); |
| | | activeThinkingTraceEmitter.completeCurrentPhase(); |
| | | activeTraceEmitter.completeCurrentPhase(); |
| | | aiSseEventPublisher.emitDone(emitter, requestId, lastMetadata.get(), config.getAiParam().getModel(), |
| | | session.getId(), startedAt, firstTokenAtRef.get()); |
| | | aiSseEventPublisher.emitSafely(emitter, "status", |
| | |
| | | } |
| | | } catch (AiChatException e) { |
| | | aiChatFailureHandler.handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(), e, |
| | | callLogId, toolSuccessCount.get(), toolFailureCount.get(), thinkingTraceEmitter, |
| | | callLogId, toolSuccessCount.get(), toolFailureCount.get(), traceEmitter, |
| | | tenantId, userId, resolvedPromptCode); |
| | | } catch (Exception e) { |
| | | aiChatFailureHandler.handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(), |
| | | aiChatFailureHandler.buildAiException("AI_INTERNAL_ERROR", AiErrorCategory.INTERNAL, "INTERNAL", |
| | | e == null ? "AI 对话失败" : e.getMessage(), e), |
| | | callLogId, toolSuccessCount.get(), toolFailureCount.get(), thinkingTraceEmitter, |
| | | callLogId, toolSuccessCount.get(), toolFailureCount.get(), traceEmitter, |
| | | tenantId, userId, resolvedPromptCode); |
| | | } finally { |
| | | log.debug("AI chat stream finished, requestId={}", requestId); |
| New file |
| | |
| | | package com.vincent.rsf.server.ai.service.impl.chat; |
| | | |
| | | import com.vincent.rsf.server.ai.dto.AiChatTraceEventDto; |
| | | import org.springframework.util.StringUtils; |
| | | import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; |
| | | |
| | | import java.time.Instant; |
| | | import java.util.Map; |
| | | import java.util.Objects; |
| | | import java.util.concurrent.ConcurrentHashMap; |
| | | import java.util.concurrent.atomic.AtomicLong; |
| | | |
| | | public class AiChatTraceEmitter { |
| | | |
| | | private static final String TRACE_EVENT_NAME = "trace"; |
| | | |
| | | private final AiSseEventPublisher aiSseEventPublisher; |
| | | private final SseEmitter emitter; |
| | | private final String requestId; |
| | | private final Long sessionId; |
| | | private final AtomicLong traceSequence; |
| | | private final Map<String, Long> traceOrderMap = new ConcurrentHashMap<>(); |
| | | private String currentPhase; |
| | | private String currentStatus; |
| | | |
| | | public AiChatTraceEmitter(AiSseEventPublisher aiSseEventPublisher, SseEmitter emitter, String requestId, |
| | | Long sessionId, AtomicLong traceSequence) { |
| | | this.aiSseEventPublisher = aiSseEventPublisher; |
| | | this.emitter = emitter; |
| | | this.requestId = requestId; |
| | | this.sessionId = sessionId; |
| | | this.traceSequence = traceSequence; |
| | | } |
| | | |
| | | public void startAnalyze() { |
| | | if (currentPhase != null) { |
| | | return; |
| | | } |
| | | currentPhase = "ANALYZE"; |
| | | currentStatus = "STARTED"; |
| | | emitThinkingEvent("ANALYZE", "STARTED", "正在分析问题", |
| | | "已接收你的问题,正在理解意图并判断是否需要调用工具。", null); |
| | | } |
| | | |
| | | public void onToolStart(String toolName, String mountName, String toolCallId, String inputSummary, long timestamp) { |
| | | completeCurrentPhase(); |
| | | emitToolEvent("STARTED", "开始调用工具", null, toolCallId, toolName, mountName, inputSummary, |
| | | null, null, null, timestamp); |
| | | } |
| | | |
| | | public void onToolResult(String toolName, String mountName, String toolCallId, String inputSummary, |
| | | String outputSummary, String errorMessage, Long durationMs, long timestamp, |
| | | boolean failed) { |
| | | emitToolEvent(failed ? "FAILED" : "COMPLETED", |
| | | failed ? "工具调用失败" : "工具调用完成", |
| | | null, |
| | | toolCallId, |
| | | toolName, |
| | | mountName, |
| | | inputSummary, |
| | | outputSummary, |
| | | errorMessage, |
| | | durationMs, |
| | | timestamp); |
| | | } |
| | | |
| | | public void startAnswer() { |
| | | switchPhase("ANSWER", "STARTED", "正在整理答案", "已完成分析,正在组织最终回复内容。", null); |
| | | } |
| | | |
| | | public void completeCurrentPhase() { |
| | | if (!StringUtils.hasText(currentPhase) || isTerminalStatus(currentStatus)) { |
| | | return; |
| | | } |
| | | currentStatus = "COMPLETED"; |
| | | emitThinkingEvent(currentPhase, "COMPLETED", resolveCompleteTitle(currentPhase), |
| | | resolveCompleteContent(currentPhase), null); |
| | | } |
| | | |
| | | public void markTerminated(String terminalStatus) { |
| | | if (!StringUtils.hasText(currentPhase) || isTerminalStatus(currentStatus)) { |
| | | return; |
| | | } |
| | | currentStatus = terminalStatus; |
| | | emitThinkingEvent(currentPhase, terminalStatus, |
| | | "ABORTED".equals(terminalStatus) ? "思考已中止" : "思考失败", |
| | | "ABORTED".equals(terminalStatus) |
| | | ? "本轮对话已被中止,思考过程提前结束。" |
| | | : "本轮对话在生成答案前失败,当前思考过程已停止。", |
| | | null); |
| | | } |
| | | |
| | | private void switchPhase(String nextPhase, String nextStatus, String title, String content, String toolCallId) { |
| | | if (!Objects.equals(currentPhase, nextPhase)) { |
| | | completeCurrentPhase(); |
| | | } |
| | | currentPhase = nextPhase; |
| | | currentStatus = nextStatus; |
| | | emitThinkingEvent(nextPhase, nextStatus, title, content, toolCallId); |
| | | } |
| | | |
| | | private void emitThinkingEvent(String phase, String status, String title, String content, String toolCallId) { |
| | | emitTraceEvent(AiChatTraceEventDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .traceType("thinking") |
| | | .phase(phase) |
| | | .status(status) |
| | | .title(title) |
| | | .content(content) |
| | | .toolCallId(toolCallId) |
| | | .timestamp(Instant.now().toEpochMilli()) |
| | | .build(), buildThinkingTraceId(phase)); |
| | | } |
| | | |
| | | private void emitToolEvent(String status, String title, String content, String toolCallId, String toolName, |
| | | String mountName, String inputSummary, String outputSummary, String errorMessage, |
| | | Long durationMs, long timestamp) { |
| | | emitTraceEvent(AiChatTraceEventDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .traceType("tool") |
| | | .status(status) |
| | | .title(title) |
| | | .content(content) |
| | | .toolCallId(toolCallId) |
| | | .toolName(toolName) |
| | | .mountName(mountName) |
| | | .inputSummary(inputSummary) |
| | | .outputSummary(outputSummary) |
| | | .errorMessage(errorMessage) |
| | | .durationMs(durationMs) |
| | | .timestamp(timestamp) |
| | | .build(), buildToolTraceId(toolCallId)); |
| | | } |
| | | |
| | | private void emitTraceEvent(AiChatTraceEventDto payload, String traceId) { |
| | | long sequence = traceOrderMap.computeIfAbsent(traceId, ignored -> traceSequence.incrementAndGet()); |
| | | payload.setSequence(sequence); |
| | | payload.setTraceId(traceId); |
| | | aiSseEventPublisher.emitSafely(emitter, TRACE_EVENT_NAME, payload); |
| | | } |
| | | |
| | | private String buildThinkingTraceId(String phase) { |
| | | return requestId + "-thinking-" + phase; |
| | | } |
| | | |
| | | private String buildToolTraceId(String toolCallId) { |
| | | return toolCallId; |
| | | } |
| | | |
| | | private boolean isTerminalStatus(String status) { |
| | | return "COMPLETED".equals(status) || "FAILED".equals(status) || "ABORTED".equals(status); |
| | | } |
| | | |
| | | private String resolveCompleteTitle(String phase) { |
| | | if ("ANSWER".equals(phase)) { |
| | | return "答案整理完成"; |
| | | } |
| | | if ("TOOL_CALL".equals(phase)) { |
| | | return "工具分析完成"; |
| | | } |
| | | return "问题分析完成"; |
| | | } |
| | | |
| | | private String resolveCompleteContent(String phase) { |
| | | if ("ANSWER".equals(phase)) { |
| | | return "最终答复已生成完成。"; |
| | | } |
| | | if ("TOOL_CALL".equals(phase)) { |
| | | return "工具调用阶段已结束,相关信息已整理完毕。"; |
| | | } |
| | | return "问题意图和处理方向已分析完成。"; |
| | | } |
| | | |
| | | private String safeLabel(String value, String fallback) { |
| | | return StringUtils.hasText(value) ? value : fallback; |
| | | } |
| | | } |
| | |
| | | private final ObjectMapper objectMapper; |
| | | |
| | | public void markFirstToken(AtomicReference<Long> firstTokenAtRef, SseEmitter emitter, String requestId, |
| | | Long sessionId, String model, long startedAt, AiThinkingTraceEmitter thinkingTraceEmitter) { |
| | | Long sessionId, String model, long startedAt, AiChatTraceEmitter traceEmitter) { |
| | | if (!firstTokenAtRef.compareAndSet(null, System.currentTimeMillis())) { |
| | | return; |
| | | } |
| | | if (thinkingTraceEmitter != null) { |
| | | thinkingTraceEmitter.startAnswer(); |
| | | if (traceEmitter != null) { |
| | | traceEmitter.startAnswer(); |
| | | } |
| | | emitSafely(emitter, "status", AiChatStatusDto.builder() |
| | | .requestId(requestId) |
| | |
| | | package com.vincent.rsf.server.ai.service.impl.chat; |
| | | |
| | | import com.vincent.rsf.framework.exception.CoolException; |
| | | import com.vincent.rsf.server.ai.dto.AiChatToolEventDto; |
| | | import com.vincent.rsf.server.ai.service.AiCallLogService; |
| | | import com.vincent.rsf.server.ai.service.MountedToolCallback; |
| | | import com.vincent.rsf.server.ai.store.AiCachedToolResult; |
| | |
| | | import org.springframework.ai.tool.ToolCallback; |
| | | import org.springframework.stereotype.Component; |
| | | import org.springframework.util.StringUtils; |
| | | import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; |
| | | |
| | | import java.util.ArrayList; |
| | | import java.util.List; |
| | |
| | | @RequiredArgsConstructor |
| | | public class AiToolObservationService { |
| | | |
| | | private final AiSseEventPublisher aiSseEventPublisher; |
| | | private final AiToolResultStore aiToolResultStore; |
| | | private final AiCallLogService aiCallLogService; |
| | | |
| | | public ToolCallback[] wrapToolCallbacks(ToolCallback[] toolCallbacks, SseEmitter emitter, String requestId, |
| | | public ToolCallback[] wrapToolCallbacks(ToolCallback[] toolCallbacks, String requestId, |
| | | Long sessionId, AtomicLong toolCallSequence, |
| | | AtomicLong toolSuccessCount, AtomicLong toolFailureCount, |
| | | Long callLogId, Long userId, Long tenantId, |
| | | AiThinkingTraceEmitter thinkingTraceEmitter) { |
| | | AiChatTraceEmitter traceEmitter) { |
| | | if (toolCallbacks == null || toolCallbacks.length == 0) { |
| | | return toolCallbacks; |
| | | } |
| | |
| | | if (callback == null) { |
| | | continue; |
| | | } |
| | | wrappedCallbacks.add(new ObservableToolCallback(callback, emitter, requestId, sessionId, toolCallSequence, |
| | | toolSuccessCount, toolFailureCount, callLogId, userId, tenantId, thinkingTraceEmitter)); |
| | | wrappedCallbacks.add(new ObservableToolCallback(callback, requestId, sessionId, toolCallSequence, |
| | | toolSuccessCount, toolFailureCount, callLogId, userId, tenantId, traceEmitter)); |
| | | } |
| | | return wrappedCallbacks.toArray(new ToolCallback[0]); |
| | | } |
| | |
| | | private class ObservableToolCallback implements ToolCallback { |
| | | |
| | | private final ToolCallback delegate; |
| | | private final SseEmitter emitter; |
| | | private final String requestId; |
| | | private final Long sessionId; |
| | | private final AtomicLong toolCallSequence; |
| | |
| | | private final Long callLogId; |
| | | private final Long userId; |
| | | private final Long tenantId; |
| | | private final AiThinkingTraceEmitter thinkingTraceEmitter; |
| | | private final AiChatTraceEmitter traceEmitter; |
| | | |
| | | private ObservableToolCallback(ToolCallback delegate, SseEmitter emitter, String requestId, |
| | | private ObservableToolCallback(ToolCallback delegate, String requestId, |
| | | Long sessionId, AtomicLong toolCallSequence, |
| | | AtomicLong toolSuccessCount, AtomicLong toolFailureCount, |
| | | Long callLogId, Long userId, Long tenantId, |
| | | AiThinkingTraceEmitter thinkingTraceEmitter) { |
| | | AiChatTraceEmitter traceEmitter) { |
| | | this.delegate = delegate; |
| | | this.emitter = emitter; |
| | | this.requestId = requestId; |
| | | this.sessionId = sessionId; |
| | | this.toolCallSequence = toolCallSequence; |
| | |
| | | this.callLogId = callLogId; |
| | | this.userId = userId; |
| | | this.tenantId = tenantId; |
| | | this.thinkingTraceEmitter = thinkingTraceEmitter; |
| | | this.traceEmitter = traceEmitter; |
| | | } |
| | | |
| | | @Override |
| | |
| | | String mountName = delegate instanceof MountedToolCallback ? ((MountedToolCallback) delegate).getMountName() : null; |
| | | String toolCallId = requestId + "-tool-" + toolCallSequence.incrementAndGet(); |
| | | long startedAt = System.currentTimeMillis(); |
| | | String inputSummary = summarizeToolPayload(toolInput, 400); |
| | | AiCachedToolResult cachedToolResult = aiToolResultStore.getToolResult(tenantId, requestId, toolName, toolInput); |
| | | if (cachedToolResult != null) { |
| | | aiSseEventPublisher.emitSafely(emitter, "tool_result", AiChatToolEventDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .toolCallId(toolCallId) |
| | | .toolName(toolName) |
| | | .mountName(mountName) |
| | | .status(cachedToolResult.isSuccess() ? "COMPLETED" : "FAILED") |
| | | .inputSummary(summarizeToolPayload(toolInput, 400)) |
| | | .outputSummary(summarizeToolPayload(cachedToolResult.getOutput(), 600)) |
| | | .errorMessage(cachedToolResult.getErrorMessage()) |
| | | .durationMs(0L) |
| | | .timestamp(System.currentTimeMillis()) |
| | | .build()); |
| | | if (thinkingTraceEmitter != null) { |
| | | thinkingTraceEmitter.onToolResult(toolName, toolCallId, !cachedToolResult.isSuccess()); |
| | | String outputSummary = summarizeToolPayload(cachedToolResult.getOutput(), 600); |
| | | String errorMessage = cachedToolResult.getErrorMessage(); |
| | | if (traceEmitter != null) { |
| | | traceEmitter.onToolResult(toolName, mountName, toolCallId, inputSummary, outputSummary, |
| | | errorMessage, 0L, System.currentTimeMillis(), !cachedToolResult.isSuccess()); |
| | | } |
| | | if (cachedToolResult.isSuccess()) { |
| | | toolSuccessCount.incrementAndGet(); |
| | | aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName, |
| | | "COMPLETED", summarizeToolPayload(toolInput, 400), summarizeToolPayload(cachedToolResult.getOutput(), 600), |
| | | "COMPLETED", inputSummary, outputSummary, |
| | | null, 0L, userId, tenantId); |
| | | return cachedToolResult.getOutput(); |
| | | } |
| | | toolFailureCount.incrementAndGet(); |
| | | aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName, |
| | | "FAILED", summarizeToolPayload(toolInput, 400), null, cachedToolResult.getErrorMessage(), |
| | | "FAILED", inputSummary, null, errorMessage, |
| | | 0L, userId, tenantId); |
| | | throw new CoolException(cachedToolResult.getErrorMessage()); |
| | | throw new CoolException(errorMessage); |
| | | } |
| | | if (thinkingTraceEmitter != null) { |
| | | thinkingTraceEmitter.onToolStart(toolName, toolCallId); |
| | | if (traceEmitter != null) { |
| | | traceEmitter.onToolStart(toolName, mountName, toolCallId, inputSummary, startedAt); |
| | | } |
| | | aiSseEventPublisher.emitSafely(emitter, "tool_start", AiChatToolEventDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .toolCallId(toolCallId) |
| | | .toolName(toolName) |
| | | .mountName(mountName) |
| | | .status("STARTED") |
| | | .inputSummary(summarizeToolPayload(toolInput, 400)) |
| | | .timestamp(startedAt) |
| | | .build()); |
| | | try { |
| | | String output = toolContext == null ? delegate.call(toolInput) : delegate.call(toolInput, toolContext); |
| | | long durationMs = System.currentTimeMillis() - startedAt; |
| | | aiSseEventPublisher.emitSafely(emitter, "tool_result", AiChatToolEventDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .toolCallId(toolCallId) |
| | | .toolName(toolName) |
| | | .mountName(mountName) |
| | | .status("COMPLETED") |
| | | .inputSummary(summarizeToolPayload(toolInput, 400)) |
| | | .outputSummary(summarizeToolPayload(output, 600)) |
| | | .durationMs(durationMs) |
| | | .timestamp(System.currentTimeMillis()) |
| | | .build()); |
| | | if (thinkingTraceEmitter != null) { |
| | | thinkingTraceEmitter.onToolResult(toolName, toolCallId, false); |
| | | String outputSummary = summarizeToolPayload(output, 600); |
| | | if (traceEmitter != null) { |
| | | traceEmitter.onToolResult(toolName, mountName, toolCallId, inputSummary, outputSummary, |
| | | null, durationMs, System.currentTimeMillis(), false); |
| | | } |
| | | aiToolResultStore.cacheToolResult(tenantId, requestId, toolName, toolInput, true, output, null); |
| | | toolSuccessCount.incrementAndGet(); |
| | | aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName, |
| | | "COMPLETED", summarizeToolPayload(toolInput, 400), summarizeToolPayload(output, 600), |
| | | "COMPLETED", inputSummary, outputSummary, |
| | | null, durationMs, userId, tenantId); |
| | | return output; |
| | | } catch (RuntimeException e) { |
| | | long durationMs = System.currentTimeMillis() - startedAt; |
| | | aiSseEventPublisher.emitSafely(emitter, "tool_error", AiChatToolEventDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .toolCallId(toolCallId) |
| | | .toolName(toolName) |
| | | .mountName(mountName) |
| | | .status("FAILED") |
| | | .inputSummary(summarizeToolPayload(toolInput, 400)) |
| | | .errorMessage(e.getMessage()) |
| | | .durationMs(durationMs) |
| | | .timestamp(System.currentTimeMillis()) |
| | | .build()); |
| | | if (thinkingTraceEmitter != null) { |
| | | thinkingTraceEmitter.onToolResult(toolName, toolCallId, true); |
| | | String errorMessage = e.getMessage(); |
| | | if (traceEmitter != null) { |
| | | traceEmitter.onToolResult(toolName, mountName, toolCallId, inputSummary, null, |
| | | errorMessage, durationMs, System.currentTimeMillis(), true); |
| | | } |
| | | aiToolResultStore.cacheToolResult(tenantId, requestId, toolName, toolInput, false, null, e.getMessage()); |
| | | aiToolResultStore.cacheToolResult(tenantId, requestId, toolName, toolInput, false, null, errorMessage); |
| | | toolFailureCount.incrementAndGet(); |
| | | aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName, |
| | | "FAILED", summarizeToolPayload(toolInput, 400), null, e.getMessage(), |
| | | "FAILED", inputSummary, null, errorMessage, |
| | | durationMs, userId, tenantId); |
| | | throw e; |
| | | } |
| New file |
| | |
| | | package com.vincent.rsf.server.ai.service.impl.chat; |
| | | |
| | | import com.vincent.rsf.server.ai.dto.AiChatTraceEventDto; |
| | | import org.junit.jupiter.api.Test; |
| | | import org.junit.jupiter.api.extension.ExtendWith; |
| | | import org.mockito.ArgumentCaptor; |
| | | import org.mockito.Captor; |
| | | import org.mockito.Mock; |
| | | import org.mockito.junit.jupiter.MockitoExtension; |
| | | import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; |
| | | |
| | | import java.util.List; |
| | | import java.util.concurrent.atomic.AtomicLong; |
| | | |
| | | import static org.assertj.core.api.Assertions.assertThat; |
| | | import static org.mockito.ArgumentMatchers.any; |
| | | import static org.mockito.ArgumentMatchers.eq; |
| | | import static org.mockito.Mockito.times; |
| | | import static org.mockito.Mockito.verify; |
| | | |
| | | @ExtendWith(MockitoExtension.class) |
| | | class AiChatTraceEmitterTest { |
| | | |
| | | @Mock |
| | | private AiSseEventPublisher aiSseEventPublisher; |
| | | |
| | | @Captor |
| | | private ArgumentCaptor<AiChatTraceEventDto> payloadCaptor; |
| | | |
| | | @Test |
| | | void shouldReuseTraceIdentityForLogicalTraceCards() { |
| | | AiChatTraceEmitter traceEmitter = new AiChatTraceEmitter( |
| | | aiSseEventPublisher, |
| | | new SseEmitter(1000L), |
| | | "req-1", |
| | | 11L, |
| | | new AtomicLong(0) |
| | | ); |
| | | |
| | | traceEmitter.startAnalyze(); |
| | | traceEmitter.onToolStart("inventory.lookup", "builtin-stock", "tool-1", "{\"code\":\"A01\"}", 100L); |
| | | traceEmitter.onToolResult("inventory.lookup", "builtin-stock", "tool-1", "{\"code\":\"A01\"}", |
| | | "{\"stock\":12}", null, 64L, 164L, false); |
| | | |
| | | verify(aiSseEventPublisher, times(4)).emitSafely(any(), eq("trace"), payloadCaptor.capture()); |
| | | |
| | | List<AiChatTraceEventDto> payloads = payloadCaptor.getAllValues(); |
| | | assertThat(payloads).extracting(AiChatTraceEventDto::getSequence) |
| | | .containsExactly(1L, 1L, 2L, 2L); |
| | | assertThat(payloads).extracting(AiChatTraceEventDto::getTraceId) |
| | | .containsExactly("req-1-thinking-ANALYZE", "req-1-thinking-ANALYZE", "tool-1", "tool-1"); |
| | | assertThat(payloads).extracting(AiChatTraceEventDto::getTraceType) |
| | | .containsExactly("thinking", "thinking", "tool", "tool"); |
| | | assertThat(payloads.get(1).getStatus()).isEqualTo("COMPLETED"); |
| | | assertThat(payloads.get(3).getStatus()).isEqualTo("COMPLETED"); |
| | | assertThat(payloads.get(3).getToolCallId()).isEqualTo("tool-1"); |
| | | assertThat(payloads.get(3).getInputSummary()).isEqualTo("{\"code\":\"A01\"}"); |
| | | } |
| | | } |
| New file |
| | |
| | | package com.vincent.rsf.server.ai.service.impl.chat; |
| | | |
| | | import com.vincent.rsf.server.ai.service.AiCallLogService; |
| | | import com.vincent.rsf.server.ai.service.MountedToolCallback; |
| | | import com.vincent.rsf.server.ai.store.AiCachedToolResult; |
| | | import com.vincent.rsf.server.ai.store.AiToolResultStore; |
| | | import org.junit.jupiter.api.Test; |
| | | import org.junit.jupiter.api.extension.ExtendWith; |
| | | import org.mockito.Mock; |
| | | import org.mockito.junit.jupiter.MockitoExtension; |
| | | import org.springframework.ai.tool.ToolCallback; |
| | | import org.springframework.ai.tool.definition.ToolDefinition; |
| | | |
| | | import java.util.concurrent.atomic.AtomicLong; |
| | | |
| | | import static org.assertj.core.api.Assertions.assertThat; |
| | | import static org.mockito.ArgumentMatchers.anyLong; |
| | | import static org.mockito.ArgumentMatchers.eq; |
| | | import static org.mockito.Mockito.never; |
| | | import static org.mockito.Mockito.verify; |
| | | import static org.mockito.Mockito.when; |
| | | |
| | | @ExtendWith(MockitoExtension.class) |
| | | class AiToolObservationServiceTest { |
| | | |
| | | @Mock |
| | | private AiToolResultStore aiToolResultStore; |
| | | @Mock |
| | | private AiCallLogService aiCallLogService; |
| | | @Mock |
| | | private AiChatTraceEmitter traceEmitter; |
| | | @Mock |
| | | private MountedToolCallback mountedToolCallback; |
| | | @Mock |
| | | private ToolDefinition toolDefinition; |
| | | |
| | | @Test |
| | | void shouldSkipStartTraceWhenToolResultComesFromCache() { |
| | | AiToolObservationService aiToolObservationService = new AiToolObservationService(aiToolResultStore, aiCallLogService); |
| | | when(mountedToolCallback.getToolDefinition()).thenReturn(toolDefinition); |
| | | when(toolDefinition.name()).thenReturn("inventory.lookup"); |
| | | when(mountedToolCallback.getMountName()).thenReturn("builtin-stock"); |
| | | when(aiToolResultStore.getToolResult(1L, "req-1", "inventory.lookup", "{\"code\":\"A01\"}")) |
| | | .thenReturn(AiCachedToolResult.builder() |
| | | .success(true) |
| | | .output("{\"stock\":12}") |
| | | .build()); |
| | | |
| | | ToolCallback[] callbacks = aiToolObservationService.wrapToolCallbacks( |
| | | new ToolCallback[]{mountedToolCallback}, |
| | | "req-1", |
| | | 11L, |
| | | new AtomicLong(0), |
| | | new AtomicLong(0), |
| | | new AtomicLong(0), |
| | | 21L, |
| | | 31L, |
| | | 1L, |
| | | traceEmitter |
| | | ); |
| | | |
| | | String output = callbacks[0].call("{\"code\":\"A01\"}"); |
| | | |
| | | assertThat(output).isEqualTo("{\"stock\":12}"); |
| | | verify(traceEmitter, never()).onToolStart(eq("inventory.lookup"), eq("builtin-stock"), eq("req-1-tool-1"), eq("{\"code\":\"A01\"}"), anyLong()); |
| | | verify(traceEmitter).onToolResult(eq("inventory.lookup"), eq("builtin-stock"), eq("req-1-tool-1"), |
| | | eq("{\"code\":\"A01\"}"), eq("{\"stock\":12}"), eq(null), eq(0L), anyLong(), eq(false)); |
| | | verify(mountedToolCallback, never()).call("{\"code\":\"A01\"}"); |
| | | } |
| | | } |