zhou zhou
14 小时以前 d5884d0974d17d96225a5d80e432de33a5ee6552
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java
@@ -16,14 +16,17 @@
import com.vincent.rsf.server.ai.dto.AiChatSessionRenameRequest;
import com.vincent.rsf.server.ai.dto.AiChatToolEventDto;
import com.vincent.rsf.server.ai.dto.AiResolvedConfig;
import com.vincent.rsf.server.ai.entity.AiCallLog;
import com.vincent.rsf.server.ai.entity.AiParam;
import com.vincent.rsf.server.ai.entity.AiPrompt;
import com.vincent.rsf.server.ai.entity.AiChatSession;
import com.vincent.rsf.server.ai.enums.AiErrorCategory;
import com.vincent.rsf.server.ai.exception.AiChatException;
import com.vincent.rsf.server.ai.service.AiCallLogService;
import com.vincent.rsf.server.ai.service.AiChatService;
import com.vincent.rsf.server.ai.service.AiChatMemoryService;
import com.vincent.rsf.server.ai.service.AiConfigResolverService;
import com.vincent.rsf.server.ai.service.MountedToolCallback;
import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory;
import io.micrometer.observation.ObservationRegistry;
import lombok.RequiredArgsConstructor;
@@ -78,6 +81,7 @@
    private final AiConfigResolverService aiConfigResolverService;
    private final AiChatMemoryService aiChatMemoryService;
    private final McpMountRuntimeFactory mcpMountRuntimeFactory;
    private final AiCallLogService aiCallLogService;
    private final GenericApplicationContext applicationContext;
    private final ObservationRegistry observationRegistry;
    private final ObjectMapper objectMapper;
@@ -148,7 +152,10 @@
        long startedAt = System.currentTimeMillis();
        AtomicReference<Long> firstTokenAtRef = new AtomicReference<>();
        AtomicLong toolCallSequence = new AtomicLong(0);
        AtomicLong toolSuccessCount = new AtomicLong(0);
        AtomicLong toolFailureCount = new AtomicLong(0);
        Long sessionId = request.getSessionId();
        Long callLogId = null;
        String model = null;
        try {
            ensureIdentity(userId, tenantId);
@@ -159,6 +166,19 @@
            sessionId = session.getId();
            AiChatMemoryDto memory = loadMemory(userId, tenantId, config.getPromptCode(), session.getId());
            List<AiChatMessageDto> mergedMessages = mergeMessages(memory.getShortMemoryMessages(), request.getMessages());
            AiCallLog callLog = aiCallLogService.startCallLog(
                    requestId,
                    session.getId(),
                    userId,
                    tenantId,
                    config.getPromptCode(),
                    config.getPrompt().getName(),
                    config.getAiParam().getModel(),
                    config.getMcpMounts().size(),
                    config.getMcpMounts().size(),
                    config.getMcpMounts().stream().map(item -> item.getName()).toList()
            );
            callLogId = callLog.getId();
            try (McpMountRuntimeFactory.McpMountRuntime runtime = createRuntime(config, userId)) {
                emitStrict(emitter, "start", AiChatRuntimeDto.builder()
                        .requestId(requestId)
@@ -187,7 +207,8 @@
                        requestId, userId, tenantId, session.getId(), resolvedModel);
                ToolCallback[] observableToolCallbacks = wrapToolCallbacks(
                        runtime.getToolCallbacks(), emitter, requestId, session.getId(), toolCallSequence
                        runtime.getToolCallbacks(), emitter, requestId, session.getId(), toolCallSequence,
                        toolSuccessCount, toolFailureCount, callLogId, userId, tenantId
                );
                Prompt prompt = new Prompt(
                        buildPromptMessages(memory, mergedMessages, config.getPrompt(), request.getMetadata()),
@@ -205,6 +226,17 @@
                    }
                    emitDone(emitter, requestId, response.getMetadata(), config.getAiParam().getModel(), session.getId(), startedAt, firstTokenAtRef.get());
                    emitSafely(emitter, "status", buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get()));
                    aiCallLogService.completeCallLog(
                            callLogId,
                            "COMPLETED",
                            System.currentTimeMillis() - startedAt,
                            resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()),
                            response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getPromptTokens(),
                            response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getCompletionTokens(),
                            response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getTotalTokens(),
                            toolSuccessCount.get(),
                            toolFailureCount.get()
                    );
                    log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}",
                            requestId, session.getId(), System.currentTimeMillis() - startedAt, resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()));
                    emitter.complete();
