zhou zhou
17 小时以前 4954d3978cf1967729a5a2d5b90f6baef18974da
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java
@@ -8,6 +8,7 @@
import com.vincent.rsf.server.ai.dto.AiChatErrorDto;
import com.vincent.rsf.server.ai.dto.AiChatMemoryDto;
import com.vincent.rsf.server.ai.dto.AiChatMessageDto;
import com.vincent.rsf.server.ai.dto.AiChatModelOptionDto;
import com.vincent.rsf.server.ai.dto.AiChatRequest;
import com.vincent.rsf.server.ai.dto.AiChatRuntimeDto;
import com.vincent.rsf.server.ai.dto.AiChatStatusDto;
@@ -27,6 +28,7 @@
import com.vincent.rsf.server.ai.service.AiChatService;
import com.vincent.rsf.server.ai.service.AiChatMemoryService;
import com.vincent.rsf.server.ai.service.AiConfigResolverService;
import com.vincent.rsf.server.ai.service.AiParamService;
import com.vincent.rsf.server.ai.service.MountedToolCallback;
import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory;
import io.micrometer.observation.ObservationRegistry;
@@ -78,8 +80,10 @@
    private final AiConfigResolverService aiConfigResolverService;
    private final AiChatMemoryService aiChatMemoryService;
    private final AiParamService aiParamService;
    private final McpMountRuntimeFactory mcpMountRuntimeFactory;
    private final AiCallLogService aiCallLogService;
    private final AiRedisSupport aiRedisSupport;
    private final GenericApplicationContext applicationContext;
    private final ObservationRegistry observationRegistry;
    private final ObjectMapper objectMapper;
