zhou zhou
16 小时以前 80a6d9236ade191a5de0975abe4de5a6e7e63915
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java
@@ -16,14 +16,17 @@
import com.vincent.rsf.server.ai.dto.AiChatSessionRenameRequest;
import com.vincent.rsf.server.ai.dto.AiChatToolEventDto;
import com.vincent.rsf.server.ai.dto.AiResolvedConfig;
import com.vincent.rsf.server.ai.entity.AiCallLog;
import com.vincent.rsf.server.ai.entity.AiParam;
import com.vincent.rsf.server.ai.entity.AiPrompt;
import com.vincent.rsf.server.ai.entity.AiChatSession;
import com.vincent.rsf.server.ai.enums.AiErrorCategory;
import com.vincent.rsf.server.ai.exception.AiChatException;
import com.vincent.rsf.server.ai.service.AiCallLogService;
import com.vincent.rsf.server.ai.service.AiChatService;
import com.vincent.rsf.server.ai.service.AiChatMemoryService;
import com.vincent.rsf.server.ai.service.AiConfigResolverService;
import com.vincent.rsf.server.ai.service.MountedToolCallback;
import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory;
import io.micrometer.observation.ObservationRegistry;
import lombok.RequiredArgsConstructor;
@@ -78,12 +81,17 @@
    private final AiConfigResolverService aiConfigResolverService;
    private final AiChatMemoryService aiChatMemoryService;
    private final McpMountRuntimeFactory mcpMountRuntimeFactory;
    private final AiCallLogService aiCallLogService;
    private final GenericApplicationContext applicationContext;
    private final ObservationRegistry observationRegistry;
    private final ObjectMapper objectMapper;
    @Qualifier("aiChatTaskExecutor")
    private final Executor aiChatTaskExecutor;
    /**
     * 获取当前对话抽屉初始化所需的运行时数据。
     * 该方法不会触发模型调用,而是把配置解析结果和会话记忆聚合成前端一次渲染所需的快照。
     */
    @Override
    public AiChatRuntimeDto getRuntime(String promptCode, Long sessionId, Long userId, Long tenantId) {
        AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId);
@@ -105,6 +113,9 @@
                .build();
    }
    /**
     * 查询指定 Prompt 场景下的历史会话摘要列表。
     */
    @Override
    public List<AiChatSessionDto> listSessions(String promptCode, String keyword, Long userId, Long tenantId) {
        AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId);
@@ -136,6 +147,10 @@
        aiChatMemoryService.retainLatestRound(userId, tenantId, sessionId);
    }
    /**
     * 启动一次新的 SSE 对话流。
     * 控制线程立即返回 emitter,真正的模型调用与工具执行交给 AI 专用线程池异步处理。
     */
    @Override
    public SseEmitter stream(AiChatRequest request, Long userId, Long tenantId) {
        SseEmitter emitter = new SseEmitter(AiDefaults.SSE_TIMEOUT_MS);
@@ -144,11 +159,22 @@
    }
    private void doStream(AiChatRequest request, Long userId, Long tenantId, SseEmitter emitter) {
        /**
         * AI 对话的核心执行链路:
         * 1. 校验身份和解析租户配置
         * 2. 解析或创建会话,加载记忆
         * 3. 动态挂载 MCP 工具
         * 4. 发起模型流式/非流式调用
         * 5. 持久化本轮消息,输出 SSE 事件并记录审计日志
         */
        String requestId = request.getRequestId();
        long startedAt = System.currentTimeMillis();
        AtomicReference<Long> firstTokenAtRef = new AtomicReference<>();
        AtomicLong toolCallSequence = new AtomicLong(0);
        AtomicLong toolSuccessCount = new AtomicLong(0);
        AtomicLong toolFailureCount = new AtomicLong(0);
        Long sessionId = request.getSessionId();
        Long callLogId = null;
        String model = null;
        try {
            ensureIdentity(userId, tenantId);
@@ -159,6 +185,19 @@
            sessionId = session.getId();
            AiChatMemoryDto memory = loadMemory(userId, tenantId, config.getPromptCode(), session.getId());
            List<AiChatMessageDto> mergedMessages = mergeMessages(memory.getShortMemoryMessages(), request.getMessages());
            AiCallLog callLog = aiCallLogService.startCallLog(
                    requestId,
                    session.getId(),
                    userId,
                    tenantId,
                    config.getPromptCode(),
                    config.getPrompt().getName(),
                    config.getAiParam().getModel(),
                    config.getMcpMounts().size(),
                    config.getMcpMounts().size(),
                    config.getMcpMounts().stream().map(item -> item.getName()).toList()
            );
            callLogId = callLog.getId();
            try (McpMountRuntimeFactory.McpMountRuntime runtime = createRuntime(config, userId)) {
                emitStrict(emitter, "start", AiChatRuntimeDto.builder()
                        .requestId(requestId)
@@ -187,7 +226,8 @@
                        requestId, userId, tenantId, session.getId(), resolvedModel);
                ToolCallback[] observableToolCallbacks = wrapToolCallbacks(
                        runtime.getToolCallbacks(), emitter, requestId, session.getId(), toolCallSequence
                        runtime.getToolCallbacks(), emitter, requestId, session.getId(), toolCallSequence,
                        toolSuccessCount, toolFailureCount, callLogId, userId, tenantId
                );
                Prompt prompt = new Prompt(
                        buildPromptMessages(memory, mergedMessages, config.getPrompt(), request.getMetadata()),
@@ -205,6 +245,17 @@
                    }
                    emitDone(emitter, requestId, response.getMetadata(), config.getAiParam().getModel(), session.getId(), startedAt, firstTokenAtRef.get());
                    emitSafely(emitter, "status", buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get()));
                    aiCallLogService.completeCallLog(
                            callLogId,
                            "COMPLETED",
                            System.currentTimeMillis() - startedAt,
                            resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()),
                            response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getPromptTokens(),
                            response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getCompletionTokens(),
                            response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getTotalTokens(),
                            toolSuccessCount.get(),
                            toolFailureCount.get()
                    );
                    log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}",
                            requestId, session.getId(), System.currentTimeMillis() - startedAt, resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()));
                    emitter.complete();
