| | |
| | | import com.vincent.rsf.server.ai.dto.AiChatErrorDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatMemoryDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatMessageDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatModelOptionDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatRequest; |
| | | import com.vincent.rsf.server.ai.dto.AiChatRuntimeDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatStatusDto; |
| | |
| | | import com.vincent.rsf.server.ai.service.AiChatService; |
| | | import com.vincent.rsf.server.ai.service.AiChatMemoryService; |
| | | import com.vincent.rsf.server.ai.service.AiConfigResolverService; |
| | | import com.vincent.rsf.server.ai.service.AiParamService; |
| | | import com.vincent.rsf.server.ai.service.MountedToolCallback; |
| | | import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory; |
| | | import io.micrometer.observation.ObservationRegistry; |
| | |
| | | |
| | | private final AiConfigResolverService aiConfigResolverService; |
| | | private final AiChatMemoryService aiChatMemoryService; |
| | | private final AiParamService aiParamService; |
| | | private final McpMountRuntimeFactory mcpMountRuntimeFactory; |
| | | private final AiCallLogService aiCallLogService; |
| | | private final AiRedisSupport aiRedisSupport; |
| | | private final GenericApplicationContext applicationContext; |
| | | private final ObservationRegistry observationRegistry; |
| | | private final ObjectMapper objectMapper; |
| | |
| | | * 该方法不会触发模型调用,而是把配置解析结果和会话记忆聚合成前端一次渲染所需的快照。 |
| | | */ |
| | | @Override |
| | | public AiChatRuntimeDto getRuntime(String promptCode, Long sessionId, Long userId, Long tenantId) { |
| | | AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId); |
| | | public AiChatRuntimeDto getRuntime(String promptCode, Long sessionId, Long aiParamId, Long userId, Long tenantId) { |
| | | AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId, aiParamId); |
| | | Long runtimeCacheAiParamId = aiParamId; |
| | | // runtime 是配置快照和会话记忆的聚合视图,单独缓存能减少一次页面进入时的重复拼装。 |
| | | AiChatRuntimeDto cached = aiRedisSupport.getRuntime(tenantId, userId, config.getPromptCode(), sessionId, runtimeCacheAiParamId); |
| | | if (cached != null) { |
| | | return cached; |
| | | } |
| | | AiChatMemoryDto memory = aiChatMemoryService.getMemory(userId, tenantId, config.getPromptCode(), sessionId); |
| | | return AiChatRuntimeDto.builder() |
| | | .requestId(null) |
| | | .sessionId(memory.getSessionId()) |
| | | .promptCode(config.getPromptCode()) |
| | | .promptName(config.getPrompt().getName()) |
| | | .model(config.getAiParam().getModel()) |
| | | .configuredMcpCount(config.getMcpMounts().size()) |
| | | .mountedMcpCount(config.getMcpMounts().size()) |
| | | .mountedMcpNames(config.getMcpMounts().stream().map(item -> item.getName()).toList()) |
| | | .mountErrors(List.of()) |
| | | .memorySummary(memory.getMemorySummary()) |
| | | .memoryFacts(memory.getMemoryFacts()) |
| | | .recentMessageCount(memory.getRecentMessageCount()) |
| | | .persistedMessages(memory.getPersistedMessages()) |
| | | .build(); |
| | | List<AiChatModelOptionDto> modelOptions = aiParamService.listChatModelOptions(tenantId); |
| | | AiChatRuntimeDto runtime = buildRuntimeSnapshot( |
| | | null, |
| | | memory.getSessionId(), |
| | | config, |
| | | modelOptions, |
| | | config.getMcpMounts().size(), |
| | | config.getMcpMounts().stream().map(item -> item.getName()).toList(), |
| | | List.of(), |
| | | memory |
| | | ); |
| | | aiRedisSupport.cacheRuntime(tenantId, userId, config.getPromptCode(), sessionId, runtimeCacheAiParamId, runtime); |
| | | if (memory.getSessionId() != null && !Objects.equals(memory.getSessionId(), sessionId)) { |
| | | aiRedisSupport.cacheRuntime(tenantId, userId, config.getPromptCode(), memory.getSessionId(), runtimeCacheAiParamId, runtime); |
| | | } |
| | | return runtime; |
| | | } |
| | | |
| | | /** |
| | |
| | | Long sessionId = request.getSessionId(); |
| | | Long callLogId = null; |
| | | String model = null; |
| | | String resolvedPromptCode = request.getPromptCode(); |
| | | ThinkingTraceEmitter thinkingTraceEmitter = null; |
| | | try { |
| | | ensureIdentity(userId, tenantId); |
| | | AiResolvedConfig config = resolveConfig(request, tenantId); |
| | | List<AiChatModelOptionDto> modelOptions = aiParamService.listChatModelOptions(tenantId); |
| | | resolvedPromptCode = config.getPromptCode(); |
| | | if (!aiRedisSupport.allowChatRequest(tenantId, userId, config.getPromptCode())) { |
| | | throw buildAiException("AI_RATE_LIMITED", AiErrorCategory.REQUEST, "RATE_LIMIT", |
| | | "当前提问过于频繁,请稍后再试", null); |
| | | } |
| | | final String resolvedModel = config.getAiParam().getModel(); |
| | | model = resolvedModel; |
| | | AiChatSession session = resolveSession(request, userId, tenantId, config.getPromptCode()); |
| | | sessionId = session.getId(); |
| | | // 流状态落 Redis 的目标是给多实例和后续运维查询留统一入口,不替代数据库日志。 |
| | | aiRedisSupport.markStreamState(requestId, tenantId, userId, sessionId, config.getPromptCode(), "RUNNING", null); |
| | | AiChatMemoryDto memory = loadMemory(userId, tenantId, config.getPromptCode(), session.getId()); |
| | | List<AiChatMessageDto> mergedMessages = mergeMessages(memory.getShortMemoryMessages(), request.getMessages()); |
| | | AiCallLog callLog = aiCallLogService.startCallLog( |
| | |
| | | ); |
| | | callLogId = callLog.getId(); |
| | | try (McpMountRuntimeFactory.McpMountRuntime runtime = createRuntime(config, userId)) { |
| | | emitStrict(emitter, "start", AiChatRuntimeDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(session.getId()) |
| | | .promptCode(config.getPromptCode()) |
| | | .promptName(config.getPrompt().getName()) |
| | | .model(config.getAiParam().getModel()) |
| | | .configuredMcpCount(config.getMcpMounts().size()) |
| | | .mountedMcpCount(runtime.getMountedCount()) |
| | | .mountedMcpNames(runtime.getMountedNames()) |
| | | .mountErrors(runtime.getErrors()) |
| | | .memorySummary(memory.getMemorySummary()) |
| | | .memoryFacts(memory.getMemoryFacts()) |
| | | .recentMessageCount(memory.getRecentMessageCount()) |
| | | .persistedMessages(memory.getPersistedMessages()) |
| | | .build()); |
| | | emitStrict(emitter, "start", buildRuntimeSnapshot( |
| | | requestId, |
| | | session.getId(), |
| | | config, |
| | | modelOptions, |
| | | runtime.getMountedCount(), |
| | | runtime.getMountedNames(), |
| | | runtime.getErrors(), |
| | | memory |
| | | )); |
| | | emitSafely(emitter, "status", AiChatStatusDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(session.getId()) |
| | |
| | | toolSuccessCount.get(), |
| | | toolFailureCount.get() |
| | | ); |
| | | aiRedisSupport.markStreamState(requestId, tenantId, userId, session.getId(), config.getPromptCode(), "COMPLETED", null); |
| | | log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}", |
| | | requestId, session.getId(), System.currentTimeMillis() - startedAt, resolveFirstTokenLatency(startedAt, firstTokenAtRef.get())); |
| | | emitter.complete(); |
| | |
| | | toolSuccessCount.get(), |
| | | toolFailureCount.get() |
| | | ); |
| | | aiRedisSupport.markStreamState(requestId, tenantId, userId, session.getId(), config.getPromptCode(), "COMPLETED", null); |
| | | log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}", |
| | | requestId, session.getId(), System.currentTimeMillis() - startedAt, resolveFirstTokenLatency(startedAt, firstTokenAtRef.get())); |
| | | emitter.complete(); |
| | | } |
| | | } catch (AiChatException e) { |
| | | handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(), e, |
| | | callLogId, toolSuccessCount.get(), toolFailureCount.get(), thinkingTraceEmitter); |
| | | callLogId, toolSuccessCount.get(), toolFailureCount.get(), thinkingTraceEmitter, |
| | | tenantId, userId, resolvedPromptCode); |
| | | } catch (Exception e) { |
| | | handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(), |
| | | buildAiException("AI_INTERNAL_ERROR", AiErrorCategory.INTERNAL, "INTERNAL", |
| | | e == null ? "AI 对话失败" : e.getMessage(), e), |
| | | callLogId, toolSuccessCount.get(), toolFailureCount.get(), thinkingTraceEmitter); |
| | | callLogId, toolSuccessCount.get(), toolFailureCount.get(), thinkingTraceEmitter, |
| | | tenantId, userId, resolvedPromptCode); |
| | | } finally { |
| | | log.debug("AI chat stream finished, requestId={}", requestId); |
| | | } |
| | |
| | | private AiResolvedConfig resolveConfig(AiChatRequest request, Long tenantId) { |
| | | /** 把请求里的 Prompt 场景解析成一份可直接执行的 AI 配置。 */ |
| | | try { |
| | | return aiConfigResolverService.resolve(request.getPromptCode(), tenantId); |
| | | return aiConfigResolverService.resolve(request.getPromptCode(), tenantId, request.getAiParamId()); |
| | | } catch (Exception e) { |
| | | throw buildAiException("AI_CONFIG_RESOLVE_ERROR", AiErrorCategory.CONFIG, "CONFIG_RESOLVE", |
| | | e == null ? "AI 配置解析失败" : e.getMessage(), e); |
| | | } |
| | | } |
| | | |
| | | private AiChatRuntimeDto buildRuntimeSnapshot(String requestId, Long sessionId, AiResolvedConfig config, |
| | | List<AiChatModelOptionDto> modelOptions, Integer mountedMcpCount, |
| | | List<String> mountedMcpNames, List<String> mountErrors, |
| | | AiChatMemoryDto memory) { |
| | | return AiChatRuntimeDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .aiParamId(config.getAiParam().getId()) |
| | | .promptCode(config.getPromptCode()) |
| | | .promptName(config.getPrompt().getName()) |
| | | .model(config.getAiParam().getModel()) |
| | | .modelOptions(modelOptions) |
| | | .configuredMcpCount(config.getMcpMounts().size()) |
| | | .mountedMcpCount(mountedMcpCount) |
| | | .mountedMcpNames(mountedMcpNames) |
| | | .mountErrors(mountErrors) |
| | | .memorySummary(memory.getMemorySummary()) |
| | | .memoryFacts(memory.getMemoryFacts()) |
| | | .recentMessageCount(memory.getRecentMessageCount()) |
| | | .persistedMessages(memory.getPersistedMessages()) |
| | | .build(); |
| | | } |
| | | |
| | | private AiChatSession resolveSession(AiChatRequest request, Long userId, Long tenantId, String promptCode) { |
| | |
| | | private void handleStreamFailure(SseEmitter emitter, String requestId, Long sessionId, String model, long startedAt, |
| | | Long firstTokenAt, AiChatException exception, Long callLogId, |
| | | long toolSuccessCount, long toolFailureCount, |
| | | ThinkingTraceEmitter thinkingTraceEmitter) { |
| | | ThinkingTraceEmitter thinkingTraceEmitter, |
| | | Long tenantId, Long userId, String promptCode) { |
| | | if (isClientAbortException(exception)) { |
| | | log.warn("AI chat aborted by client, requestId={}, sessionId={}, stage={}, message={}", |
| | | requestId, sessionId, exception.getStage(), exception.getMessage()); |
| | |
| | | toolSuccessCount, |
| | | toolFailureCount |
| | | ); |
| | | aiRedisSupport.markStreamState(requestId, tenantId, userId, sessionId, promptCode, "ABORTED", exception.getMessage()); |
| | | emitter.completeWithError(exception); |
| | | return; |
| | | } |
| | |
| | | toolSuccessCount, |
| | | toolFailureCount |
| | | ); |
| | | aiRedisSupport.markStreamState(requestId, tenantId, userId, sessionId, promptCode, "FAILED", exception.getMessage()); |
| | | emitter.completeWithError(exception); |
| | | } |
| | | |
| | |
| | | String mountName = delegate instanceof MountedToolCallback ? ((MountedToolCallback) delegate).getMountName() : null; |
| | | String toolCallId = requestId + "-tool-" + toolCallSequence.incrementAndGet(); |
| | | long startedAt = System.currentTimeMillis(); |
| | | // 这里只对同一 request 内的重复工具调用做短期复用,避免把跨请求结果误当成通用缓存。 |
| | | AiRedisSupport.CachedToolResult cachedToolResult = aiRedisSupport.getToolResult(tenantId, requestId, toolName, toolInput); |
| | | if (cachedToolResult != null) { |
| | | emitSafely(emitter, "tool_result", AiChatToolEventDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .toolCallId(toolCallId) |
| | | .toolName(toolName) |
| | | .mountName(mountName) |
| | | .status(cachedToolResult.isSuccess() ? "COMPLETED" : "FAILED") |
| | | .inputSummary(summarizeToolPayload(toolInput, 400)) |
| | | .outputSummary(summarizeToolPayload(cachedToolResult.getOutput(), 600)) |
| | | .errorMessage(cachedToolResult.getErrorMessage()) |
| | | .durationMs(0L) |
| | | .timestamp(System.currentTimeMillis()) |
| | | .build()); |
| | | if (thinkingTraceEmitter != null) { |
| | | thinkingTraceEmitter.onToolResult(toolName, toolCallId, !cachedToolResult.isSuccess()); |
| | | } |
| | | if (cachedToolResult.isSuccess()) { |
| | | toolSuccessCount.incrementAndGet(); |
| | | aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName, |
| | | "COMPLETED", summarizeToolPayload(toolInput, 400), summarizeToolPayload(cachedToolResult.getOutput(), 600), |
| | | null, 0L, userId, tenantId); |
| | | return cachedToolResult.getOutput(); |
| | | } |
| | | toolFailureCount.incrementAndGet(); |
| | | aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName, |
| | | "FAILED", summarizeToolPayload(toolInput, 400), null, cachedToolResult.getErrorMessage(), |
| | | 0L, userId, tenantId); |
| | | throw new CoolException(cachedToolResult.getErrorMessage()); |
| | | } |
| | | if (thinkingTraceEmitter != null) { |
| | | thinkingTraceEmitter.onToolStart(toolName, toolCallId); |
| | | } |
| | |
| | | if (thinkingTraceEmitter != null) { |
| | | thinkingTraceEmitter.onToolResult(toolName, toolCallId, false); |
| | | } |
| | | aiRedisSupport.cacheToolResult(tenantId, requestId, toolName, toolInput, true, output, null); |
| | | toolSuccessCount.incrementAndGet(); |
| | | aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName, |
| | | "COMPLETED", summarizeToolPayload(toolInput, 400), summarizeToolPayload(output, 600), |
| | |
| | | if (thinkingTraceEmitter != null) { |
| | | thinkingTraceEmitter.onToolResult(toolName, toolCallId, true); |
| | | } |
| | | aiRedisSupport.cacheToolResult(tenantId, requestId, toolName, toolInput, false, null, e.getMessage()); |
| | | toolFailureCount.incrementAndGet(); |
| | | aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName, |
| | | "FAILED", summarizeToolPayload(toolInput, 400), null, e.getMessage(), |