zhou zhou
20 小时以前 4954d3978cf1967729a5a2d5b90f6baef18974da
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java
@@ -8,12 +8,14 @@
import com.vincent.rsf.server.ai.dto.AiChatErrorDto;
import com.vincent.rsf.server.ai.dto.AiChatMemoryDto;
import com.vincent.rsf.server.ai.dto.AiChatMessageDto;
import com.vincent.rsf.server.ai.dto.AiChatModelOptionDto;
import com.vincent.rsf.server.ai.dto.AiChatRequest;
import com.vincent.rsf.server.ai.dto.AiChatRuntimeDto;
import com.vincent.rsf.server.ai.dto.AiChatStatusDto;
import com.vincent.rsf.server.ai.dto.AiChatSessionDto;
import com.vincent.rsf.server.ai.dto.AiChatSessionPinRequest;
import com.vincent.rsf.server.ai.dto.AiChatSessionRenameRequest;
import com.vincent.rsf.server.ai.dto.AiChatThinkingEventDto;
import com.vincent.rsf.server.ai.dto.AiChatToolEventDto;
import com.vincent.rsf.server.ai.dto.AiResolvedConfig;
import com.vincent.rsf.server.ai.entity.AiCallLog;
@@ -26,6 +28,7 @@
import com.vincent.rsf.server.ai.service.AiChatService;
import com.vincent.rsf.server.ai.service.AiChatMemoryService;
import com.vincent.rsf.server.ai.service.AiConfigResolverService;
import com.vincent.rsf.server.ai.service.AiParamService;
import com.vincent.rsf.server.ai.service.MountedToolCallback;
import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory;
import io.micrometer.observation.ObservationRegistry;
@@ -51,12 +54,9 @@
import org.springframework.ai.util.json.schema.SchemaType;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.http.MediaType;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
import org.springframework.web.client.RestClient;
import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import reactor.core.publisher.Flux;
@@ -80,8 +80,10 @@
    private final AiConfigResolverService aiConfigResolverService;
    private final AiChatMemoryService aiChatMemoryService;
    private final AiParamService aiParamService;
    private final McpMountRuntimeFactory mcpMountRuntimeFactory;
    private final AiCallLogService aiCallLogService;
    private final AiRedisSupport aiRedisSupport;
    private final GenericApplicationContext applicationContext;
    private final ObservationRegistry observationRegistry;
    private final ObjectMapper objectMapper;