@@ -232,16 +264,29 @@
                aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), assistantContent.toString());
                emitDone(emitter, requestId, lastMetadata.get(), config.getAiParam().getModel(), session.getId(), startedAt, firstTokenAtRef.get());
                emitSafely(emitter, "status", buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get()));
                aiCallLogService.completeCallLog(
                        callLogId,
                        "COMPLETED",
                        System.currentTimeMillis() - startedAt,
                        resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()),
                        lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getPromptTokens(),
                        lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getCompletionTokens(),
                        lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getTotalTokens(),
                        toolSuccessCount.get(),
                        toolFailureCount.get()
                );
                log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}",
                        requestId, session.getId(), System.currentTimeMillis() - startedAt, resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()));
                emitter.complete();
            }
        } catch (AiChatException e) {
            handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(), e);
            handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(), e,
                    callLogId, toolSuccessCount.get(), toolFailureCount.get());
        } catch (Exception e) {
            handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(),
                    buildAiException("AI_INTERNAL_ERROR", AiErrorCategory.INTERNAL, "INTERNAL",
                            e == null ? "AI 对话失败" : e.getMessage(), e));
                            e == null ? "AI 对话失败" : e.getMessage(), e),
                    callLogId, toolSuccessCount.get(), toolFailureCount.get());
        } finally {
            log.debug("AI chat stream finished, requestId={}", requestId);
        }
@@ -341,11 +386,24 @@
        return firstTokenAt == null ? null : Math.max(0L, firstTokenAt - startedAt);
    }
    private void handleStreamFailure(SseEmitter emitter, String requestId, Long sessionId, String model, long startedAt, Long firstTokenAt, AiChatException exception) {
    private void handleStreamFailure(SseEmitter emitter, String requestId, Long sessionId, String model, long startedAt,
                                     Long firstTokenAt, AiChatException exception, Long callLogId,
                                     long toolSuccessCount, long toolFailureCount) {
        if (isClientAbortException(exception)) {
            log.warn("AI chat aborted by client, requestId={}, sessionId={}, stage={}, message={}",
                    requestId, sessionId, exception.getStage(), exception.getMessage());
            emitSafely(emitter, "status", buildTerminalStatus(requestId, sessionId, "ABORTED", model, startedAt, firstTokenAt));
            aiCallLogService.failCallLog(
                    callLogId,
                    "ABORTED",
                    exception.getCategory().name(),
                    exception.getStage(),
                    exception.getMessage(),
                    System.currentTimeMillis() - startedAt,
                    resolveFirstTokenLatency(startedAt, firstTokenAt),
                    toolSuccessCount,
                    toolFailureCount
            );
            emitter.completeWithError(exception);
            return;
        }
@@ -361,6 +419,17 @@
                .message(exception.getMessage())
                .timestamp(Instant.now().toEpochMilli())
                .build());
        aiCallLogService.failCallLog(
                callLogId,
                "FAILED",
                exception.getCategory().name(),
                exception.getStage(),
                exception.getMessage(),
                System.currentTimeMillis() - startedAt,
                resolveFirstTokenLatency(startedAt, firstTokenAt),
                toolSuccessCount,
                toolFailureCount
        );
        emitter.completeWithError(exception);
    }
@@ -436,7 +505,9 @@
    }
    private ToolCallback[] wrapToolCallbacks(ToolCallback[] toolCallbacks, SseEmitter emitter, String requestId,
                                             Long sessionId, AtomicLong toolCallSequence) {
                                             Long sessionId, AtomicLong toolCallSequence,
                                             AtomicLong toolSuccessCount, AtomicLong toolFailureCount,
                                             Long callLogId, Long userId, Long tenantId) {
        if (Cools.isEmpty(toolCallbacks)) {
            return toolCallbacks;
        }
@@ -445,7 +516,8 @@
            if (callback == null) {
                continue;
            }
            wrappedCallbacks.add(new ObservableToolCallback(callback, emitter, requestId, sessionId, toolCallSequence));
            wrappedCallbacks.add(new ObservableToolCallback(callback, emitter, requestId, sessionId, toolCallSequence,
                    toolSuccessCount, toolFailureCount, callLogId, userId, tenantId));
        }
        return wrappedCallbacks.toArray(new ToolCallback[0]);
    }