@@ -232,16 +283,29 @@
                aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), assistantContent.toString());
                emitDone(emitter, requestId, lastMetadata.get(), config.getAiParam().getModel(), session.getId(), startedAt, firstTokenAtRef.get());
                emitSafely(emitter, "status", buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get()));
                aiCallLogService.completeCallLog(
                        callLogId,
                        "COMPLETED",
                        System.currentTimeMillis() - startedAt,
                        resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()),
                        lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getPromptTokens(),
                        lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getCompletionTokens(),
                        lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getTotalTokens(),
                        toolSuccessCount.get(),
                        toolFailureCount.get()
                );
                log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}",
                        requestId, session.getId(), System.currentTimeMillis() - startedAt, resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()));
                emitter.complete();
            }
        } catch (AiChatException e) {
            handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(), e);
            handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(), e,
                    callLogId, toolSuccessCount.get(), toolFailureCount.get());
        } catch (Exception e) {
            handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(),
                    buildAiException("AI_INTERNAL_ERROR", AiErrorCategory.INTERNAL, "INTERNAL",
                            e == null ? "AI 对话失败" : e.getMessage(), e));
                            e == null ? "AI 对话失败" : e.getMessage(), e),
                    callLogId, toolSuccessCount.get(), toolFailureCount.get());
        } finally {
            log.debug("AI chat stream finished, requestId={}", requestId);
        }
