From 6477d7156272a6f1fe126c781958369bb10970c6 Mon Sep 17 00:00:00 2001
From: zhou zhou <3272660260@qq.com>
Date: 星期六, 21 三月 2026 11:15:50 +0800
Subject: [PATCH] #ai 思维链
---
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java | 185 ++++++++++++++++++++++++++++++++++++++++------
1 files changed, 161 insertions(+), 24 deletions(-)
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java
index 0430123..e7d842e 100644
--- a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java
@@ -14,6 +14,7 @@
import com.vincent.rsf.server.ai.dto.AiChatSessionDto;
import com.vincent.rsf.server.ai.dto.AiChatSessionPinRequest;
import com.vincent.rsf.server.ai.dto.AiChatSessionRenameRequest;
+import com.vincent.rsf.server.ai.dto.AiChatThinkingEventDto;
import com.vincent.rsf.server.ai.dto.AiChatToolEventDto;
import com.vincent.rsf.server.ai.dto.AiResolvedConfig;
import com.vincent.rsf.server.ai.entity.AiCallLog;
@@ -51,12 +52,9 @@
import org.springframework.ai.util.json.schema.SchemaType;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.http.MediaType;
-import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
-import org.springframework.web.client.RestClient;
-import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import reactor.core.publisher.Flux;
@@ -176,6 +174,7 @@
Long sessionId = request.getSessionId();
Long callLogId = null;
String model = null;
+ ThinkingTraceEmitter thinkingTraceEmitter = null;
try {
ensureIdentity(userId, tenantId);
AiResolvedConfig config = resolveConfig(request, tenantId);
@@ -224,10 +223,13 @@
.build());
log.info("AI chat started, requestId={}, userId={}, tenantId={}, sessionId={}, model={}",
requestId, userId, tenantId, session.getId(), resolvedModel);
+ thinkingTraceEmitter = new ThinkingTraceEmitter(emitter, requestId, session.getId());
+ thinkingTraceEmitter.startAnalyze();
+ ThinkingTraceEmitter activeThinkingTraceEmitter = thinkingTraceEmitter;
ToolCallback[] observableToolCallbacks = wrapToolCallbacks(
runtime.getToolCallbacks(), emitter, requestId, session.getId(), toolCallSequence,
- toolSuccessCount, toolFailureCount, callLogId, userId, tenantId
+ toolSuccessCount, toolFailureCount, callLogId, userId, tenantId, activeThinkingTraceEmitter
);
Prompt prompt = new Prompt(
buildPromptMessages(memory, mergedMessages, config.getPrompt(), request.getMetadata()),
@@ -240,9 +242,10 @@
String content = extractContent(response);
aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), content);
if (StringUtils.hasText(content)) {
- markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt);
+ markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt, activeThinkingTraceEmitter);
emitStrict(emitter, "delta", buildMessagePayload("requestId", requestId, "content", content));
}
+ activeThinkingTraceEmitter.completeCurrentPhase();
emitDone(emitter, requestId, response.getMetadata(), config.getAiParam().getModel(), session.getId(), startedAt, firstTokenAtRef.get());
emitSafely(emitter, "status", buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get()));
aiCallLogService.completeCallLog(
@@ -270,7 +273,7 @@
lastMetadata.set(response.getMetadata());
String content = extractContent(response);
if (StringUtils.hasText(content)) {
- markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt);
+ markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt, activeThinkingTraceEmitter);
assistantContent.append(content);
emitStrict(emitter, "delta", buildMessagePayload("requestId", requestId, "content", content));
}
@@ -281,6 +284,7 @@
e == null ? "AI 妯″瀷娴佸紡璋冪敤澶辫触" : e.getMessage(), e);
}
aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), assistantContent.toString());
+ activeThinkingTraceEmitter.completeCurrentPhase();
emitDone(emitter, requestId, lastMetadata.get(), config.getAiParam().getModel(), session.getId(), startedAt, firstTokenAtRef.get());
emitSafely(emitter, "status", buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get()));
aiCallLogService.completeCallLog(
@@ -300,12 +304,12 @@
}
} catch (AiChatException e) {
handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(), e,
- callLogId, toolSuccessCount.get(), toolFailureCount.get());
+ callLogId, toolSuccessCount.get(), toolFailureCount.get(), thinkingTraceEmitter);
} catch (Exception e) {
handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(),
buildAiException("AI_INTERNAL_ERROR", AiErrorCategory.INTERNAL, "INTERNAL",
e == null ? "AI 瀵硅瘽澶辫触" : e.getMessage(), e),
- callLogId, toolSuccessCount.get(), toolFailureCount.get());
+ callLogId, toolSuccessCount.get(), toolFailureCount.get(), thinkingTraceEmitter);
} finally {
log.debug("AI chat stream finished, requestId={}", requestId);
}
@@ -378,9 +382,13 @@
}
}
- private void markFirstToken(AtomicReference<Long> firstTokenAtRef, SseEmitter emitter, String requestId, Long sessionId, String model, long startedAt) {
+ private void markFirstToken(AtomicReference<Long> firstTokenAtRef, SseEmitter emitter, String requestId,
+ Long sessionId, String model, long startedAt, ThinkingTraceEmitter thinkingTraceEmitter) {
if (!firstTokenAtRef.compareAndSet(null, System.currentTimeMillis())) {
return;
+ }
+ if (thinkingTraceEmitter != null) {
+ thinkingTraceEmitter.startAnswer();
}
emitSafely(emitter, "status", AiChatStatusDto.builder()
.requestId(requestId)
@@ -411,10 +419,14 @@
private void handleStreamFailure(SseEmitter emitter, String requestId, Long sessionId, String model, long startedAt,
Long firstTokenAt, AiChatException exception, Long callLogId,
- long toolSuccessCount, long toolFailureCount) {
+ long toolSuccessCount, long toolFailureCount,
+ ThinkingTraceEmitter thinkingTraceEmitter) {
if (isClientAbortException(exception)) {
log.warn("AI chat aborted by client, requestId={}, sessionId={}, stage={}, message={}",
requestId, sessionId, exception.getStage(), exception.getMessage());
+ if (thinkingTraceEmitter != null) {
+ thinkingTraceEmitter.markTerminated("ABORTED");
+ }
emitSafely(emitter, "status", buildTerminalStatus(requestId, sessionId, "ABORTED", model, startedAt, firstTokenAt));
aiCallLogService.failCallLog(
callLogId,
@@ -432,6 +444,9 @@
}
log.error("AI chat failed, requestId={}, sessionId={}, category={}, stage={}, message={}",
requestId, sessionId, exception.getCategory(), exception.getStage(), exception.getMessage(), exception);
+ if (thinkingTraceEmitter != null) {
+ thinkingTraceEmitter.markTerminated("FAILED");
+ }
emitSafely(emitter, "status", buildTerminalStatus(requestId, sessionId, "FAILED", model, startedAt, firstTokenAt));
emitSafely(emitter, "error", AiChatErrorDto.builder()
.requestId(requestId)
@@ -479,17 +494,7 @@
}
private OpenAiApi buildOpenAiApi(AiParam aiParam) {
- int timeoutMs = aiParam.getTimeoutMs() == null ? AiDefaults.DEFAULT_TIMEOUT_MS : aiParam.getTimeoutMs();
- SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory();
- requestFactory.setConnectTimeout(timeoutMs);
- requestFactory.setReadTimeout(timeoutMs);
-
- return OpenAiApi.builder()
- .baseUrl(aiParam.getBaseUrl())
- .apiKey(aiParam.getApiKey())
- .restClientBuilder(RestClient.builder().requestFactory(requestFactory))
- .webClientBuilder(WebClient.builder())
- .build();
+ return AiOpenAiApiSupport.buildOpenAiApi(aiParam);
}
private OpenAiChatOptions buildChatOptions(AiParam aiParam, ToolCallback[] toolCallbacks, Long userId, Long tenantId,
@@ -534,7 +539,8 @@
private ToolCallback[] wrapToolCallbacks(ToolCallback[] toolCallbacks, SseEmitter emitter, String requestId,
Long sessionId, AtomicLong toolCallSequence,
AtomicLong toolSuccessCount, AtomicLong toolFailureCount,
- Long callLogId, Long userId, Long tenantId) {
+ Long callLogId, Long userId, Long tenantId,
+ ThinkingTraceEmitter thinkingTraceEmitter) {
/** 缁欐墍鏈夊伐鍏峰洖璋冨涓婁竴灞傚彲瑙傛祴鍖呰锛岀敤浜庡疄鏃� SSE 杞ㄨ抗鍜屽璁℃棩蹇楄惤搴撱�� */
if (Cools.isEmpty(toolCallbacks)) {
return toolCallbacks;
@@ -545,7 +551,7 @@
continue;
}
wrappedCallbacks.add(new ObservableToolCallback(callback, emitter, requestId, sessionId, toolCallSequence,
- toolSuccessCount, toolFailureCount, callLogId, userId, tenantId));
+ toolSuccessCount, toolFailureCount, callLogId, userId, tenantId, thinkingTraceEmitter));
}
return wrappedCallbacks.toArray(new ToolCallback[0]);
}
@@ -738,6 +744,125 @@
return false;
}
+ private class ThinkingTraceEmitter {
+
+ private final SseEmitter emitter;
+ private final String requestId;
+ private final Long sessionId;
+ private String currentPhase;
+ private String currentStatus;
+
+ private ThinkingTraceEmitter(SseEmitter emitter, String requestId, Long sessionId) {
+ this.emitter = emitter;
+ this.requestId = requestId;
+ this.sessionId = sessionId;
+ }
+
+ private void startAnalyze() {
+ if (currentPhase != null) {
+ return;
+ }
+ currentPhase = "ANALYZE";
+ currentStatus = "STARTED";
+ emitThinkingEvent("ANALYZE", "STARTED", "姝e湪鍒嗘瀽闂",
+ "宸叉帴鏀朵綘鐨勯棶棰橈紝姝e湪鐞嗚В鎰忓浘骞跺垽鏂槸鍚﹂渶瑕佽皟鐢ㄥ伐鍏枫��", null);
+ }
+
+ private void onToolStart(String toolName, String toolCallId) {
+ switchPhase("TOOL_CALL", "STARTED", "姝e湪璋冪敤宸ュ叿", "宸插垽鏂渶瑕佽皟鐢ㄥ伐鍏凤紝姝e湪鏌ヨ鐩稿叧淇℃伅銆�", null);
+ currentStatus = "UPDATED";
+ emitThinkingEvent("TOOL_CALL", "UPDATED", "姝e湪璋冪敤宸ュ叿",
+ "姝e湪璋冪敤宸ュ叿 " + safeLabel(toolName, "鏈煡宸ュ叿") + " 鑾峰彇鎵�闇�淇℃伅銆�", toolCallId);
+ }
+
+ private void onToolResult(String toolName, String toolCallId, boolean failed) {
+ currentPhase = "TOOL_CALL";
+ currentStatus = failed ? "FAILED" : "UPDATED";
+ emitThinkingEvent("TOOL_CALL", failed ? "FAILED" : "UPDATED",
+ failed ? "宸ュ叿璋冪敤澶辫触" : "宸ュ叿璋冪敤瀹屾垚",
+ failed
+ ? "宸ュ叿 " + safeLabel(toolName, "鏈煡宸ュ叿") + " 璋冪敤澶辫触锛屾鍦ㄨ瘎浼板け璐ュ奖鍝嶅苟鏁寸悊鍙敤淇℃伅銆�"
+ : "宸ュ叿 " + safeLabel(toolName, "鏈煡宸ュ叿") + " 宸茶繑鍥炵粨鏋滐紝姝e湪缁х画鍒嗘瀽骞舵彁鐐煎叧閿俊鎭��",
+ toolCallId);
+ }
+
+ private void startAnswer() {
+ switchPhase("ANSWER", "STARTED", "姝e湪鏁寸悊绛旀", "宸插畬鎴愬垎鏋愶紝姝e湪缁勭粐鏈�缁堝洖澶嶅唴瀹广��", null);
+ }
+
+ private void completeCurrentPhase() {
+ if (!StringUtils.hasText(currentPhase) || isTerminalStatus(currentStatus)) {
+ return;
+ }
+ currentStatus = "COMPLETED";
+ emitThinkingEvent(currentPhase, "COMPLETED", resolveCompleteTitle(currentPhase),
+ resolveCompleteContent(currentPhase), null);
+ }
+
+ private void markTerminated(String terminalStatus) {
+ if (!StringUtils.hasText(currentPhase) || isTerminalStatus(currentStatus)) {
+ return;
+ }
+ currentStatus = terminalStatus;
+ emitThinkingEvent(currentPhase, terminalStatus,
+ "ABORTED".equals(terminalStatus) ? "鎬濊�冨凡涓" : "鎬濊�冨け璐�",
+ "ABORTED".equals(terminalStatus)
+ ? "鏈疆瀵硅瘽宸茶涓锛屾�濊�冭繃绋嬫彁鍓嶇粨鏉熴��"
+ : "鏈疆瀵硅瘽鍦ㄧ敓鎴愮瓟妗堝墠澶辫触锛屽綋鍓嶆�濊�冭繃绋嬪凡鍋滄銆�",
+ null);
+ }
+
+ private void switchPhase(String nextPhase, String nextStatus, String title, String content, String toolCallId) {
+ if (!Objects.equals(currentPhase, nextPhase)) {
+ completeCurrentPhase();
+ }
+ currentPhase = nextPhase;
+ currentStatus = nextStatus;
+ emitThinkingEvent(nextPhase, nextStatus, title, content, toolCallId);
+ }
+
+ private void emitThinkingEvent(String phase, String status, String title, String content, String toolCallId) {
+ emitSafely(emitter, "thinking", AiChatThinkingEventDto.builder()
+ .requestId(requestId)
+ .sessionId(sessionId)
+ .phase(phase)
+ .status(status)
+ .title(title)
+ .content(content)
+ .toolCallId(toolCallId)
+ .timestamp(Instant.now().toEpochMilli())
+ .build());
+ }
+
+ private boolean isTerminalStatus(String status) {
+ return "COMPLETED".equals(status) || "FAILED".equals(status) || "ABORTED".equals(status);
+ }
+
+ private String resolveCompleteTitle(String phase) {
+ if ("ANSWER".equals(phase)) {
+ return "绛旀鏁寸悊瀹屾垚";
+ }
+ if ("TOOL_CALL".equals(phase)) {
+ return "宸ュ叿鍒嗘瀽瀹屾垚";
+ }
+ return "闂鍒嗘瀽瀹屾垚";
+ }
+
+ private String resolveCompleteContent(String phase) {
+ if ("ANSWER".equals(phase)) {
+ return "鏈�缁堢瓟澶嶅凡鐢熸垚瀹屾垚銆�";
+ }
+ if ("TOOL_CALL".equals(phase)) {
+ return "宸ュ叿璋冪敤闃舵宸茬粨鏉燂紝鐩稿叧淇℃伅宸叉暣鐞嗗畬姣曘��";
+ }
+ return "闂鎰忓浘鍜屽鐞嗘柟鍚戝凡鍒嗘瀽瀹屾垚銆�";
+ }
+
+ private String safeLabel(String value, String fallback) {
+ return StringUtils.hasText(value) ? value : fallback;
+ }
+ }
+
private class ObservableToolCallback implements ToolCallback {
private final ToolCallback delegate;
@@ -750,11 +875,13 @@
private final Long callLogId;
private final Long userId;
private final Long tenantId;
+ private final ThinkingTraceEmitter thinkingTraceEmitter;
private ObservableToolCallback(ToolCallback delegate, SseEmitter emitter, String requestId,
Long sessionId, AtomicLong toolCallSequence,
AtomicLong toolSuccessCount, AtomicLong toolFailureCount,
- Long callLogId, Long userId, Long tenantId) {
+ Long callLogId, Long userId, Long tenantId,
+ ThinkingTraceEmitter thinkingTraceEmitter) {
this.delegate = delegate;
this.emitter = emitter;
this.requestId = requestId;
@@ -765,6 +892,7 @@
this.callLogId = callLogId;
this.userId = userId;
this.tenantId = tenantId;
+ this.thinkingTraceEmitter = thinkingTraceEmitter;
}
@Override
@@ -793,6 +921,9 @@
String mountName = delegate instanceof MountedToolCallback ? ((MountedToolCallback) delegate).getMountName() : null;
String toolCallId = requestId + "-tool-" + toolCallSequence.incrementAndGet();
long startedAt = System.currentTimeMillis();
+ if (thinkingTraceEmitter != null) {
+ thinkingTraceEmitter.onToolStart(toolName, toolCallId);
+ }
emitSafely(emitter, "tool_start", AiChatToolEventDto.builder()
.requestId(requestId)
.sessionId(sessionId)
@@ -818,6 +949,9 @@
.durationMs(durationMs)
.timestamp(System.currentTimeMillis())
.build());
+ if (thinkingTraceEmitter != null) {
+ thinkingTraceEmitter.onToolResult(toolName, toolCallId, false);
+ }
toolSuccessCount.incrementAndGet();
aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
"COMPLETED", summarizeToolPayload(toolInput, 400), summarizeToolPayload(output, 600),
@@ -837,6 +971,9 @@
.durationMs(durationMs)
.timestamp(System.currentTimeMillis())
.build());
+ if (thinkingTraceEmitter != null) {
+ thinkingTraceEmitter.onToolResult(toolName, toolCallId, true);
+ }
toolFailureCount.incrementAndGet();
aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
"FAILED", summarizeToolPayload(toolInput, 400), null, e.getMessage(),
--
Gitblit v1.9.1