@@ -93,24 +95,31 @@
     * 该方法不会触发模型调用,而是把配置解析结果和会话记忆聚合成前端一次渲染所需的快照。
     */
    @Override
    public AiChatRuntimeDto getRuntime(String promptCode, Long sessionId, Long userId, Long tenantId) {
        AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId);
    public AiChatRuntimeDto getRuntime(String promptCode, Long sessionId, Long aiParamId, Long userId, Long tenantId) {
        AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId, aiParamId);
        Long runtimeCacheAiParamId = aiParamId;
        // runtime 是配置快照和会话记忆的聚合视图,单独缓存能减少一次页面进入时的重复拼装。
        AiChatRuntimeDto cached = aiRedisSupport.getRuntime(tenantId, userId, config.getPromptCode(), sessionId, runtimeCacheAiParamId);
        if (cached != null) {
            return cached;
        }
        AiChatMemoryDto memory = aiChatMemoryService.getMemory(userId, tenantId, config.getPromptCode(), sessionId);
        return AiChatRuntimeDto.builder()
                .requestId(null)
                .sessionId(memory.getSessionId())
                .promptCode(config.getPromptCode())
                .promptName(config.getPrompt().getName())
                .model(config.getAiParam().getModel())
                .configuredMcpCount(config.getMcpMounts().size())
                .mountedMcpCount(config.getMcpMounts().size())
                .mountedMcpNames(config.getMcpMounts().stream().map(item -> item.getName()).toList())
                .mountErrors(List.of())
                .memorySummary(memory.getMemorySummary())
                .memoryFacts(memory.getMemoryFacts())
                .recentMessageCount(memory.getRecentMessageCount())
                .persistedMessages(memory.getPersistedMessages())
                .build();
        List<AiChatModelOptionDto> modelOptions = aiParamService.listChatModelOptions(tenantId);
        AiChatRuntimeDto runtime = buildRuntimeSnapshot(
                null,
                memory.getSessionId(),
                config,
                modelOptions,
                config.getMcpMounts().size(),
                config.getMcpMounts().stream().map(item -> item.getName()).toList(),
                List.of(),
                memory
        );
        aiRedisSupport.cacheRuntime(tenantId, userId, config.getPromptCode(), sessionId, runtimeCacheAiParamId, runtime);
        if (memory.getSessionId() != null && !Objects.equals(memory.getSessionId(), sessionId)) {
            aiRedisSupport.cacheRuntime(tenantId, userId, config.getPromptCode(), memory.getSessionId(), runtimeCacheAiParamId, runtime);
        }
        return runtime;
    }
    /**
@@ -176,13 +185,23 @@
        Long sessionId = request.getSessionId();
        Long callLogId = null;
        String model = null;
        String resolvedPromptCode = request.getPromptCode();
        ThinkingTraceEmitter thinkingTraceEmitter = null;
        try {
            ensureIdentity(userId, tenantId);
            AiResolvedConfig config = resolveConfig(request, tenantId);
            List<AiChatModelOptionDto> modelOptions = aiParamService.listChatModelOptions(tenantId);
            resolvedPromptCode = config.getPromptCode();
            if (!aiRedisSupport.allowChatRequest(tenantId, userId, config.getPromptCode())) {
                throw buildAiException("AI_RATE_LIMITED", AiErrorCategory.REQUEST, "RATE_LIMIT",
                        "当前提问过于频繁,请稍后再试", null);
            }
            final String resolvedModel = config.getAiParam().getModel();
            model = resolvedModel;
            AiChatSession session = resolveSession(request, userId, tenantId, config.getPromptCode());
            sessionId = session.getId();
            // 流状态落 Redis 的目标是给多实例和后续运维查询留统一入口,不替代数据库日志。
            aiRedisSupport.markStreamState(requestId, tenantId, userId, sessionId, config.getPromptCode(), "RUNNING", null);
            AiChatMemoryDto memory = loadMemory(userId, tenantId, config.getPromptCode(), session.getId());
            List<AiChatMessageDto> mergedMessages = mergeMessages(memory.getShortMemoryMessages(), request.getMessages());
            AiCallLog callLog = aiCallLogService.startCallLog(
@@ -199,21 +218,16 @@
            );
            callLogId = callLog.getId();
            try (McpMountRuntimeFactory.McpMountRuntime runtime = createRuntime(config, userId)) {
                emitStrict(emitter, "start", AiChatRuntimeDto.builder()
                        .requestId(requestId)
                        .sessionId(session.getId())
                        .promptCode(config.getPromptCode())
                        .promptName(config.getPrompt().getName())
                        .model(config.getAiParam().getModel())
                        .configuredMcpCount(config.getMcpMounts().size())
                        .mountedMcpCount(runtime.getMountedCount())
                        .mountedMcpNames(runtime.getMountedNames())
                        .mountErrors(runtime.getErrors())
                        .memorySummary(memory.getMemorySummary())
                        .memoryFacts(memory.getMemoryFacts())
                        .recentMessageCount(memory.getRecentMessageCount())
                        .persistedMessages(memory.getPersistedMessages())
                        .build());
                emitStrict(emitter, "start", buildRuntimeSnapshot(
                        requestId,
                        session.getId(),
                        config,
                        modelOptions,
                        runtime.getMountedCount(),
                        runtime.getMountedNames(),
                        runtime.getErrors(),
                        memory
                ));
                emitSafely(emitter, "status", AiChatStatusDto.builder()
                        .requestId(requestId)
                        .sessionId(session.getId())
@@ -224,10 +238,13 @@
                        .build());
                log.info("AI chat started, requestId={}, userId={}, tenantId={}, sessionId={}, model={}",
                        requestId, userId, tenantId, session.getId(), resolvedModel);
                thinkingTraceEmitter = new ThinkingTraceEmitter(emitter, requestId, session.getId());
                thinkingTraceEmitter.startAnalyze();
                ThinkingTraceEmitter activeThinkingTraceEmitter = thinkingTraceEmitter;
                ToolCallback[] observableToolCallbacks = wrapToolCallbacks(
                        runtime.getToolCallbacks(), emitter, requestId, session.getId(), toolCallSequence,
                        toolSuccessCount, toolFailureCount, callLogId, userId, tenantId
                        toolSuccessCount, toolFailureCount, callLogId, userId, tenantId, activeThinkingTraceEmitter
                );
                Prompt prompt = new Prompt(
                        buildPromptMessages(memory, mergedMessages, config.getPrompt(), request.getMetadata()),
@@ -240,9 +257,10 @@
                    String content = extractContent(response);
                    aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), content);
                    if (StringUtils.hasText(content)) {
                        markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt);
                        markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt, activeThinkingTraceEmitter);
                        emitStrict(emitter, "delta", buildMessagePayload("requestId", requestId, "content", content));
                    }
                    activeThinkingTraceEmitter.completeCurrentPhase();
                    emitDone(emitter, requestId, response.getMetadata(), config.getAiParam().getModel(), session.getId(), startedAt, firstTokenAtRef.get());
                    emitSafely(emitter, "status", buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get()));
                    aiCallLogService.completeCallLog(
@@ -256,6 +274,7 @@
                            toolSuccessCount.get(),
                            toolFailureCount.get()
                    );
                    aiRedisSupport.markStreamState(requestId, tenantId, userId, session.getId(), config.getPromptCode(), "COMPLETED", null);
                    log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}",
                            requestId, session.getId(), System.currentTimeMillis() - startedAt, resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()));
                    emitter.complete();
@@ -270,7 +289,7 @@
                            lastMetadata.set(response.getMetadata());
                            String content = extractContent(response);
                            if (StringUtils.hasText(content)) {
                                markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt);
                                markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt, activeThinkingTraceEmitter);
                                assistantContent.append(content);
                                emitStrict(emitter, "delta", buildMessagePayload("requestId", requestId, "content", content));
                            }
@@ -281,6 +300,7 @@
                            e == null ? "AI 模型流式调用失败" : e.getMessage(), e);
                }
                aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), assistantContent.toString());
                activeThinkingTraceEmitter.completeCurrentPhase();
                emitDone(emitter, requestId, lastMetadata.get(), config.getAiParam().getModel(), session.getId(), startedAt, firstTokenAtRef.get());
                emitSafely(emitter, "status", buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get()));
                aiCallLogService.completeCallLog(
@@ -294,18 +314,21 @@
                        toolSuccessCount.get(),
                        toolFailureCount.get()
                );
                aiRedisSupport.markStreamState(requestId, tenantId, userId, session.getId(), config.getPromptCode(), "COMPLETED", null);
                log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}",
                        requestId, session.getId(), System.currentTimeMillis() - startedAt, resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()));
                emitter.complete();
            }
        } catch (AiChatException e) {
            handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(), e,
                    callLogId, toolSuccessCount.get(), toolFailureCount.get());
                    callLogId, toolSuccessCount.get(), toolFailureCount.get(), thinkingTraceEmitter,
                    tenantId, userId, resolvedPromptCode);
        } catch (Exception e) {
            handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(),
                    buildAiException("AI_INTERNAL_ERROR", AiErrorCategory.INTERNAL, "INTERNAL",
                            e == null ? "AI 对话失败" : e.getMessage(), e),
                    callLogId, toolSuccessCount.get(), toolFailureCount.get());
                    callLogId, toolSuccessCount.get(), toolFailureCount.get(), thinkingTraceEmitter,
                    tenantId, userId, resolvedPromptCode);
        } finally {
            log.debug("AI chat stream finished, requestId={}", requestId);
        }
@@ -323,11 +346,34 @@
    private AiResolvedConfig resolveConfig(AiChatRequest request, Long tenantId) {
        /** 把请求里的 Prompt 场景解析成一份可直接执行的 AI 配置。 */
        try {
            return aiConfigResolverService.resolve(request.getPromptCode(), tenantId);
            return aiConfigResolverService.resolve(request.getPromptCode(), tenantId, request.getAiParamId());
        } catch (Exception e) {
            throw buildAiException("AI_CONFIG_RESOLVE_ERROR", AiErrorCategory.CONFIG, "CONFIG_RESOLVE",
                    e == null ? "AI 配置解析失败" : e.getMessage(), e);
        }
    }
    private AiChatRuntimeDto buildRuntimeSnapshot(String requestId, Long sessionId, AiResolvedConfig config,
                                                  List<AiChatModelOptionDto> modelOptions, Integer mountedMcpCount,
                                                  List<String> mountedMcpNames, List<String> mountErrors,
                                                  AiChatMemoryDto memory) {
        return AiChatRuntimeDto.builder()
                .requestId(requestId)
                .sessionId(sessionId)
                .aiParamId(config.getAiParam().getId())
                .promptCode(config.getPromptCode())
                .promptName(config.getPrompt().getName())
                .model(config.getAiParam().getModel())
                .modelOptions(modelOptions)
                .configuredMcpCount(config.getMcpMounts().size())
                .mountedMcpCount(mountedMcpCount)
                .mountedMcpNames(mountedMcpNames)
                .mountErrors(mountErrors)
                .memorySummary(memory.getMemorySummary())
                .memoryFacts(memory.getMemoryFacts())
                .recentMessageCount(memory.getRecentMessageCount())
                .persistedMessages(memory.getPersistedMessages())
                .build();
    }
    private AiChatSession resolveSession(AiChatRequest request, Long userId, Long tenantId, String promptCode) {
@@ -378,9 +424,13 @@
        }
    }
    private void markFirstToken(AtomicReference<Long> firstTokenAtRef, SseEmitter emitter, String requestId, Long sessionId, String model, long startedAt) {
    private void markFirstToken(AtomicReference<Long> firstTokenAtRef, SseEmitter emitter, String requestId,
                                Long sessionId, String model, long startedAt, ThinkingTraceEmitter thinkingTraceEmitter) {
        if (!firstTokenAtRef.compareAndSet(null, System.currentTimeMillis())) {
            return;
        }
        if (thinkingTraceEmitter != null) {
            thinkingTraceEmitter.startAnswer();
        }
        emitSafely(emitter, "status", AiChatStatusDto.builder()
                .requestId(requestId)
@@ -411,10 +461,15 @@
    private void handleStreamFailure(SseEmitter emitter, String requestId, Long sessionId, String model, long startedAt,
                                     Long firstTokenAt, AiChatException exception, Long callLogId,
                                     long toolSuccessCount, long toolFailureCount) {
                                     long toolSuccessCount, long toolFailureCount,
                                     ThinkingTraceEmitter thinkingTraceEmitter,
                                     Long tenantId, Long userId, String promptCode) {
        if (isClientAbortException(exception)) {
            log.warn("AI chat aborted by client, requestId={}, sessionId={}, stage={}, message={}",
                    requestId, sessionId, exception.getStage(), exception.getMessage());
            if (thinkingTraceEmitter != null) {
                thinkingTraceEmitter.markTerminated("ABORTED");
            }
            emitSafely(emitter, "status", buildTerminalStatus(requestId, sessionId, "ABORTED", model, startedAt, firstTokenAt));
            aiCallLogService.failCallLog(
                    callLogId,
@@ -427,11 +482,15 @@
                    toolSuccessCount,
                    toolFailureCount
            );
            aiRedisSupport.markStreamState(requestId, tenantId, userId, sessionId, promptCode, "ABORTED", exception.getMessage());
            emitter.completeWithError(exception);
            return;
        }
        log.error("AI chat failed, requestId={}, sessionId={}, category={}, stage={}, message={}",
                requestId, sessionId, exception.getCategory(), exception.getStage(), exception.getMessage(), exception);
        if (thinkingTraceEmitter != null) {
            thinkingTraceEmitter.markTerminated("FAILED");
        }
        emitSafely(emitter, "status", buildTerminalStatus(requestId, sessionId, "FAILED", model, startedAt, firstTokenAt));
        emitSafely(emitter, "error", AiChatErrorDto.builder()
                .requestId(requestId)
@@ -453,6 +512,7 @@
                toolSuccessCount,
                toolFailureCount
        );
        aiRedisSupport.markStreamState(requestId, tenantId, userId, sessionId, promptCode, "FAILED", exception.getMessage());
        emitter.completeWithError(exception);
    }
@@ -479,17 +539,7 @@
    }
    private OpenAiApi buildOpenAiApi(AiParam aiParam) {
        int timeoutMs = aiParam.getTimeoutMs() == null ? AiDefaults.DEFAULT_TIMEOUT_MS : aiParam.getTimeoutMs();
        SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory();
        requestFactory.setConnectTimeout(timeoutMs);
        requestFactory.setReadTimeout(timeoutMs);
        return OpenAiApi.builder()
                .baseUrl(aiParam.getBaseUrl())
                .apiKey(aiParam.getApiKey())
                .restClientBuilder(RestClient.builder().requestFactory(requestFactory))
                .webClientBuilder(WebClient.builder())
                .build();
        return AiOpenAiApiSupport.buildOpenAiApi(aiParam);
    }
    private OpenAiChatOptions buildChatOptions(AiParam aiParam, ToolCallback[] toolCallbacks, Long userId, Long tenantId,
@@ -534,7 +584,8 @@
    private ToolCallback[] wrapToolCallbacks(ToolCallback[] toolCallbacks, SseEmitter emitter, String requestId,
                                             Long sessionId, AtomicLong toolCallSequence,
                                             AtomicLong toolSuccessCount, AtomicLong toolFailureCount,
                                             Long callLogId, Long userId, Long tenantId) {
                                             Long callLogId, Long userId, Long tenantId,
                                             ThinkingTraceEmitter thinkingTraceEmitter) {
        /** 给所有工具回调套上一层可观测包装,用于实时 SSE 轨迹和审计日志落库。 */
        if (Cools.isEmpty(toolCallbacks)) {
            return toolCallbacks;
@@ -545,7 +596,7 @@
                continue;
            }
            wrappedCallbacks.add(new ObservableToolCallback(callback, emitter, requestId, sessionId, toolCallSequence,
                    toolSuccessCount, toolFailureCount, callLogId, userId, tenantId));
                    toolSuccessCount, toolFailureCount, callLogId, userId, tenantId, thinkingTraceEmitter));
        }
        return wrappedCallbacks.toArray(new ToolCallback[0]);
    }
