From 6477d7156272a6f1fe126c781958369bb10970c6 Mon Sep 17 00:00:00 2001
From: zhou zhou <3272660260@qq.com>
Date: 星期六, 21 三月 2026 11:15:50 +0800
Subject: [PATCH] #ai 思维链

---
 rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java |  170 +++++++++++++++++++++++++++++++++++++++++++++++++++++---
 1 files changed, 160 insertions(+), 10 deletions(-)

diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java
index 40d5594..e7d842e 100644
--- a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java
@@ -14,6 +14,7 @@
 import com.vincent.rsf.server.ai.dto.AiChatSessionDto;
 import com.vincent.rsf.server.ai.dto.AiChatSessionPinRequest;
 import com.vincent.rsf.server.ai.dto.AiChatSessionRenameRequest;
+import com.vincent.rsf.server.ai.dto.AiChatThinkingEventDto;
 import com.vincent.rsf.server.ai.dto.AiChatToolEventDto;
 import com.vincent.rsf.server.ai.dto.AiResolvedConfig;
 import com.vincent.rsf.server.ai.entity.AiCallLog;
@@ -173,6 +174,7 @@
         Long sessionId = request.getSessionId();
         Long callLogId = null;
         String model = null;
+        ThinkingTraceEmitter thinkingTraceEmitter = null;
         try {
             ensureIdentity(userId, tenantId);
             AiResolvedConfig config = resolveConfig(request, tenantId);
@@ -221,10 +223,13 @@
                         .build());
                 log.info("AI chat started, requestId={}, userId={}, tenantId={}, sessionId={}, model={}",
                         requestId, userId, tenantId, session.getId(), resolvedModel);
+                thinkingTraceEmitter = new ThinkingTraceEmitter(emitter, requestId, session.getId());
+                thinkingTraceEmitter.startAnalyze();
 
+                ThinkingTraceEmitter activeThinkingTraceEmitter = thinkingTraceEmitter;
                 ToolCallback[] observableToolCallbacks = wrapToolCallbacks(
                         runtime.getToolCallbacks(), emitter, requestId, session.getId(), toolCallSequence,
-                        toolSuccessCount, toolFailureCount, callLogId, userId, tenantId
+                        toolSuccessCount, toolFailureCount, callLogId, userId, tenantId, activeThinkingTraceEmitter
                 );
                 Prompt prompt = new Prompt(
                         buildPromptMessages(memory, mergedMessages, config.getPrompt(), request.getMetadata()),
@@ -237,9 +242,10 @@
                     String content = extractContent(response);
                     aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), content);
                     if (StringUtils.hasText(content)) {
-                        markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt);
+                        markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt, activeThinkingTraceEmitter);
                         emitStrict(emitter, "delta", buildMessagePayload("requestId", requestId, "content", content));
                     }
+                    activeThinkingTraceEmitter.completeCurrentPhase();
                     emitDone(emitter, requestId, response.getMetadata(), config.getAiParam().getModel(), session.getId(), startedAt, firstTokenAtRef.get());
                     emitSafely(emitter, "status", buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get()));
                     aiCallLogService.completeCallLog(
@@ -267,7 +273,7 @@
                             lastMetadata.set(response.getMetadata());
                             String content = extractContent(response);
                             if (StringUtils.hasText(content)) {
-                                markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt);
+                                markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt, activeThinkingTraceEmitter);
                                 assistantContent.append(content);
                                 emitStrict(emitter, "delta", buildMessagePayload("requestId", requestId, "content", content));
                             }
@@ -278,6 +284,7 @@
                             e == null ? "AI 妯″瀷娴佸紡璋冪敤澶辫触" : e.getMessage(), e);
                 }
                 aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), assistantContent.toString());
+                activeThinkingTraceEmitter.completeCurrentPhase();
                 emitDone(emitter, requestId, lastMetadata.get(), config.getAiParam().getModel(), session.getId(), startedAt, firstTokenAtRef.get());
                 emitSafely(emitter, "status", buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get()));
                 aiCallLogService.completeCallLog(
@@ -297,12 +304,12 @@
             }
         } catch (AiChatException e) {
             handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(), e,
-                    callLogId, toolSuccessCount.get(), toolFailureCount.get());
+                    callLogId, toolSuccessCount.get(), toolFailureCount.get(), thinkingTraceEmitter);
         } catch (Exception e) {
             handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(),
                     buildAiException("AI_INTERNAL_ERROR", AiErrorCategory.INTERNAL, "INTERNAL",
                             e == null ? "AI 瀵硅瘽澶辫触" : e.getMessage(), e),
-                    callLogId, toolSuccessCount.get(), toolFailureCount.get());
+                    callLogId, toolSuccessCount.get(), toolFailureCount.get(), thinkingTraceEmitter);
         } finally {
             log.debug("AI chat stream finished, requestId={}", requestId);
         }