@@ -637,14 +709,26 @@
        private final String requestId;
        private final Long sessionId;
        private final AtomicLong toolCallSequence;
        private final AtomicLong toolSuccessCount;
        private final AtomicLong toolFailureCount;
        private final Long callLogId;
        private final Long userId;
        private final Long tenantId;
        private ObservableToolCallback(ToolCallback delegate, SseEmitter emitter, String requestId,
                                       Long sessionId, AtomicLong toolCallSequence) {
                                       Long sessionId, AtomicLong toolCallSequence,
                                       AtomicLong toolSuccessCount, AtomicLong toolFailureCount,
                                       Long callLogId, Long userId, Long tenantId) {
            this.delegate = delegate;
            this.emitter = emitter;
            this.requestId = requestId;
            this.sessionId = sessionId;
            this.toolCallSequence = toolCallSequence;
            this.toolSuccessCount = toolSuccessCount;
            this.toolFailureCount = toolFailureCount;
            this.callLogId = callLogId;
            this.userId = userId;
            this.tenantId = tenantId;
        }
        @Override
@@ -665,6 +749,7 @@
        @Override
        public String call(String toolInput, ToolContext toolContext) {
            String toolName = delegate.getToolDefinition() == null ? "unknown" : delegate.getToolDefinition().name();
            String mountName = delegate instanceof MountedToolCallback ? ((MountedToolCallback) delegate).getMountName() : null;
            String toolCallId = requestId + "-tool-" + toolCallSequence.incrementAndGet();
            long startedAt = System.currentTimeMillis();
            emitSafely(emitter, "tool_start", AiChatToolEventDto.builder()
@@ -672,36 +757,49 @@
                    .sessionId(sessionId)
                    .toolCallId(toolCallId)
                    .toolName(toolName)
                    .mountName(mountName)
                    .status("STARTED")
                    .inputSummary(summarizeToolPayload(toolInput, 400))
                    .timestamp(startedAt)
                    .build());
            try {
                String output = toolContext == null ? delegate.call(toolInput) : delegate.call(toolInput, toolContext);
                long durationMs = System.currentTimeMillis() - startedAt;
                emitSafely(emitter, "tool_result", AiChatToolEventDto.builder()
                        .requestId(requestId)
                        .sessionId(sessionId)
                        .toolCallId(toolCallId)
                        .toolName(toolName)
                        .mountName(mountName)
                        .status("COMPLETED")
                        .inputSummary(summarizeToolPayload(toolInput, 400))
                        .outputSummary(summarizeToolPayload(output, 600))
                        .durationMs(System.currentTimeMillis() - startedAt)
                        .durationMs(durationMs)
                        .timestamp(System.currentTimeMillis())
                        .build());
                toolSuccessCount.incrementAndGet();
                aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
                        "COMPLETED", summarizeToolPayload(toolInput, 400), summarizeToolPayload(output, 600),
                        null, durationMs, userId, tenantId);
                return output;
            } catch (RuntimeException e) {
                long durationMs = System.currentTimeMillis() - startedAt;
                emitSafely(emitter, "tool_error", AiChatToolEventDto.builder()
                        .requestId(requestId)
                        .sessionId(sessionId)
                        .toolCallId(toolCallId)
                        .toolName(toolName)
                        .mountName(mountName)
                        .status("FAILED")
                        .inputSummary(summarizeToolPayload(toolInput, 400))
                        .errorMessage(e.getMessage())
                        .durationMs(System.currentTimeMillis() - startedAt)
                        .durationMs(durationMs)
                        .timestamp(System.currentTimeMillis())
                        .build());
                toolFailureCount.incrementAndGet();
                aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
                        "FAILED", summarizeToolPayload(toolInput, 400), null, e.getMessage(),
                        durationMs, userId, tenantId);
                throw e;
            }
        }