package com.vincent.rsf.server.ai.service.impl; import com.fasterxml.jackson.databind.ObjectMapper; import com.vincent.rsf.framework.common.Cools; import com.vincent.rsf.framework.exception.CoolException; import com.vincent.rsf.server.ai.config.AiDefaults; import com.vincent.rsf.server.ai.dto.AiChatDoneDto; import com.vincent.rsf.server.ai.dto.AiChatErrorDto; import com.vincent.rsf.server.ai.dto.AiChatMemoryDto; import com.vincent.rsf.server.ai.dto.AiChatMessageDto; import com.vincent.rsf.server.ai.dto.AiChatRequest; import com.vincent.rsf.server.ai.dto.AiChatRuntimeDto; import com.vincent.rsf.server.ai.dto.AiChatStatusDto; import com.vincent.rsf.server.ai.dto.AiChatSessionDto; import com.vincent.rsf.server.ai.dto.AiChatSessionPinRequest; import com.vincent.rsf.server.ai.dto.AiChatSessionRenameRequest; import com.vincent.rsf.server.ai.dto.AiChatToolEventDto; import com.vincent.rsf.server.ai.dto.AiResolvedConfig; import com.vincent.rsf.server.ai.entity.AiCallLog; import com.vincent.rsf.server.ai.entity.AiParam; import com.vincent.rsf.server.ai.entity.AiPrompt; import com.vincent.rsf.server.ai.entity.AiChatSession; import com.vincent.rsf.server.ai.enums.AiErrorCategory; import com.vincent.rsf.server.ai.exception.AiChatException; import com.vincent.rsf.server.ai.service.AiCallLogService; import com.vincent.rsf.server.ai.service.AiChatService; import com.vincent.rsf.server.ai.service.AiChatMemoryService; import com.vincent.rsf.server.ai.service.AiConfigResolverService; import com.vincent.rsf.server.ai.service.MountedToolCallback; import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory; import io.micrometer.observation.ObservationRegistry; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.tool.DefaultToolCallingManager; import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor; import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver; import org.springframework.ai.util.json.schema.SchemaType; import org.springframework.context.support.GenericApplicationContext; import org.springframework.http.MediaType; import org.springframework.http.client.SimpleClientHttpRequestFactory; import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.stereotype.Service; import org.springframework.util.StringUtils; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import reactor.core.publisher.Flux; import java.io.IOException; import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicLong; @Slf4j @Service @RequiredArgsConstructor public class AiChatServiceImpl implements AiChatService { private final AiConfigResolverService aiConfigResolverService; private final AiChatMemoryService aiChatMemoryService; private final McpMountRuntimeFactory mcpMountRuntimeFactory; private final AiCallLogService aiCallLogService; private final GenericApplicationContext applicationContext; private final ObservationRegistry observationRegistry; private final ObjectMapper objectMapper; @Qualifier("aiChatTaskExecutor") private final Executor aiChatTaskExecutor; @Override public AiChatRuntimeDto getRuntime(String promptCode, Long sessionId, Long userId, Long tenantId) { AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId); AiChatMemoryDto memory = aiChatMemoryService.getMemory(userId, tenantId, config.getPromptCode(), sessionId); return AiChatRuntimeDto.builder() .requestId(null) .sessionId(memory.getSessionId()) .promptCode(config.getPromptCode()) .promptName(config.getPrompt().getName()) .model(config.getAiParam().getModel()) .configuredMcpCount(config.getMcpMounts().size()) .mountedMcpCount(config.getMcpMounts().size()) .mountedMcpNames(config.getMcpMounts().stream().map(item -> item.getName()).toList()) .mountErrors(List.of()) .memorySummary(memory.getMemorySummary()) .memoryFacts(memory.getMemoryFacts()) .recentMessageCount(memory.getRecentMessageCount()) .persistedMessages(memory.getPersistedMessages()) .build(); } @Override public List listSessions(String promptCode, String keyword, Long userId, Long tenantId) { AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId); return aiChatMemoryService.listSessions(userId, tenantId, config.getPromptCode(), keyword); } @Override public void removeSession(Long sessionId, Long userId, Long tenantId) { aiChatMemoryService.removeSession(userId, tenantId, sessionId); } @Override public AiChatSessionDto renameSession(Long sessionId, AiChatSessionRenameRequest request, Long userId, Long tenantId) { return aiChatMemoryService.renameSession(userId, tenantId, sessionId, request); } @Override public AiChatSessionDto pinSession(Long sessionId, AiChatSessionPinRequest request, Long userId, Long tenantId) { return aiChatMemoryService.pinSession(userId, tenantId, sessionId, request); } @Override public void clearSessionMemory(Long sessionId, Long userId, Long tenantId) { aiChatMemoryService.clearSessionMemory(userId, tenantId, sessionId); } @Override public void retainLatestRound(Long sessionId, Long userId, Long tenantId) { aiChatMemoryService.retainLatestRound(userId, tenantId, sessionId); } @Override public SseEmitter stream(AiChatRequest request, Long userId, Long tenantId) { SseEmitter emitter = new SseEmitter(AiDefaults.SSE_TIMEOUT_MS); CompletableFuture.runAsync(() -> doStream(request, userId, tenantId, emitter), aiChatTaskExecutor); return emitter; } private void doStream(AiChatRequest request, Long userId, Long tenantId, SseEmitter emitter) { String requestId = request.getRequestId(); long startedAt = System.currentTimeMillis(); AtomicReference firstTokenAtRef = new AtomicReference<>(); AtomicLong toolCallSequence = new AtomicLong(0); AtomicLong toolSuccessCount = new AtomicLong(0); AtomicLong toolFailureCount = new AtomicLong(0); Long sessionId = request.getSessionId(); Long callLogId = null; String model = null; try { ensureIdentity(userId, tenantId); AiResolvedConfig config = resolveConfig(request, tenantId); final String resolvedModel = config.getAiParam().getModel(); model = resolvedModel; AiChatSession session = resolveSession(request, userId, tenantId, config.getPromptCode()); sessionId = session.getId(); AiChatMemoryDto memory = loadMemory(userId, tenantId, config.getPromptCode(), session.getId()); List mergedMessages = mergeMessages(memory.getShortMemoryMessages(), request.getMessages()); AiCallLog callLog = aiCallLogService.startCallLog( requestId, session.getId(), userId, tenantId, config.getPromptCode(), config.getPrompt().getName(), config.getAiParam().getModel(), config.getMcpMounts().size(), config.getMcpMounts().size(), config.getMcpMounts().stream().map(item -> item.getName()).toList() ); callLogId = callLog.getId(); try (McpMountRuntimeFactory.McpMountRuntime runtime = createRuntime(config, userId)) { emitStrict(emitter, "start", AiChatRuntimeDto.builder() .requestId(requestId) .sessionId(session.getId()) .promptCode(config.getPromptCode()) .promptName(config.getPrompt().getName()) .model(config.getAiParam().getModel()) .configuredMcpCount(config.getMcpMounts().size()) .mountedMcpCount(runtime.getMountedCount()) .mountedMcpNames(runtime.getMountedNames()) .mountErrors(runtime.getErrors()) .memorySummary(memory.getMemorySummary()) .memoryFacts(memory.getMemoryFacts()) .recentMessageCount(memory.getRecentMessageCount()) .persistedMessages(memory.getPersistedMessages()) .build()); emitSafely(emitter, "status", AiChatStatusDto.builder() .requestId(requestId) .sessionId(session.getId()) .status("STARTED") .model(resolvedModel) .timestamp(Instant.now().toEpochMilli()) .elapsedMs(0L) .build()); log.info("AI chat started, requestId={}, userId={}, tenantId={}, sessionId={}, model={}", requestId, userId, tenantId, session.getId(), resolvedModel); ToolCallback[] observableToolCallbacks = wrapToolCallbacks( runtime.getToolCallbacks(), emitter, requestId, session.getId(), toolCallSequence, toolSuccessCount, toolFailureCount, callLogId, userId, tenantId ); Prompt prompt = new Prompt( buildPromptMessages(memory, mergedMessages, config.getPrompt(), request.getMetadata()), buildChatOptions(config.getAiParam(), observableToolCallbacks, userId, tenantId, requestId, session.getId(), request.getMetadata()) ); OpenAiChatModel chatModel = createChatModel(config.getAiParam()); if (Boolean.FALSE.equals(config.getAiParam().getStreamingEnabled())) { ChatResponse response = invokeChatCall(chatModel, prompt); String content = extractContent(response); aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), content); if (StringUtils.hasText(content)) { markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt); emitStrict(emitter, "delta", buildMessagePayload("requestId", requestId, "content", content)); } emitDone(emitter, requestId, response.getMetadata(), config.getAiParam().getModel(), session.getId(), startedAt, firstTokenAtRef.get()); emitSafely(emitter, "status", buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get())); aiCallLogService.completeCallLog( callLogId, "COMPLETED", System.currentTimeMillis() - startedAt, resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()), response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getPromptTokens(), response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getCompletionTokens(), response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getTotalTokens(), toolSuccessCount.get(), toolFailureCount.get() ); log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}", requestId, session.getId(), System.currentTimeMillis() - startedAt, resolveFirstTokenLatency(startedAt, firstTokenAtRef.get())); emitter.complete(); return; } Flux responseFlux = invokeChatStream(chatModel, prompt); AtomicReference lastMetadata = new AtomicReference<>(); StringBuilder assistantContent = new StringBuilder(); try { responseFlux.doOnNext(response -> { lastMetadata.set(response.getMetadata()); String content = extractContent(response); if (StringUtils.hasText(content)) { markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt); assistantContent.append(content); emitStrict(emitter, "delta", buildMessagePayload("requestId", requestId, "content", content)); } }) .blockLast(); } catch (Exception e) { throw buildAiException("AI_MODEL_STREAM_ERROR", AiErrorCategory.MODEL, "MODEL_STREAM", e == null ? "AI 模型流式调用失败" : e.getMessage(), e); } aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), assistantContent.toString()); emitDone(emitter, requestId, lastMetadata.get(), config.getAiParam().getModel(), session.getId(), startedAt, firstTokenAtRef.get()); emitSafely(emitter, "status", buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get())); aiCallLogService.completeCallLog( callLogId, "COMPLETED", System.currentTimeMillis() - startedAt, resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()), lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getPromptTokens(), lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getCompletionTokens(), lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getTotalTokens(), toolSuccessCount.get(), toolFailureCount.get() ); log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}", requestId, session.getId(), System.currentTimeMillis() - startedAt, resolveFirstTokenLatency(startedAt, firstTokenAtRef.get())); emitter.complete(); } } catch (AiChatException e) { handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(), e, callLogId, toolSuccessCount.get(), toolFailureCount.get()); } catch (Exception e) { handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(), buildAiException("AI_INTERNAL_ERROR", AiErrorCategory.INTERNAL, "INTERNAL", e == null ? "AI 对话失败" : e.getMessage(), e), callLogId, toolSuccessCount.get(), toolFailureCount.get()); } finally { log.debug("AI chat stream finished, requestId={}", requestId); } } private void ensureIdentity(Long userId, Long tenantId) { if (userId == null) { throw buildAiException("AI_AUTH_USER_MISSING", AiErrorCategory.AUTH, "AUTH_VALIDATE", "当前登录用户不存在", null); } if (tenantId == null) { throw buildAiException("AI_AUTH_TENANT_MISSING", AiErrorCategory.AUTH, "AUTH_VALIDATE", "当前租户不存在", null); } } private AiResolvedConfig resolveConfig(AiChatRequest request, Long tenantId) { try { return aiConfigResolverService.resolve(request.getPromptCode(), tenantId); } catch (Exception e) { throw buildAiException("AI_CONFIG_RESOLVE_ERROR", AiErrorCategory.CONFIG, "CONFIG_RESOLVE", e == null ? "AI 配置解析失败" : e.getMessage(), e); } } private AiChatSession resolveSession(AiChatRequest request, Long userId, Long tenantId, String promptCode) { try { return aiChatMemoryService.resolveSession(userId, tenantId, promptCode, request.getSessionId(), resolveTitleSeed(request.getMessages())); } catch (Exception e) { throw buildAiException("AI_SESSION_RESOLVE_ERROR", AiErrorCategory.REQUEST, "SESSION_RESOLVE", e == null ? "AI 会话解析失败" : e.getMessage(), e); } } private AiChatMemoryDto loadMemory(Long userId, Long tenantId, String promptCode, Long sessionId) { try { return aiChatMemoryService.getMemory(userId, tenantId, promptCode, sessionId); } catch (Exception e) { throw buildAiException("AI_MEMORY_LOAD_ERROR", AiErrorCategory.REQUEST, "MEMORY_LOAD", e == null ? "AI 会话记忆加载失败" : e.getMessage(), e); } } private McpMountRuntimeFactory.McpMountRuntime createRuntime(AiResolvedConfig config, Long userId) { try { return mcpMountRuntimeFactory.create(config.getMcpMounts(), userId); } catch (Exception e) { throw buildAiException("AI_MCP_MOUNT_ERROR", AiErrorCategory.MCP, "MCP_MOUNT", e == null ? "MCP 挂载失败" : e.getMessage(), e); } } private ChatResponse invokeChatCall(OpenAiChatModel chatModel, Prompt prompt) { try { return chatModel.call(prompt); } catch (Exception e) { throw buildAiException("AI_MODEL_CALL_ERROR", AiErrorCategory.MODEL, "MODEL_CALL", e == null ? "AI 模型调用失败" : e.getMessage(), e); } } private Flux invokeChatStream(OpenAiChatModel chatModel, Prompt prompt) { try { return chatModel.stream(prompt); } catch (Exception e) { throw buildAiException("AI_MODEL_STREAM_ERROR", AiErrorCategory.MODEL, "MODEL_STREAM_INIT", e == null ? "AI 模型流式调用失败" : e.getMessage(), e); } } private void markFirstToken(AtomicReference firstTokenAtRef, SseEmitter emitter, String requestId, Long sessionId, String model, long startedAt) { if (!firstTokenAtRef.compareAndSet(null, System.currentTimeMillis())) { return; } emitSafely(emitter, "status", AiChatStatusDto.builder() .requestId(requestId) .sessionId(sessionId) .status("FIRST_TOKEN") .model(model) .timestamp(Instant.now().toEpochMilli()) .elapsedMs(System.currentTimeMillis() - startedAt) .firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAtRef.get())) .build()); } private AiChatStatusDto buildTerminalStatus(String requestId, Long sessionId, String status, String model, long startedAt, Long firstTokenAt) { return AiChatStatusDto.builder() .requestId(requestId) .sessionId(sessionId) .status(status) .model(model) .timestamp(Instant.now().toEpochMilli()) .elapsedMs(System.currentTimeMillis() - startedAt) .firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAt)) .build(); } private Long resolveFirstTokenLatency(long startedAt, Long firstTokenAt) { return firstTokenAt == null ? null : Math.max(0L, firstTokenAt - startedAt); } private void handleStreamFailure(SseEmitter emitter, String requestId, Long sessionId, String model, long startedAt, Long firstTokenAt, AiChatException exception, Long callLogId, long toolSuccessCount, long toolFailureCount) { if (isClientAbortException(exception)) { log.warn("AI chat aborted by client, requestId={}, sessionId={}, stage={}, message={}", requestId, sessionId, exception.getStage(), exception.getMessage()); emitSafely(emitter, "status", buildTerminalStatus(requestId, sessionId, "ABORTED", model, startedAt, firstTokenAt)); aiCallLogService.failCallLog( callLogId, "ABORTED", exception.getCategory().name(), exception.getStage(), exception.getMessage(), System.currentTimeMillis() - startedAt, resolveFirstTokenLatency(startedAt, firstTokenAt), toolSuccessCount, toolFailureCount ); emitter.completeWithError(exception); return; } log.error("AI chat failed, requestId={}, sessionId={}, category={}, stage={}, message={}", requestId, sessionId, exception.getCategory(), exception.getStage(), exception.getMessage(), exception); emitSafely(emitter, "status", buildTerminalStatus(requestId, sessionId, "FAILED", model, startedAt, firstTokenAt)); emitSafely(emitter, "error", AiChatErrorDto.builder() .requestId(requestId) .sessionId(sessionId) .code(exception.getCode()) .category(exception.getCategory().name()) .stage(exception.getStage()) .message(exception.getMessage()) .timestamp(Instant.now().toEpochMilli()) .build()); aiCallLogService.failCallLog( callLogId, "FAILED", exception.getCategory().name(), exception.getStage(), exception.getMessage(), System.currentTimeMillis() - startedAt, resolveFirstTokenLatency(startedAt, firstTokenAt), toolSuccessCount, toolFailureCount ); emitter.completeWithError(exception); } private OpenAiChatModel createChatModel(AiParam aiParam) { OpenAiApi openAiApi = buildOpenAiApi(aiParam); ToolCallingManager toolCallingManager = DefaultToolCallingManager.builder() .observationRegistry(observationRegistry) .toolCallbackResolver(new SpringBeanToolCallbackResolver(applicationContext, SchemaType.OPEN_API_SCHEMA)) .toolExecutionExceptionProcessor(new DefaultToolExecutionExceptionProcessor(false)) .build(); return new OpenAiChatModel( openAiApi, OpenAiChatOptions.builder() .model(aiParam.getModel()) .temperature(aiParam.getTemperature()) .topP(aiParam.getTopP()) .maxTokens(aiParam.getMaxTokens()) .streamUsage(true) .build(), toolCallingManager, org.springframework.retry.support.RetryTemplate.builder().maxAttempts(1).build(), observationRegistry ); } private OpenAiApi buildOpenAiApi(AiParam aiParam) { int timeoutMs = aiParam.getTimeoutMs() == null ? AiDefaults.DEFAULT_TIMEOUT_MS : aiParam.getTimeoutMs(); SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory(); requestFactory.setConnectTimeout(timeoutMs); requestFactory.setReadTimeout(timeoutMs); return OpenAiApi.builder() .baseUrl(aiParam.getBaseUrl()) .apiKey(aiParam.getApiKey()) .restClientBuilder(RestClient.builder().requestFactory(requestFactory)) .webClientBuilder(WebClient.builder()) .build(); } private OpenAiChatOptions buildChatOptions(AiParam aiParam, ToolCallback[] toolCallbacks, Long userId, Long tenantId, String requestId, Long sessionId, Map metadata) { if (userId == null) { throw buildAiException("AI_AUTH_USER_MISSING", AiErrorCategory.AUTH, "OPTIONS_BUILD", "当前登录用户不存在", null); } OpenAiChatOptions.Builder builder = OpenAiChatOptions.builder() .model(aiParam.getModel()) .temperature(aiParam.getTemperature()) .topP(aiParam.getTopP()) .maxTokens(aiParam.getMaxTokens()) .streamUsage(true) .user(String.valueOf(userId)); if (!Cools.isEmpty(toolCallbacks)) { builder.toolCallbacks(Arrays.stream(toolCallbacks).toList()); } Map toolContext = new LinkedHashMap<>(); toolContext.put("userId", userId); toolContext.put("tenantId", tenantId); toolContext.put("requestId", requestId); toolContext.put("sessionId", sessionId); Map metadataMap = new LinkedHashMap<>(); if (metadata != null) { metadata.forEach((key, value) -> { String normalized = value == null ? "" : String.valueOf(value); metadataMap.put(key, normalized); toolContext.put(key, normalized); }); } builder.toolContext(toolContext); if (!metadataMap.isEmpty()) { builder.metadata(metadataMap); } return builder.build(); } private ToolCallback[] wrapToolCallbacks(ToolCallback[] toolCallbacks, SseEmitter emitter, String requestId, Long sessionId, AtomicLong toolCallSequence, AtomicLong toolSuccessCount, AtomicLong toolFailureCount, Long callLogId, Long userId, Long tenantId) { if (Cools.isEmpty(toolCallbacks)) { return toolCallbacks; } List wrappedCallbacks = new ArrayList<>(); for (ToolCallback callback : toolCallbacks) { if (callback == null) { continue; } wrappedCallbacks.add(new ObservableToolCallback(callback, emitter, requestId, sessionId, toolCallSequence, toolSuccessCount, toolFailureCount, callLogId, userId, tenantId)); } return wrappedCallbacks.toArray(new ToolCallback[0]); } private List buildPromptMessages(AiChatMemoryDto memory, List sourceMessages, AiPrompt aiPrompt, Map metadata) { if (Cools.isEmpty(sourceMessages)) { throw new CoolException("对话消息不能为空"); } List messages = new ArrayList<>(); if (StringUtils.hasText(aiPrompt.getSystemPrompt())) { messages.add(new SystemMessage(aiPrompt.getSystemPrompt())); } if (memory != null && StringUtils.hasText(memory.getMemorySummary())) { messages.add(new SystemMessage("历史摘要:\n" + memory.getMemorySummary())); } if (memory != null && StringUtils.hasText(memory.getMemoryFacts())) { messages.add(new SystemMessage("关键事实:\n" + memory.getMemoryFacts())); } int lastUserIndex = -1; for (int i = 0; i < sourceMessages.size(); i++) { AiChatMessageDto item = sourceMessages.get(i); if (item != null && "user".equalsIgnoreCase(item.getRole())) { lastUserIndex = i; } } for (int i = 0; i < sourceMessages.size(); i++) { AiChatMessageDto item = sourceMessages.get(i); if (item == null || !StringUtils.hasText(item.getContent())) { continue; } String role = item.getRole() == null ? "user" : item.getRole().toLowerCase(); if ("system".equals(role)) { continue; } String content = item.getContent(); if ("user".equals(role) && i == lastUserIndex) { content = renderUserPrompt(aiPrompt.getUserPromptTemplate(), content, metadata); } if ("assistant".equals(role)) { messages.add(new AssistantMessage(content)); } else { messages.add(new UserMessage(content)); } } if (messages.stream().noneMatch(item -> item instanceof UserMessage)) { throw new CoolException("至少需要一条用户消息"); } return messages; } private List mergeMessages(List persistedMessages, List memoryMessages) { List merged = new ArrayList<>(); if (!Cools.isEmpty(persistedMessages)) { merged.addAll(persistedMessages); } if (!Cools.isEmpty(memoryMessages)) { merged.addAll(memoryMessages); } if (merged.isEmpty()) { throw new CoolException("对话消息不能为空"); } return merged; } private String resolveTitleSeed(List messages) { if (Cools.isEmpty(messages)) { throw new CoolException("对话消息不能为空"); } for (int i = messages.size() - 1; i >= 0; i--) { AiChatMessageDto item = messages.get(i); if (item != null && "user".equalsIgnoreCase(item.getRole()) && StringUtils.hasText(item.getContent())) { return item.getContent(); } } throw new CoolException("至少需要一条用户消息"); } private String renderUserPrompt(String userPromptTemplate, String content, Map metadata) { if (!StringUtils.hasText(userPromptTemplate)) { return content; } String rendered = userPromptTemplate .replace("{{input}}", content) .replace("{input}", content); if (metadata != null) { for (Map.Entry entry : metadata.entrySet()) { String value = entry.getValue() == null ? "" : String.valueOf(entry.getValue()); rendered = rendered.replace("{{" + entry.getKey() + "}}", value); rendered = rendered.replace("{" + entry.getKey() + "}", value); } } if (Objects.equals(rendered, userPromptTemplate)) { return userPromptTemplate + "\n\n" + content; } return rendered; } private String extractContent(ChatResponse response) { if (response == null || response.getResult() == null || response.getResult().getOutput() == null) { return null; } return response.getResult().getOutput().getText(); } private String summarizeToolPayload(String content, int maxLength) { if (!StringUtils.hasText(content)) { return null; } String normalized = content.trim() .replace("\r", " ") .replace("\n", " ") .replaceAll("\\s+", " "); return normalized.length() > maxLength ? normalized.substring(0, maxLength) : normalized; } private void emitDone(SseEmitter emitter, String requestId, ChatResponseMetadata metadata, String fallbackModel, Long sessionId, long startedAt, Long firstTokenAt) { Usage usage = metadata == null ? null : metadata.getUsage(); emitStrict(emitter, "done", AiChatDoneDto.builder() .requestId(requestId) .sessionId(sessionId) .model(metadata != null && StringUtils.hasText(metadata.getModel()) ? metadata.getModel() : fallbackModel) .elapsedMs(System.currentTimeMillis() - startedAt) .firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAt)) .promptTokens(usage == null ? null : usage.getPromptTokens()) .completionTokens(usage == null ? null : usage.getCompletionTokens()) .totalTokens(usage == null ? null : usage.getTotalTokens()) .build()); } private Map buildMessagePayload(String... keyValues) { Map payload = new LinkedHashMap<>(); if (keyValues == null || keyValues.length == 0) { return payload; } if (keyValues.length % 2 != 0) { throw new CoolException("消息载荷参数必须成对出现"); } for (int i = 0; i < keyValues.length; i += 2) { payload.put(keyValues[i], keyValues[i + 1] == null ? "" : keyValues[i + 1]); } return payload; } private void emitStrict(SseEmitter emitter, String eventName, Object payload) { try { String data = objectMapper.writeValueAsString(payload); emitter.send(SseEmitter.event() .name(eventName) .data(data, MediaType.APPLICATION_JSON)); } catch (IOException e) { throw buildAiException("AI_SSE_EMIT_ERROR", AiErrorCategory.STREAM, "SSE_EMIT", "SSE 输出失败: " + e.getMessage(), e); } } private void emitSafely(SseEmitter emitter, String eventName, Object payload) { try { emitStrict(emitter, eventName, payload); } catch (Exception e) { log.warn("AI SSE event emit skipped, eventName={}, message={}", eventName, e.getMessage()); } } private AiChatException buildAiException(String code, AiErrorCategory category, String stage, String message, Throwable cause) { return new AiChatException(code, category, stage, message, cause); } private boolean isClientAbortException(Throwable throwable) { Throwable current = throwable; while (current != null) { String message = current.getMessage(); if (message != null) { String normalized = message.toLowerCase(); if (normalized.contains("broken pipe") || normalized.contains("connection reset") || normalized.contains("forcibly closed") || normalized.contains("abort")) { return true; } } current = current.getCause(); } return false; } private class ObservableToolCallback implements ToolCallback { private final ToolCallback delegate; private final SseEmitter emitter; private final String requestId; private final Long sessionId; private final AtomicLong toolCallSequence; private final AtomicLong toolSuccessCount; private final AtomicLong toolFailureCount; private final Long callLogId; private final Long userId; private final Long tenantId; private ObservableToolCallback(ToolCallback delegate, SseEmitter emitter, String requestId, Long sessionId, AtomicLong toolCallSequence, AtomicLong toolSuccessCount, AtomicLong toolFailureCount, Long callLogId, Long userId, Long tenantId) { this.delegate = delegate; this.emitter = emitter; this.requestId = requestId; this.sessionId = sessionId; this.toolCallSequence = toolCallSequence; this.toolSuccessCount = toolSuccessCount; this.toolFailureCount = toolFailureCount; this.callLogId = callLogId; this.userId = userId; this.tenantId = tenantId; } @Override public org.springframework.ai.tool.definition.ToolDefinition getToolDefinition() { return delegate.getToolDefinition(); } @Override public org.springframework.ai.tool.metadata.ToolMetadata getToolMetadata() { return delegate.getToolMetadata(); } @Override public String call(String toolInput) { return call(toolInput, null); } @Override public String call(String toolInput, ToolContext toolContext) { String toolName = delegate.getToolDefinition() == null ? "unknown" : delegate.getToolDefinition().name(); String mountName = delegate instanceof MountedToolCallback ? ((MountedToolCallback) delegate).getMountName() : null; String toolCallId = requestId + "-tool-" + toolCallSequence.incrementAndGet(); long startedAt = System.currentTimeMillis(); emitSafely(emitter, "tool_start", AiChatToolEventDto.builder() .requestId(requestId) .sessionId(sessionId) .toolCallId(toolCallId) .toolName(toolName) .mountName(mountName) .status("STARTED") .inputSummary(summarizeToolPayload(toolInput, 400)) .timestamp(startedAt) .build()); try { String output = toolContext == null ? delegate.call(toolInput) : delegate.call(toolInput, toolContext); long durationMs = System.currentTimeMillis() - startedAt; emitSafely(emitter, "tool_result", AiChatToolEventDto.builder() .requestId(requestId) .sessionId(sessionId) .toolCallId(toolCallId) .toolName(toolName) .mountName(mountName) .status("COMPLETED") .inputSummary(summarizeToolPayload(toolInput, 400)) .outputSummary(summarizeToolPayload(output, 600)) .durationMs(durationMs) .timestamp(System.currentTimeMillis()) .build()); toolSuccessCount.incrementAndGet(); aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName, "COMPLETED", summarizeToolPayload(toolInput, 400), summarizeToolPayload(output, 600), null, durationMs, userId, tenantId); return output; } catch (RuntimeException e) { long durationMs = System.currentTimeMillis() - startedAt; emitSafely(emitter, "tool_error", AiChatToolEventDto.builder() .requestId(requestId) .sessionId(sessionId) .toolCallId(toolCallId) .toolName(toolName) .mountName(mountName) .status("FAILED") .inputSummary(summarizeToolPayload(toolInput, 400)) .errorMessage(e.getMessage()) .durationMs(durationMs) .timestamp(System.currentTimeMillis()) .build()); toolFailureCount.incrementAndGet(); aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName, "FAILED", summarizeToolPayload(toolInput, 400), null, e.getMessage(), durationMs, userId, tenantId); throw e; } } } }