@@ -91,24 +95,31 @@
     * 该方法不会触发模型调用,而是把配置解析结果和会话记忆聚合成前端一次渲染所需的快照。
     */
    @Override
    public AiChatRuntimeDto getRuntime(String promptCode, Long sessionId, Long userId, Long tenantId) {
        AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId);
    public AiChatRuntimeDto getRuntime(String promptCode, Long sessionId, Long aiParamId, Long userId, Long tenantId) {
        AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId, aiParamId);
        Long runtimeCacheAiParamId = aiParamId;
        // runtime 是配置快照和会话记忆的聚合视图,单独缓存能减少一次页面进入时的重复拼装。
        AiChatRuntimeDto cached = aiRedisSupport.getRuntime(tenantId, userId, config.getPromptCode(), sessionId, runtimeCacheAiParamId);
        if (cached != null) {
            return cached;
        }
        AiChatMemoryDto memory = aiChatMemoryService.getMemory(userId, tenantId, config.getPromptCode(), sessionId);
        return AiChatRuntimeDto.builder()
                .requestId(null)
                .sessionId(memory.getSessionId())
                .promptCode(config.getPromptCode())
                .promptName(config.getPrompt().getName())
                .model(config.getAiParam().getModel())
                .configuredMcpCount(config.getMcpMounts().size())
                .mountedMcpCount(config.getMcpMounts().size())
                .mountedMcpNames(config.getMcpMounts().stream().map(item -> item.getName()).toList())
                .mountErrors(List.of())
                .memorySummary(memory.getMemorySummary())
                .memoryFacts(memory.getMemoryFacts())
                .recentMessageCount(memory.getRecentMessageCount())
                .persistedMessages(memory.getPersistedMessages())
                .build();
        List<AiChatModelOptionDto> modelOptions = aiParamService.listChatModelOptions(tenantId);
        AiChatRuntimeDto runtime = buildRuntimeSnapshot(
                null,
                memory.getSessionId(),
                config,
                modelOptions,
                config.getMcpMounts().size(),
                config.getMcpMounts().stream().map(item -> item.getName()).toList(),
                List.of(),
                memory
        );
        aiRedisSupport.cacheRuntime(tenantId, userId, config.getPromptCode(), sessionId, runtimeCacheAiParamId, runtime);
        if (memory.getSessionId() != null && !Objects.equals(memory.getSessionId(), sessionId)) {
            aiRedisSupport.cacheRuntime(tenantId, userId, config.getPromptCode(), memory.getSessionId(), runtimeCacheAiParamId, runtime);
        }
        return runtime;
    }
    /**
@@ -174,14 +185,23 @@
        Long sessionId = request.getSessionId();
        Long callLogId = null;
        String model = null;
        String resolvedPromptCode = request.getPromptCode();
        ThinkingTraceEmitter thinkingTraceEmitter = null;
        try {
            ensureIdentity(userId, tenantId);
            AiResolvedConfig config = resolveConfig(request, tenantId);
            List<AiChatModelOptionDto> modelOptions = aiParamService.listChatModelOptions(tenantId);
            resolvedPromptCode = config.getPromptCode();
            if (!aiRedisSupport.allowChatRequest(tenantId, userId, config.getPromptCode())) {
                throw buildAiException("AI_RATE_LIMITED", AiErrorCategory.REQUEST, "RATE_LIMIT",
                        "当前提问过于频繁,请稍后再试", null);
            }
            final String resolvedModel = config.getAiParam().getModel();
            model = resolvedModel;
            AiChatSession session = resolveSession(request, userId, tenantId, config.getPromptCode());
            sessionId = session.getId();
            // 流状态落 Redis 的目标是给多实例和后续运维查询留统一入口,不替代数据库日志。
            aiRedisSupport.markStreamState(requestId, tenantId, userId, sessionId, config.getPromptCode(), "RUNNING", null);
            AiChatMemoryDto memory = loadMemory(userId, tenantId, config.getPromptCode(), session.getId());
            List<AiChatMessageDto> mergedMessages = mergeMessages(memory.getShortMemoryMessages(), request.getMessages());
            AiCallLog callLog = aiCallLogService.startCallLog(
@@ -198,21 +218,16 @@
            );
            callLogId = callLog.getId();
            try (McpMountRuntimeFactory.McpMountRuntime runtime = createRuntime(config, userId)) {
                emitStrict(emitter, "start", AiChatRuntimeDto.builder()
                        .requestId(requestId)
                        .sessionId(session.getId())
                        .promptCode(config.getPromptCode())
                        .promptName(config.getPrompt().getName())
                        .model(config.getAiParam().getModel())
                        .configuredMcpCount(config.getMcpMounts().size())
                        .mountedMcpCount(runtime.getMountedCount())
                        .mountedMcpNames(runtime.getMountedNames())
                        .mountErrors(runtime.getErrors())
                        .memorySummary(memory.getMemorySummary())
                        .memoryFacts(memory.getMemoryFacts())
                        .recentMessageCount(memory.getRecentMessageCount())
                        .persistedMessages(memory.getPersistedMessages())
                        .build());
                emitStrict(emitter, "start", buildRuntimeSnapshot(
                        requestId,
                        session.getId(),
                        config,
                        modelOptions,
                        runtime.getMountedCount(),
                        runtime.getMountedNames(),
                        runtime.getErrors(),
                        memory
                ));
                emitSafely(emitter, "status", AiChatStatusDto.builder()
                        .requestId(requestId)
                        .sessionId(session.getId())
@@ -259,6 +274,7 @@
                            toolSuccessCount.get(),
                            toolFailureCount.get()
                    );
                    aiRedisSupport.markStreamState(requestId, tenantId, userId, session.getId(), config.getPromptCode(), "COMPLETED", null);
                    log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}",
                            requestId, session.getId(), System.currentTimeMillis() - startedAt, resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()));
                    emitter.complete();
@@ -298,18 +314,21 @@
                        toolSuccessCount.get(),
                        toolFailureCount.get()
                );
                aiRedisSupport.markStreamState(requestId, tenantId, userId, session.getId(), config.getPromptCode(), "COMPLETED", null);
                log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}",
                        requestId, session.getId(), System.currentTimeMillis() - startedAt, resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()));
                emitter.complete();
            }
        } catch (AiChatException e) {
            handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(), e,
                    callLogId, toolSuccessCount.get(), toolFailureCount.get(), thinkingTraceEmitter);
                    callLogId, toolSuccessCount.get(), toolFailureCount.get(), thinkingTraceEmitter,
                    tenantId, userId, resolvedPromptCode);
        } catch (Exception e) {
            handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(),
                    buildAiException("AI_INTERNAL_ERROR", AiErrorCategory.INTERNAL, "INTERNAL",
                            e == null ? "AI 对话失败" : e.getMessage(), e),
                    callLogId, toolSuccessCount.get(), toolFailureCount.get(), thinkingTraceEmitter);
                    callLogId, toolSuccessCount.get(), toolFailureCount.get(), thinkingTraceEmitter,
                    tenantId, userId, resolvedPromptCode);
        } finally {
            log.debug("AI chat stream finished, requestId={}", requestId);
        }
@@ -327,11 +346,34 @@
    private AiResolvedConfig resolveConfig(AiChatRequest request, Long tenantId) {
        /** 把请求里的 Prompt 场景解析成一份可直接执行的 AI 配置。 */
        try {
            return aiConfigResolverService.resolve(request.getPromptCode(), tenantId);
            return aiConfigResolverService.resolve(request.getPromptCode(), tenantId, request.getAiParamId());
        } catch (Exception e) {
            throw buildAiException("AI_CONFIG_RESOLVE_ERROR", AiErrorCategory.CONFIG, "CONFIG_RESOLVE",
                    e == null ? "AI 配置解析失败" : e.getMessage(), e);
        }
    }
    private AiChatRuntimeDto buildRuntimeSnapshot(String requestId, Long sessionId, AiResolvedConfig config,
                                                  List<AiChatModelOptionDto> modelOptions, Integer mountedMcpCount,
                                                  List<String> mountedMcpNames, List<String> mountErrors,
                                                  AiChatMemoryDto memory) {
        return AiChatRuntimeDto.builder()
                .requestId(requestId)
                .sessionId(sessionId)
                .aiParamId(config.getAiParam().getId())
                .promptCode(config.getPromptCode())
                .promptName(config.getPrompt().getName())
                .model(config.getAiParam().getModel())
                .modelOptions(modelOptions)
                .configuredMcpCount(config.getMcpMounts().size())
                .mountedMcpCount(mountedMcpCount)
                .mountedMcpNames(mountedMcpNames)
                .mountErrors(mountErrors)
                .memorySummary(memory.getMemorySummary())
                .memoryFacts(memory.getMemoryFacts())
                .recentMessageCount(memory.getRecentMessageCount())
                .persistedMessages(memory.getPersistedMessages())
                .build();
    }
    private AiChatSession resolveSession(AiChatRequest request, Long userId, Long tenantId, String promptCode) {
@@ -420,7 +462,8 @@
    private void handleStreamFailure(SseEmitter emitter, String requestId, Long sessionId, String model, long startedAt,
                                     Long firstTokenAt, AiChatException exception, Long callLogId,
                                     long toolSuccessCount, long toolFailureCount,
                                     ThinkingTraceEmitter thinkingTraceEmitter) {
                                     ThinkingTraceEmitter thinkingTraceEmitter,
                                     Long tenantId, Long userId, String promptCode) {
        if (isClientAbortException(exception)) {
            log.warn("AI chat aborted by client, requestId={}, sessionId={}, stage={}, message={}",
                    requestId, sessionId, exception.getStage(), exception.getMessage());
@@ -439,6 +482,7 @@
                    toolSuccessCount,
                    toolFailureCount
            );
            aiRedisSupport.markStreamState(requestId, tenantId, userId, sessionId, promptCode, "ABORTED", exception.getMessage());
            emitter.completeWithError(exception);
            return;
        }
@@ -468,6 +512,7 @@
                toolSuccessCount,
                toolFailureCount
        );
        aiRedisSupport.markStreamState(requestId, tenantId, userId, sessionId, promptCode, "FAILED", exception.getMessage());
        emitter.completeWithError(exception);
    }