@@ -375,9 +382,13 @@
         }
     }
 
-    private void markFirstToken(AtomicReference<Long> firstTokenAtRef, SseEmitter emitter, String requestId, Long sessionId, String model, long startedAt) {
+    private void markFirstToken(AtomicReference<Long> firstTokenAtRef, SseEmitter emitter, String requestId,
+                                Long sessionId, String model, long startedAt, ThinkingTraceEmitter thinkingTraceEmitter) {
         if (!firstTokenAtRef.compareAndSet(null, System.currentTimeMillis())) {
             return;
+        }
+        if (thinkingTraceEmitter != null) {
+            thinkingTraceEmitter.startAnswer();
         }
         emitSafely(emitter, "status", AiChatStatusDto.builder()
                 .requestId(requestId)
@@ -408,10 +419,14 @@
 
     private void handleStreamFailure(SseEmitter emitter, String requestId, Long sessionId, String model, long startedAt,
                                      Long firstTokenAt, AiChatException exception, Long callLogId,
-                                     long toolSuccessCount, long toolFailureCount) {
+                                     long toolSuccessCount, long toolFailureCount,
+                                     ThinkingTraceEmitter thinkingTraceEmitter) {
         if (isClientAbortException(exception)) {
             log.warn("AI chat aborted by client, requestId={}, sessionId={}, stage={}, message={}",
                     requestId, sessionId, exception.getStage(), exception.getMessage());
+            if (thinkingTraceEmitter != null) {
+                thinkingTraceEmitter.markTerminated("ABORTED");
+            }
             emitSafely(emitter, "status", buildTerminalStatus(requestId, sessionId, "ABORTED", model, startedAt, firstTokenAt));
             aiCallLogService.failCallLog(
                     callLogId,
@@ -429,6 +444,9 @@
         }
         log.error("AI chat failed, requestId={}, sessionId={}, category={}, stage={}, message={}",
                 requestId, sessionId, exception.getCategory(), exception.getStage(), exception.getMessage(), exception);
+        if (thinkingTraceEmitter != null) {
+            thinkingTraceEmitter.markTerminated("FAILED");
+        }
         emitSafely(emitter, "status", buildTerminalStatus(requestId, sessionId, "FAILED", model, startedAt, firstTokenAt));
         emitSafely(emitter, "error", AiChatErrorDto.builder()
                 .requestId(requestId)
@@ -521,7 +539,8 @@
     private ToolCallback[] wrapToolCallbacks(ToolCallback[] toolCallbacks, SseEmitter emitter, String requestId,
                                              Long sessionId, AtomicLong toolCallSequence,
                                              AtomicLong toolSuccessCount, AtomicLong toolFailureCount,
-                                             Long callLogId, Long userId, Long tenantId) {
+                                             Long callLogId, Long userId, Long tenantId,
+                                             ThinkingTraceEmitter thinkingTraceEmitter) {
         /** 缁欐墍鏈夊伐鍏峰洖璋冨涓婁竴灞傚彲瑙傛祴鍖呰锛岀敤浜庡疄鏃� SSE 杞ㄨ抗鍜屽璁℃棩蹇楄惤搴撱�� */
         if (Cools.isEmpty(toolCallbacks)) {
             return toolCallbacks;
@@ -532,7 +551,7 @@
                 continue;
             }
             wrappedCallbacks.add(new ObservableToolCallback(callback, emitter, requestId, sessionId, toolCallSequence,
-                    toolSuccessCount, toolFailureCount, callLogId, userId, tenantId));
+                    toolSuccessCount, toolFailureCount, callLogId, userId, tenantId, thinkingTraceEmitter));
         }
         return wrappedCallbacks.toArray(new ToolCallback[0]);
     }
@@ -725,6 +744,125 @@
         return false;
     }
 
+    private class ThinkingTraceEmitter {
+
+        private final SseEmitter emitter;
+        private final String requestId;
+        private final Long sessionId;
+        private String currentPhase;
+        private String currentStatus;
+
+        private ThinkingTraceEmitter(SseEmitter emitter, String requestId, Long sessionId) {
+            this.emitter = emitter;
+            this.requestId = requestId;
+            this.sessionId = sessionId;
+        }
+
+        private void startAnalyze() {
+            if (currentPhase != null) {
+                return;
+            }
+            currentPhase = "ANALYZE";
+            currentStatus = "STARTED";
+            emitThinkingEvent("ANALYZE", "STARTED", "姝e湪鍒嗘瀽闂",
+                    "宸叉帴鏀朵綘鐨勯棶棰橈紝姝e湪鐞嗚В鎰忓浘骞跺垽鏂槸鍚﹂渶瑕佽皟鐢ㄥ伐鍏枫��", null);
+        }
+
+        private void onToolStart(String toolName, String toolCallId) {
+            switchPhase("TOOL_CALL", "STARTED", "姝e湪璋冪敤宸ュ叿", "宸插垽鏂渶瑕佽皟鐢ㄥ伐鍏凤紝姝e湪鏌ヨ鐩稿叧淇℃伅銆�", null);
+            currentStatus = "UPDATED";
+            emitThinkingEvent("TOOL_CALL", "UPDATED", "姝e湪璋冪敤宸ュ叿",
+                    "姝e湪璋冪敤宸ュ叿 " + safeLabel(toolName, "鏈煡宸ュ叿") + " 鑾峰彇鎵�闇�淇℃伅銆�", toolCallId);
+        }
+
+        private void onToolResult(String toolName, String toolCallId, boolean failed) {
+            currentPhase = "TOOL_CALL";
+            currentStatus = failed ? "FAILED" : "UPDATED";
+            emitThinkingEvent("TOOL_CALL", failed ? "FAILED" : "UPDATED",
+                    failed ? "宸ュ叿璋冪敤澶辫触" : "宸ュ叿璋冪敤瀹屾垚",
+                    failed
+                            ? "宸ュ叿 " + safeLabel(toolName, "鏈煡宸ュ叿") + " 璋冪敤澶辫触锛屾鍦ㄨ瘎浼板け璐ュ奖鍝嶅苟鏁寸悊鍙敤淇℃伅銆�"
+                            : "宸ュ叿 " + safeLabel(toolName, "鏈煡宸ュ叿") + " 宸茶繑鍥炵粨鏋滐紝姝e湪缁х画鍒嗘瀽骞舵彁鐐煎叧閿俊鎭��",
+                    toolCallId);
+        }
+
+        private void startAnswer() {
+            switchPhase("ANSWER", "STARTED", "姝e湪鏁寸悊绛旀", "宸插畬鎴愬垎鏋愶紝姝e湪缁勭粐鏈�缁堝洖澶嶅唴瀹广��", null);
+        }
+
+        private void completeCurrentPhase() {
+            if (!StringUtils.hasText(currentPhase) || isTerminalStatus(currentStatus)) {
+                return;
+            }
+            currentStatus = "COMPLETED";
+            emitThinkingEvent(currentPhase, "COMPLETED", resolveCompleteTitle(currentPhase),
+                    resolveCompleteContent(currentPhase), null);
+        }
+
+        private void markTerminated(String terminalStatus) {
+            if (!StringUtils.hasText(currentPhase) || isTerminalStatus(currentStatus)) {
+                return;
+            }
+            currentStatus = terminalStatus;
+            emitThinkingEvent(currentPhase, terminalStatus,
+                    "ABORTED".equals(terminalStatus) ? "鎬濊�冨凡涓" : "鎬濊�冨け璐�",
+                    "ABORTED".equals(terminalStatus)
+                            ? "鏈疆瀵硅瘽宸茶涓锛屾�濊�冭繃绋嬫彁鍓嶇粨鏉熴��"
+                            : "鏈疆瀵硅瘽鍦ㄧ敓鎴愮瓟妗堝墠澶辫触锛屽綋鍓嶆�濊�冭繃绋嬪凡鍋滄銆�",
+                    null);
+        }
+
+        private void switchPhase(String nextPhase, String nextStatus, String title, String content, String toolCallId) {
+            if (!Objects.equals(currentPhase, nextPhase)) {
+                completeCurrentPhase();
+            }
+            currentPhase = nextPhase;
+            currentStatus = nextStatus;
+            emitThinkingEvent(nextPhase, nextStatus, title, content, toolCallId);
+        }
+
+        private void emitThinkingEvent(String phase, String status, String title, String content, String toolCallId) {
+            emitSafely(emitter, "thinking", AiChatThinkingEventDto.builder()
+                    .requestId(requestId)
+                    .sessionId(sessionId)
+                    .phase(phase)
+                    .status(status)
+                    .title(title)
+                    .content(content)
+                    .toolCallId(toolCallId)
+                    .timestamp(Instant.now().toEpochMilli())
+                    .build());
+        }
+
+        private boolean isTerminalStatus(String status) {
+            return "COMPLETED".equals(status) || "FAILED".equals(status) || "ABORTED".equals(status);
+        }
+
+        private String resolveCompleteTitle(String phase) {
+            if ("ANSWER".equals(phase)) {
+                return "绛旀鏁寸悊瀹屾垚";
+            }
+            if ("TOOL_CALL".equals(phase)) {
+                return "宸ュ叿鍒嗘瀽瀹屾垚";
+            }
+            return "闂鍒嗘瀽瀹屾垚";
+        }
+
+        private String resolveCompleteContent(String phase) {
+            if ("ANSWER".equals(phase)) {
+                return "鏈�缁堢瓟澶嶅凡鐢熸垚瀹屾垚銆�";
+            }
+            if ("TOOL_CALL".equals(phase)) {
+                return "宸ュ叿璋冪敤闃舵宸茬粨鏉燂紝鐩稿叧淇℃伅宸叉暣鐞嗗畬姣曘��";
+            }
+            return "闂鎰忓浘鍜屽鐞嗘柟鍚戝凡鍒嗘瀽瀹屾垚銆�";
+        }
+
+        private String safeLabel(String value, String fallback) {
+            return StringUtils.hasText(value) ? value : fallback;
+        }
+    }
+
     private class ObservableToolCallback implements ToolCallback {
 
         private final ToolCallback delegate;
@@ -737,11 +875,13 @@
         private final Long callLogId;
         private final Long userId;
         private final Long tenantId;
+        private final ThinkingTraceEmitter thinkingTraceEmitter;
 
         private ObservableToolCallback(ToolCallback delegate, SseEmitter emitter, String requestId,
                                        Long sessionId, AtomicLong toolCallSequence,
                                        AtomicLong toolSuccessCount, AtomicLong toolFailureCount,
-                                       Long callLogId, Long userId, Long tenantId) {
+                                       Long callLogId, Long userId, Long tenantId,
+                                       ThinkingTraceEmitter thinkingTraceEmitter) {
             this.delegate = delegate;
             this.emitter = emitter;
             this.requestId = requestId;
@@ -752,6 +892,7 @@
             this.callLogId = callLogId;
             this.userId = userId;
             this.tenantId = tenantId;
+            this.thinkingTraceEmitter = thinkingTraceEmitter;
         }
 
         @Override
@@ -780,6 +921,9 @@
             String mountName = delegate instanceof MountedToolCallback ? ((MountedToolCallback) delegate).getMountName() : null;
             String toolCallId = requestId + "-tool-" + toolCallSequence.incrementAndGet();
             long startedAt = System.currentTimeMillis();
+            if (thinkingTraceEmitter != null) {
+                thinkingTraceEmitter.onToolStart(toolName, toolCallId);
+            }
             emitSafely(emitter, "tool_start", AiChatToolEventDto.builder()
                     .requestId(requestId)
                     .sessionId(sessionId)
@@ -805,6 +949,9 @@
                         .durationMs(durationMs)
                         .timestamp(System.currentTimeMillis())
                         .build());
+                if (thinkingTraceEmitter != null) {
+                    thinkingTraceEmitter.onToolResult(toolName, toolCallId, false);
+                }
                 toolSuccessCount.incrementAndGet();
                 aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
                         "COMPLETED", summarizeToolPayload(toolInput, 400), summarizeToolPayload(output, 600),
@@ -824,6 +971,9 @@
                         .durationMs(durationMs)
                         .timestamp(System.currentTimeMillis())
                         .build());
+                if (thinkingTraceEmitter != null) {
+                    thinkingTraceEmitter.onToolResult(toolName, toolCallId, true);
+                }
                 toolFailureCount.incrementAndGet();
                 aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
                         "FAILED", summarizeToolPayload(toolInput, 400), null, e.getMessage(),

--
Gitblit v1.9.1