package com.vincent.rsf.server.ai.service.impl.chat; import com.vincent.rsf.server.ai.dto.AiChatMemoryDto; import com.vincent.rsf.server.ai.dto.AiChatMessageDto; import com.vincent.rsf.server.ai.dto.AiChatModelOptionDto; import com.vincent.rsf.server.ai.dto.AiChatRequest; import com.vincent.rsf.server.ai.dto.AiChatStatusDto; import com.vincent.rsf.server.ai.dto.AiResolvedConfig; import com.vincent.rsf.server.ai.entity.AiCallLog; import com.vincent.rsf.server.ai.entity.AiChatSession; import com.vincent.rsf.server.ai.enums.AiErrorCategory; import com.vincent.rsf.server.ai.exception.AiChatException; import com.vincent.rsf.server.ai.service.AiCallLogService; import com.vincent.rsf.server.ai.service.AiChatMemoryService; import com.vincent.rsf.server.ai.service.AiConfigResolverService; import com.vincent.rsf.server.ai.service.AiParamService; import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory; import com.vincent.rsf.server.ai.store.AiChatRateLimiter; import com.vincent.rsf.server.ai.store.AiStreamStateStore; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.tool.ToolCallback; import org.springframework.stereotype.Component; import org.springframework.util.StringUtils; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import reactor.core.publisher.Flux; import java.time.Instant; import java.util.List; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; @Slf4j @Component @RequiredArgsConstructor public class AiChatOrchestrator { private final AiConfigResolverService aiConfigResolverService; private final AiChatMemoryService aiChatMemoryService; private final AiParamService aiParamService; private final McpMountRuntimeFactory mcpMountRuntimeFactory; private final AiCallLogService aiCallLogService; private final AiChatRateLimiter aiChatRateLimiter; private final AiStreamStateStore aiStreamStateStore; private final AiChatRuntimeAssembler aiChatRuntimeAssembler; private final AiPromptMessageBuilder aiPromptMessageBuilder; private final AiOpenAiChatModelFactory aiOpenAiChatModelFactory; private final AiToolObservationService aiToolObservationService; private final AiSseEventPublisher aiSseEventPublisher; private final AiChatFailureHandler aiChatFailureHandler; public void executeStream(AiChatRequest request, Long userId, Long tenantId, SseEmitter emitter) { String requestId = request.getRequestId(); long startedAt = System.currentTimeMillis(); AtomicReference firstTokenAtRef = new AtomicReference<>(); AtomicLong toolCallSequence = new AtomicLong(0); AtomicLong toolSuccessCount = new AtomicLong(0); AtomicLong toolFailureCount = new AtomicLong(0); Long sessionId = request.getSessionId(); Long callLogId = null; String model = null; String resolvedPromptCode = request.getPromptCode(); AiThinkingTraceEmitter thinkingTraceEmitter = null; try { ensureIdentity(userId, tenantId); AiResolvedConfig config = resolveConfig(request, tenantId); List modelOptions = aiParamService.listChatModelOptions(tenantId); resolvedPromptCode = config.getPromptCode(); if (!aiChatRateLimiter.allowChatRequest(tenantId, userId, config.getPromptCode())) { throw aiChatFailureHandler.buildAiException("AI_RATE_LIMITED", AiErrorCategory.REQUEST, "RATE_LIMIT", "当前提问过于频繁,请稍后再试", null); } final String resolvedModel = config.getAiParam().getModel(); model = resolvedModel; AiChatSession session = resolveSession(request, userId, tenantId, config.getPromptCode()); sessionId = session.getId(); aiStreamStateStore.markStreamState(requestId, tenantId, userId, sessionId, config.getPromptCode(), "RUNNING", null); AiChatMemoryDto memory = loadMemory(userId, tenantId, config.getPromptCode(), session.getId()); List mergedMessages = aiPromptMessageBuilder.mergeMessages(memory.getShortMemoryMessages(), request.getMessages()); AiCallLog callLog = aiCallLogService.startCallLog( requestId, session.getId(), userId, tenantId, config.getPromptCode(), config.getPrompt().getName(), config.getAiParam().getModel(), config.getMcpMounts().size(), config.getMcpMounts().size(), config.getMcpMounts().stream().map(item -> item.getName()).toList() ); callLogId = callLog.getId(); try (McpMountRuntimeFactory.McpMountRuntime runtime = createRuntime(config, userId)) { aiSseEventPublisher.emitStrict(emitter, "start", aiChatRuntimeAssembler.buildRuntimeSnapshot( requestId, session.getId(), config, modelOptions, runtime.getMountedCount(), runtime.getMountedNames(), runtime.getErrors(), memory )); aiSseEventPublisher.emitSafely(emitter, "status", AiChatStatusDto.builder() .requestId(requestId) .sessionId(session.getId()) .status("STARTED") .model(resolvedModel) .timestamp(Instant.now().toEpochMilli()) .elapsedMs(0L) .build()); log.info("AI chat started, requestId={}, userId={}, tenantId={}, sessionId={}, model={}", requestId, userId, tenantId, session.getId(), resolvedModel); thinkingTraceEmitter = new AiThinkingTraceEmitter(aiSseEventPublisher, emitter, requestId, session.getId()); thinkingTraceEmitter.startAnalyze(); AiThinkingTraceEmitter activeThinkingTraceEmitter = thinkingTraceEmitter; ToolCallback[] observableToolCallbacks = aiToolObservationService.wrapToolCallbacks( runtime.getToolCallbacks(), emitter, requestId, session.getId(), toolCallSequence, toolSuccessCount, toolFailureCount, callLogId, userId, tenantId, activeThinkingTraceEmitter ); Prompt prompt = new Prompt( aiPromptMessageBuilder.buildPromptMessages(memory, mergedMessages, config.getPrompt(), request.getMetadata()), aiOpenAiChatModelFactory.buildChatOptions(config.getAiParam(), observableToolCallbacks, userId, tenantId, requestId, session.getId(), request.getMetadata()) ); OpenAiChatModel chatModel = aiOpenAiChatModelFactory.createChatModel(config.getAiParam()); if (Boolean.FALSE.equals(config.getAiParam().getStreamingEnabled())) { ChatResponse response = invokeChatCall(chatModel, prompt); String content = extractContent(response); aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), content); if (StringUtils.hasText(content)) { aiSseEventPublisher.markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt, activeThinkingTraceEmitter); aiSseEventPublisher.emitStrict(emitter, "delta", aiSseEventPublisher.buildMessagePayload("requestId", requestId, "content", content)); } activeThinkingTraceEmitter.completeCurrentPhase(); aiSseEventPublisher.emitDone(emitter, requestId, response.getMetadata(), config.getAiParam().getModel(), session.getId(), startedAt, firstTokenAtRef.get()); aiSseEventPublisher.emitSafely(emitter, "status", aiSseEventPublisher.buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get())); aiCallLogService.completeCallLog( callLogId, "COMPLETED", System.currentTimeMillis() - startedAt, aiSseEventPublisher.resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()), response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getPromptTokens(), response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getCompletionTokens(), response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getTotalTokens(), toolSuccessCount.get(), toolFailureCount.get() ); aiStreamStateStore.markStreamState(requestId, tenantId, userId, session.getId(), config.getPromptCode(), "COMPLETED", null); log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}", requestId, session.getId(), System.currentTimeMillis() - startedAt, aiSseEventPublisher.resolveFirstTokenLatency(startedAt, firstTokenAtRef.get())); emitter.complete(); return; } Flux responseFlux = invokeChatStream(chatModel, prompt); AtomicReference lastMetadata = new AtomicReference<>(); StringBuilder assistantContent = new StringBuilder(); try { responseFlux.doOnNext(response -> { lastMetadata.set(response.getMetadata()); String content = extractContent(response); if (StringUtils.hasText(content)) { aiSseEventPublisher.markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt, activeThinkingTraceEmitter); assistantContent.append(content); aiSseEventPublisher.emitStrict(emitter, "delta", aiSseEventPublisher.buildMessagePayload("requestId", requestId, "content", content)); } }) .blockLast(); } catch (Exception e) { throw aiChatFailureHandler.buildAiException("AI_MODEL_STREAM_ERROR", AiErrorCategory.MODEL, "MODEL_STREAM", e == null ? "AI 模型流式调用失败" : e.getMessage(), e); } aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), assistantContent.toString()); activeThinkingTraceEmitter.completeCurrentPhase(); aiSseEventPublisher.emitDone(emitter, requestId, lastMetadata.get(), config.getAiParam().getModel(), session.getId(), startedAt, firstTokenAtRef.get()); aiSseEventPublisher.emitSafely(emitter, "status", aiSseEventPublisher.buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get())); aiCallLogService.completeCallLog( callLogId, "COMPLETED", System.currentTimeMillis() - startedAt, aiSseEventPublisher.resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()), lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getPromptTokens(), lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getCompletionTokens(), lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getTotalTokens(), toolSuccessCount.get(), toolFailureCount.get() ); aiStreamStateStore.markStreamState(requestId, tenantId, userId, session.getId(), config.getPromptCode(), "COMPLETED", null); log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}", requestId, session.getId(), System.currentTimeMillis() - startedAt, aiSseEventPublisher.resolveFirstTokenLatency(startedAt, firstTokenAtRef.get())); emitter.complete(); } } catch (AiChatException e) { aiChatFailureHandler.handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(), e, callLogId, toolSuccessCount.get(), toolFailureCount.get(), thinkingTraceEmitter, tenantId, userId, resolvedPromptCode); } catch (Exception e) { aiChatFailureHandler.handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(), aiChatFailureHandler.buildAiException("AI_INTERNAL_ERROR", AiErrorCategory.INTERNAL, "INTERNAL", e == null ? "AI 对话失败" : e.getMessage(), e), callLogId, toolSuccessCount.get(), toolFailureCount.get(), thinkingTraceEmitter, tenantId, userId, resolvedPromptCode); } finally { log.debug("AI chat stream finished, requestId={}", requestId); } } private void ensureIdentity(Long userId, Long tenantId) { if (userId == null) { throw aiChatFailureHandler.buildAiException("AI_AUTH_USER_MISSING", AiErrorCategory.AUTH, "AUTH_VALIDATE", "当前登录用户不存在", null); } if (tenantId == null) { throw aiChatFailureHandler.buildAiException("AI_AUTH_TENANT_MISSING", AiErrorCategory.AUTH, "AUTH_VALIDATE", "当前租户不存在", null); } } private AiResolvedConfig resolveConfig(AiChatRequest request, Long tenantId) { try { return aiConfigResolverService.resolve(request.getPromptCode(), tenantId, request.getAiParamId()); } catch (Exception e) { throw aiChatFailureHandler.buildAiException("AI_CONFIG_RESOLVE_ERROR", AiErrorCategory.CONFIG, "CONFIG_RESOLVE", e == null ? "AI 配置解析失败" : e.getMessage(), e); } } private AiChatSession resolveSession(AiChatRequest request, Long userId, Long tenantId, String promptCode) { try { return aiChatMemoryService.resolveSession(userId, tenantId, promptCode, request.getSessionId(), aiPromptMessageBuilder.resolveTitleSeed(request.getMessages())); } catch (Exception e) { throw aiChatFailureHandler.buildAiException("AI_SESSION_RESOLVE_ERROR", AiErrorCategory.REQUEST, "SESSION_RESOLVE", e == null ? "AI 会话解析失败" : e.getMessage(), e); } } private AiChatMemoryDto loadMemory(Long userId, Long tenantId, String promptCode, Long sessionId) { try { return aiChatMemoryService.getMemory(userId, tenantId, promptCode, sessionId); } catch (Exception e) { throw aiChatFailureHandler.buildAiException("AI_MEMORY_LOAD_ERROR", AiErrorCategory.REQUEST, "MEMORY_LOAD", e == null ? "AI 会话记忆加载失败" : e.getMessage(), e); } } private McpMountRuntimeFactory.McpMountRuntime createRuntime(AiResolvedConfig config, Long userId) { try { return mcpMountRuntimeFactory.create(config.getMcpMounts(), userId); } catch (Exception e) { throw aiChatFailureHandler.buildAiException("AI_MCP_MOUNT_ERROR", AiErrorCategory.MCP, "MCP_MOUNT", e == null ? "MCP 挂载失败" : e.getMessage(), e); } } private ChatResponse invokeChatCall(OpenAiChatModel chatModel, Prompt prompt) { try { return chatModel.call(prompt); } catch (Exception e) { throw aiChatFailureHandler.buildAiException("AI_MODEL_CALL_ERROR", AiErrorCategory.MODEL, "MODEL_CALL", e == null ? "AI 模型调用失败" : e.getMessage(), e); } } private Flux invokeChatStream(OpenAiChatModel chatModel, Prompt prompt) { try { return chatModel.stream(prompt); } catch (Exception e) { throw aiChatFailureHandler.buildAiException("AI_MODEL_STREAM_ERROR", AiErrorCategory.MODEL, "MODEL_STREAM_INIT", e == null ? "AI 模型流式调用失败" : e.getMessage(), e); } } private String extractContent(ChatResponse response) { if (response == null || response.getResult() == null || response.getResult().getOutput() == null) { return null; } return response.getResult().getOutput().getText(); } }