@@ -921,6 +966,38 @@
            String mountName = delegate instanceof MountedToolCallback ? ((MountedToolCallback) delegate).getMountName() : null;
            String toolCallId = requestId + "-tool-" + toolCallSequence.incrementAndGet();
            long startedAt = System.currentTimeMillis();
            // 这里只对同一 request 内的重复工具调用做短期复用,避免把跨请求结果误当成通用缓存。
            AiRedisSupport.CachedToolResult cachedToolResult = aiRedisSupport.getToolResult(tenantId, requestId, toolName, toolInput);
            if (cachedToolResult != null) {
                emitSafely(emitter, "tool_result", AiChatToolEventDto.builder()
                        .requestId(requestId)
                        .sessionId(sessionId)
                        .toolCallId(toolCallId)
                        .toolName(toolName)
                        .mountName(mountName)
                        .status(cachedToolResult.isSuccess() ? "COMPLETED" : "FAILED")
                        .inputSummary(summarizeToolPayload(toolInput, 400))
                        .outputSummary(summarizeToolPayload(cachedToolResult.getOutput(), 600))
                        .errorMessage(cachedToolResult.getErrorMessage())
                        .durationMs(0L)
                        .timestamp(System.currentTimeMillis())
                        .build());
                if (thinkingTraceEmitter != null) {
                    thinkingTraceEmitter.onToolResult(toolName, toolCallId, !cachedToolResult.isSuccess());
                }
                if (cachedToolResult.isSuccess()) {
                    toolSuccessCount.incrementAndGet();
                    aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
                            "COMPLETED", summarizeToolPayload(toolInput, 400), summarizeToolPayload(cachedToolResult.getOutput(), 600),
                            null, 0L, userId, tenantId);
                    return cachedToolResult.getOutput();
                }
                toolFailureCount.incrementAndGet();
                aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
                        "FAILED", summarizeToolPayload(toolInput, 400), null, cachedToolResult.getErrorMessage(),
                        0L, userId, tenantId);
                throw new CoolException(cachedToolResult.getErrorMessage());
            }
            if (thinkingTraceEmitter != null) {
                thinkingTraceEmitter.onToolStart(toolName, toolCallId);
            }
@@ -952,6 +1029,7 @@
                if (thinkingTraceEmitter != null) {
                    thinkingTraceEmitter.onToolResult(toolName, toolCallId, false);
                }
                aiRedisSupport.cacheToolResult(tenantId, requestId, toolName, toolInput, true, output, null);
                toolSuccessCount.incrementAndGet();
                aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
                        "COMPLETED", summarizeToolPayload(toolInput, 400), summarizeToolPayload(output, 600),
@@ -974,6 +1052,7 @@
                if (thinkingTraceEmitter != null) {
                    thinkingTraceEmitter.onToolResult(toolName, toolCallId, true);
                }
                aiRedisSupport.cacheToolResult(tenantId, requestId, toolName, toolInput, false, null, e.getMessage());
                toolFailureCount.incrementAndGet();
                aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
                        "FAILED", summarizeToolPayload(toolInput, 400), null, e.getMessage(),