@@ -257,6 +321,7 @@
    }
    private AiResolvedConfig resolveConfig(AiChatRequest request, Long tenantId) {
        /** 把请求里的 Prompt 场景解析成一份可直接执行的 AI 配置。 */
        try {
            return aiConfigResolverService.resolve(request.getPromptCode(), tenantId);
        } catch (Exception e) {
@@ -266,6 +331,7 @@
    }
    private AiChatSession resolveSession(AiChatRequest request, Long userId, Long tenantId, String promptCode) {
        /** 根据 sessionId 复用历史会话,或在首次提问时创建新会话。 */
        try {
            return aiChatMemoryService.resolveSession(userId, tenantId, promptCode, request.getSessionId(), resolveTitleSeed(request.getMessages()));
        } catch (Exception e) {
@@ -275,6 +341,7 @@
    }
    private AiChatMemoryDto loadMemory(Long userId, Long tenantId, String promptCode, Long sessionId) {
        /** 读取会话的短期记忆、摘要记忆和事实记忆,供模型组装上下文。 */
        try {
            return aiChatMemoryService.getMemory(userId, tenantId, promptCode, sessionId);
        } catch (Exception e) {
@@ -284,6 +351,7 @@
    }
    private McpMountRuntimeFactory.McpMountRuntime createRuntime(AiResolvedConfig config, Long userId) {
        /** 按配置中的 MCP 挂载记录构造本轮对话专属的工具运行时。 */
        try {
            return mcpMountRuntimeFactory.create(config.getMcpMounts(), userId);
        } catch (Exception e) {
@@ -341,11 +409,24 @@
        return firstTokenAt == null ? null : Math.max(0L, firstTokenAt - startedAt);
    }
    private void handleStreamFailure(SseEmitter emitter, String requestId, Long sessionId, String model, long startedAt, Long firstTokenAt, AiChatException exception) {
    private void handleStreamFailure(SseEmitter emitter, String requestId, Long sessionId, String model, long startedAt,
                                     Long firstTokenAt, AiChatException exception, Long callLogId,
                                     long toolSuccessCount, long toolFailureCount) {
        if (isClientAbortException(exception)) {
            log.warn("AI chat aborted by client, requestId={}, sessionId={}, stage={}, message={}",
                    requestId, sessionId, exception.getStage(), exception.getMessage());
            emitSafely(emitter, "status", buildTerminalStatus(requestId, sessionId, "ABORTED", model, startedAt, firstTokenAt));
            aiCallLogService.failCallLog(
                    callLogId,
                    "ABORTED",
                    exception.getCategory().name(),
                    exception.getStage(),
                    exception.getMessage(),
                    System.currentTimeMillis() - startedAt,
                    resolveFirstTokenLatency(startedAt, firstTokenAt),
                    toolSuccessCount,
                    toolFailureCount
            );
            emitter.completeWithError(exception);
            return;
        }
@@ -361,6 +442,17 @@
                .message(exception.getMessage())
                .timestamp(Instant.now().toEpochMilli())
                .build());
        aiCallLogService.failCallLog(
                callLogId,
                "FAILED",
                exception.getCategory().name(),
                exception.getStage(),
                exception.getMessage(),
                System.currentTimeMillis() - startedAt,
                resolveFirstTokenLatency(startedAt, firstTokenAt),
                toolSuccessCount,
                toolFailureCount
        );
        emitter.completeWithError(exception);
    }
@@ -402,6 +494,10 @@
    private OpenAiChatOptions buildChatOptions(AiParam aiParam, ToolCallback[] toolCallbacks, Long userId, Long tenantId,
                                               String requestId, Long sessionId, Map<String, Object> metadata) {
        /**
         * 组装一次聊天调用的全部模型选项和 Tool Context。
         * Tool Context 会透传给内置工具和外部 MCP,保证工具在租户和会话范围内执行。
         */
        if (userId == null) {
            throw buildAiException("AI_AUTH_USER_MISSING", AiErrorCategory.AUTH, "OPTIONS_BUILD", "当前登录用户不存在", null);
        }
@@ -436,7 +532,10 @@
    }
    private ToolCallback[] wrapToolCallbacks(ToolCallback[] toolCallbacks, SseEmitter emitter, String requestId,
                                             Long sessionId, AtomicLong toolCallSequence) {
                                             Long sessionId, AtomicLong toolCallSequence,
                                             AtomicLong toolSuccessCount, AtomicLong toolFailureCount,
                                             Long callLogId, Long userId, Long tenantId) {
        /** 给所有工具回调套上一层可观测包装,用于实时 SSE 轨迹和审计日志落库。 */
        if (Cools.isEmpty(toolCallbacks)) {
            return toolCallbacks;
        }
@@ -445,12 +544,17 @@
            if (callback == null) {
                continue;
            }
            wrappedCallbacks.add(new ObservableToolCallback(callback, emitter, requestId, sessionId, toolCallSequence));
            wrappedCallbacks.add(new ObservableToolCallback(callback, emitter, requestId, sessionId, toolCallSequence,
                    toolSuccessCount, toolFailureCount, callLogId, userId, tenantId));
        }
        return wrappedCallbacks.toArray(new ToolCallback[0]);
    }
    private List<Message> buildPromptMessages(AiChatMemoryDto memory, List<AiChatMessageDto> sourceMessages, AiPrompt aiPrompt, Map<String, Object> metadata) {
        /**
         * 组装最终提交给模型的消息列表。
         * 顺序上始终是:系统 Prompt -> 历史摘要 -> 关键事实 -> 最近对话 -> 当前用户输入。
         */
        if (Cools.isEmpty(sourceMessages)) {
            throw new CoolException("对话消息不能为空");
        }
@@ -497,6 +601,7 @@
    }
    private List<AiChatMessageDto> mergeMessages(List<AiChatMessageDto> persistedMessages, List<AiChatMessageDto> memoryMessages) {
        /** 把落库历史与本轮前端内存增量合并成模型可消费的完整上下文。 */
        List<AiChatMessageDto> merged = new ArrayList<>();
        if (!Cools.isEmpty(persistedMessages)) {
            merged.addAll(persistedMessages);
@@ -562,6 +667,7 @@
    }
    private void emitDone(SseEmitter emitter, String requestId, ChatResponseMetadata metadata, String fallbackModel, Long sessionId, long startedAt, Long firstTokenAt) {
        /** 输出对话完成事件,统一封装耗时、首包延迟和 token 用量。 */
        Usage usage = metadata == null ? null : metadata.getUsage();
        emitStrict(emitter, "done", AiChatDoneDto.builder()
                .requestId(requestId)
@@ -590,6 +696,7 @@
    }
    private void emitStrict(SseEmitter emitter, String eventName, Object payload) {
        /** 严格发送 SSE 事件;一旦发送失败,直接上抛为流式输出异常。 */
        try {
            String data = objectMapper.writeValueAsString(payload);
            emitter.send(SseEmitter.event()
@@ -601,6 +708,7 @@
    }
    private void emitSafely(SseEmitter emitter, String eventName, Object payload) {
        /** 尝试发送非关键事件,发送失败只记录日志,不打断主对话流程。 */
        try {
            emitStrict(emitter, eventName, payload);
        } catch (Exception e) {
@@ -637,14 +745,26 @@
        private final String requestId;
        private final Long sessionId;
        private final AtomicLong toolCallSequence;
        private final AtomicLong toolSuccessCount;
        private final AtomicLong toolFailureCount;
        private final Long callLogId;
        private final Long userId;
        private final Long tenantId;
        private ObservableToolCallback(ToolCallback delegate, SseEmitter emitter, String requestId,
                                       Long sessionId, AtomicLong toolCallSequence) {
                                       Long sessionId, AtomicLong toolCallSequence,
                                       AtomicLong toolSuccessCount, AtomicLong toolFailureCount,
                                       Long callLogId, Long userId, Long tenantId) {
            this.delegate = delegate;
            this.emitter = emitter;
            this.requestId = requestId;
            this.sessionId = sessionId;
            this.toolCallSequence = toolCallSequence;
            this.toolSuccessCount = toolSuccessCount;
            this.toolFailureCount = toolFailureCount;
            this.callLogId = callLogId;
            this.userId = userId;
            this.tenantId = tenantId;
        }
        @Override
@@ -664,7 +784,13 @@
        @Override
        public String call(String toolInput, ToolContext toolContext) {
            /**
             * 工具执行观测包装器。
             * 在真实调用前后分别发送 tool_start / tool_result / tool_error,
             * 同时把调用摘要写入 MCP 调用日志表。
             */
            String toolName = delegate.getToolDefinition() == null ? "unknown" : delegate.getToolDefinition().name();
            String mountName = delegate instanceof MountedToolCallback ? ((MountedToolCallback) delegate).getMountName() : null;
            String toolCallId = requestId + "-tool-" + toolCallSequence.incrementAndGet();
            long startedAt = System.currentTimeMillis();
            emitSafely(emitter, "tool_start", AiChatToolEventDto.builder()
@@ -672,36 +798,49 @@
                    .sessionId(sessionId)
                    .toolCallId(toolCallId)
                    .toolName(toolName)
                    .mountName(mountName)
                    .status("STARTED")
                    .inputSummary(summarizeToolPayload(toolInput, 400))
                    .timestamp(startedAt)
                    .build());
            try {
                String output = toolContext == null ? delegate.call(toolInput) : delegate.call(toolInput, toolContext);
                long durationMs = System.currentTimeMillis() - startedAt;
                emitSafely(emitter, "tool_result", AiChatToolEventDto.builder()
                        .requestId(requestId)
                        .sessionId(sessionId)
                        .toolCallId(toolCallId)
                        .toolName(toolName)
                        .mountName(mountName)
                        .status("COMPLETED")
                        .inputSummary(summarizeToolPayload(toolInput, 400))
                        .outputSummary(summarizeToolPayload(output, 600))
                        .durationMs(System.currentTimeMillis() - startedAt)
                        .durationMs(durationMs)
                        .timestamp(System.currentTimeMillis())
                        .build());
                toolSuccessCount.incrementAndGet();
                aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
                        "COMPLETED", summarizeToolPayload(toolInput, 400), summarizeToolPayload(output, 600),
                        null, durationMs, userId, tenantId);
                return output;
            } catch (RuntimeException e) {
                long durationMs = System.currentTimeMillis() - startedAt;
                emitSafely(emitter, "tool_error", AiChatToolEventDto.builder()
                        .requestId(requestId)
                        .sessionId(sessionId)
                        .toolCallId(toolCallId)
                        .toolName(toolName)
                        .mountName(mountName)
                        .status("FAILED")
                        .inputSummary(summarizeToolPayload(toolInput, 400))
                        .errorMessage(e.getMessage())
                        .durationMs(System.currentTimeMillis() - startedAt)
                        .durationMs(durationMs)
                        .timestamp(System.currentTimeMillis())
                        .build());
                toolFailureCount.incrementAndGet();
                aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
                        "FAILED", summarizeToolPayload(toolInput, 400), null, e.getMessage(),
                        durationMs, userId, tenantId);
                throw e;
            }
        }