@@ -738,6 +789,125 @@
        return false;
    }
    private class ThinkingTraceEmitter {
        private final SseEmitter emitter;
        private final String requestId;
        private final Long sessionId;
        private String currentPhase;
        private String currentStatus;
        private ThinkingTraceEmitter(SseEmitter emitter, String requestId, Long sessionId) {
            this.emitter = emitter;
            this.requestId = requestId;
            this.sessionId = sessionId;
        }
        private void startAnalyze() {
            if (currentPhase != null) {
                return;
            }
            currentPhase = "ANALYZE";
            currentStatus = "STARTED";
            emitThinkingEvent("ANALYZE", "STARTED", "正在分析问题",
                    "已接收你的问题,正在理解意图并判断是否需要调用工具。", null);
        }
        private void onToolStart(String toolName, String toolCallId) {
            switchPhase("TOOL_CALL", "STARTED", "正在调用工具", "已判断需要调用工具,正在查询相关信息。", null);
            currentStatus = "UPDATED";
            emitThinkingEvent("TOOL_CALL", "UPDATED", "正在调用工具",
                    "正在调用工具 " + safeLabel(toolName, "未知工具") + " 获取所需信息。", toolCallId);
        }
        private void onToolResult(String toolName, String toolCallId, boolean failed) {
            currentPhase = "TOOL_CALL";
            currentStatus = failed ? "FAILED" : "UPDATED";
            emitThinkingEvent("TOOL_CALL", failed ? "FAILED" : "UPDATED",
                    failed ? "工具调用失败" : "工具调用完成",
                    failed
                            ? "工具 " + safeLabel(toolName, "未知工具") + " 调用失败,正在评估失败影响并整理可用信息。"
                            : "工具 " + safeLabel(toolName, "未知工具") + " 已返回结果,正在继续分析并提炼关键信息。",
                    toolCallId);
        }
        private void startAnswer() {
            switchPhase("ANSWER", "STARTED", "正在整理答案", "已完成分析,正在组织最终回复内容。", null);
        }
        private void completeCurrentPhase() {
            if (!StringUtils.hasText(currentPhase) || isTerminalStatus(currentStatus)) {
                return;
            }
            currentStatus = "COMPLETED";
            emitThinkingEvent(currentPhase, "COMPLETED", resolveCompleteTitle(currentPhase),
                    resolveCompleteContent(currentPhase), null);
        }
        private void markTerminated(String terminalStatus) {
            if (!StringUtils.hasText(currentPhase) || isTerminalStatus(currentStatus)) {
                return;
            }
            currentStatus = terminalStatus;
            emitThinkingEvent(currentPhase, terminalStatus,
                    "ABORTED".equals(terminalStatus) ? "思考已中止" : "思考失败",
                    "ABORTED".equals(terminalStatus)
                            ? "本轮对话已被中止,思考过程提前结束。"
                            : "本轮对话在生成答案前失败,当前思考过程已停止。",
                    null);
        }
        private void switchPhase(String nextPhase, String nextStatus, String title, String content, String toolCallId) {
            if (!Objects.equals(currentPhase, nextPhase)) {
                completeCurrentPhase();
            }
            currentPhase = nextPhase;
            currentStatus = nextStatus;
            emitThinkingEvent(nextPhase, nextStatus, title, content, toolCallId);
        }
        private void emitThinkingEvent(String phase, String status, String title, String content, String toolCallId) {
            emitSafely(emitter, "thinking", AiChatThinkingEventDto.builder()
                    .requestId(requestId)
                    .sessionId(sessionId)
                    .phase(phase)
                    .status(status)
                    .title(title)
                    .content(content)
                    .toolCallId(toolCallId)
                    .timestamp(Instant.now().toEpochMilli())
                    .build());
        }
        private boolean isTerminalStatus(String status) {
            return "COMPLETED".equals(status) || "FAILED".equals(status) || "ABORTED".equals(status);
        }
        private String resolveCompleteTitle(String phase) {
            if ("ANSWER".equals(phase)) {
                return "答案整理完成";
            }
            if ("TOOL_CALL".equals(phase)) {
                return "工具分析完成";
            }
            return "问题分析完成";
        }
        private String resolveCompleteContent(String phase) {
            if ("ANSWER".equals(phase)) {
                return "最终答复已生成完成。";
            }
            if ("TOOL_CALL".equals(phase)) {
                return "工具调用阶段已结束,相关信息已整理完毕。";
            }
            return "问题意图和处理方向已分析完成。";
        }
        private String safeLabel(String value, String fallback) {
            return StringUtils.hasText(value) ? value : fallback;
        }
    }
    private class ObservableToolCallback implements ToolCallback {
        private final ToolCallback delegate;
@@ -750,11 +920,13 @@
        private final Long callLogId;
        private final Long userId;
        private final Long tenantId;
        private final ThinkingTraceEmitter thinkingTraceEmitter;
        private ObservableToolCallback(ToolCallback delegate, SseEmitter emitter, String requestId,
                                       Long sessionId, AtomicLong toolCallSequence,
                                       AtomicLong toolSuccessCount, AtomicLong toolFailureCount,
                                       Long callLogId, Long userId, Long tenantId) {
                                       Long callLogId, Long userId, Long tenantId,
                                       ThinkingTraceEmitter thinkingTraceEmitter) {
            this.delegate = delegate;
            this.emitter = emitter;
            this.requestId = requestId;
@@ -765,6 +937,7 @@
            this.callLogId = callLogId;
            this.userId = userId;
            this.tenantId = tenantId;
            this.thinkingTraceEmitter = thinkingTraceEmitter;
        }
        @Override
@@ -793,6 +966,41 @@
            String mountName = delegate instanceof MountedToolCallback ? ((MountedToolCallback) delegate).getMountName() : null;
            String toolCallId = requestId + "-tool-" + toolCallSequence.incrementAndGet();
            long startedAt = System.currentTimeMillis();
            // 这里只对同一 request 内的重复工具调用做短期复用,避免把跨请求结果误当成通用缓存。
            AiRedisSupport.CachedToolResult cachedToolResult = aiRedisSupport.getToolResult(tenantId, requestId, toolName, toolInput);
            if (cachedToolResult != null) {
                emitSafely(emitter, "tool_result", AiChatToolEventDto.builder()
                        .requestId(requestId)
                        .sessionId(sessionId)
                        .toolCallId(toolCallId)
                        .toolName(toolName)
                        .mountName(mountName)
                        .status(cachedToolResult.isSuccess() ? "COMPLETED" : "FAILED")
                        .inputSummary(summarizeToolPayload(toolInput, 400))
                        .outputSummary(summarizeToolPayload(cachedToolResult.getOutput(), 600))
                        .errorMessage(cachedToolResult.getErrorMessage())
                        .durationMs(0L)
                        .timestamp(System.currentTimeMillis())
                        .build());
                if (thinkingTraceEmitter != null) {
                    thinkingTraceEmitter.onToolResult(toolName, toolCallId, !cachedToolResult.isSuccess());
                }
                if (cachedToolResult.isSuccess()) {
                    toolSuccessCount.incrementAndGet();
                    aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
                            "COMPLETED", summarizeToolPayload(toolInput, 400), summarizeToolPayload(cachedToolResult.getOutput(), 600),
                            null, 0L, userId, tenantId);
                    return cachedToolResult.getOutput();
                }
                toolFailureCount.incrementAndGet();
                aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
                        "FAILED", summarizeToolPayload(toolInput, 400), null, cachedToolResult.getErrorMessage(),
                        0L, userId, tenantId);
                throw new CoolException(cachedToolResult.getErrorMessage());
            }
            if (thinkingTraceEmitter != null) {
                thinkingTraceEmitter.onToolStart(toolName, toolCallId);
            }
            emitSafely(emitter, "tool_start", AiChatToolEventDto.builder()
                    .requestId(requestId)
                    .sessionId(sessionId)
@@ -818,6 +1026,10 @@
                        .durationMs(durationMs)
                        .timestamp(System.currentTimeMillis())
                        .build());
                if (thinkingTraceEmitter != null) {
                    thinkingTraceEmitter.onToolResult(toolName, toolCallId, false);
                }
                aiRedisSupport.cacheToolResult(tenantId, requestId, toolName, toolInput, true, output, null);
                toolSuccessCount.incrementAndGet();
                aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
                        "COMPLETED", summarizeToolPayload(toolInput, 400), summarizeToolPayload(output, 600),
@@ -837,6 +1049,10 @@
                        .durationMs(durationMs)
                        .timestamp(System.currentTimeMillis())
                        .build());
                if (thinkingTraceEmitter != null) {
                    thinkingTraceEmitter.onToolResult(toolName, toolCallId, true);
                }
                aiRedisSupport.cacheToolResult(tenantId, requestId, toolName, toolInput, false, null, e.getMessage());
                toolFailureCount.incrementAndGet();
                aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
                        "FAILED", summarizeToolPayload(toolInput, 400), null, e.getMessage(),