From 82624affb0251b75b62b35567d3eb260c06efe78 Mon Sep 17 00:00:00 2001
From: zhou zhou <3272660260@qq.com>
Date: 星期一, 23 三月 2026 12:48:07 +0800
Subject: [PATCH] #ai 代码优化
---
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java | 808 ++------------------------------------------------------
1 files changed, 39 insertions(+), 769 deletions(-)
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java
index 40d5594..1d6da64 100644
--- a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java
@@ -1,122 +1,81 @@
package com.vincent.rsf.server.ai.service.impl;
-import com.fasterxml.jackson.databind.ObjectMapper;
-import com.vincent.rsf.framework.common.Cools;
-import com.vincent.rsf.framework.exception.CoolException;
import com.vincent.rsf.server.ai.config.AiDefaults;
-import com.vincent.rsf.server.ai.dto.AiChatDoneDto;
-import com.vincent.rsf.server.ai.dto.AiChatErrorDto;
import com.vincent.rsf.server.ai.dto.AiChatMemoryDto;
-import com.vincent.rsf.server.ai.dto.AiChatMessageDto;
+import com.vincent.rsf.server.ai.dto.AiChatModelOptionDto;
import com.vincent.rsf.server.ai.dto.AiChatRequest;
import com.vincent.rsf.server.ai.dto.AiChatRuntimeDto;
-import com.vincent.rsf.server.ai.dto.AiChatStatusDto;
import com.vincent.rsf.server.ai.dto.AiChatSessionDto;
import com.vincent.rsf.server.ai.dto.AiChatSessionPinRequest;
import com.vincent.rsf.server.ai.dto.AiChatSessionRenameRequest;
-import com.vincent.rsf.server.ai.dto.AiChatToolEventDto;
import com.vincent.rsf.server.ai.dto.AiResolvedConfig;
-import com.vincent.rsf.server.ai.entity.AiCallLog;
-import com.vincent.rsf.server.ai.entity.AiParam;
-import com.vincent.rsf.server.ai.entity.AiPrompt;
-import com.vincent.rsf.server.ai.entity.AiChatSession;
-import com.vincent.rsf.server.ai.enums.AiErrorCategory;
-import com.vincent.rsf.server.ai.exception.AiChatException;
-import com.vincent.rsf.server.ai.service.AiCallLogService;
-import com.vincent.rsf.server.ai.service.AiChatService;
import com.vincent.rsf.server.ai.service.AiChatMemoryService;
+import com.vincent.rsf.server.ai.service.AiChatService;
import com.vincent.rsf.server.ai.service.AiConfigResolverService;
-import com.vincent.rsf.server.ai.service.MountedToolCallback;
-import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory;
-import io.micrometer.observation.ObservationRegistry;
+import com.vincent.rsf.server.ai.service.AiParamService;
+import com.vincent.rsf.server.ai.service.impl.chat.AiChatOrchestrator;
+import com.vincent.rsf.server.ai.service.impl.chat.AiChatRuntimeAssembler;
+import com.vincent.rsf.server.ai.store.AiConversationCacheStore;
import lombok.RequiredArgsConstructor;
-import lombok.extern.slf4j.Slf4j;
-import org.springframework.ai.chat.messages.AssistantMessage;
-import org.springframework.ai.chat.messages.Message;
-import org.springframework.ai.chat.messages.SystemMessage;
-import org.springframework.ai.chat.messages.UserMessage;
-import org.springframework.ai.chat.metadata.ChatResponseMetadata;
-import org.springframework.ai.chat.metadata.Usage;
-import org.springframework.ai.chat.model.ChatResponse;
-import org.springframework.ai.chat.model.ToolContext;
-import org.springframework.ai.chat.prompt.Prompt;
-import org.springframework.ai.model.tool.DefaultToolCallingManager;
-import org.springframework.ai.model.tool.ToolCallingManager;
-import org.springframework.ai.openai.OpenAiChatModel;
-import org.springframework.ai.openai.OpenAiChatOptions;
-import org.springframework.ai.openai.api.OpenAiApi;
-import org.springframework.ai.tool.ToolCallback;
-import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor;
-import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver;
-import org.springframework.ai.util.json.schema.SchemaType;
-import org.springframework.context.support.GenericApplicationContext;
-import org.springframework.http.MediaType;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service;
-import org.springframework.util.StringUtils;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
-import reactor.core.publisher.Flux;
-import java.io.IOException;
-import java.time.Instant;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.LinkedHashMap;
import java.util.List;
-import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
-import java.util.concurrent.atomic.AtomicReference;
-import java.util.concurrent.atomic.AtomicLong;
-@Slf4j
@Service
@RequiredArgsConstructor
public class AiChatServiceImpl implements AiChatService {
private final AiConfigResolverService aiConfigResolverService;
private final AiChatMemoryService aiChatMemoryService;
- private final McpMountRuntimeFactory mcpMountRuntimeFactory;
- private final AiCallLogService aiCallLogService;
- private final GenericApplicationContext applicationContext;
- private final ObservationRegistry observationRegistry;
- private final ObjectMapper objectMapper;
+ private final AiParamService aiParamService;
+ private final AiConversationCacheStore aiConversationCacheStore;
+ private final AiChatRuntimeAssembler aiChatRuntimeAssembler;
+ private final AiChatOrchestrator aiChatOrchestrator;
@Qualifier("aiChatTaskExecutor")
private final Executor aiChatTaskExecutor;
- /**
- * 鑾峰彇褰撳墠瀵硅瘽鎶藉眽鍒濆鍖栨墍闇�鐨勮繍琛屾椂鏁版嵁銆�
- * 璇ユ柟娉曚笉浼氳Е鍙戞ā鍨嬭皟鐢紝鑰屾槸鎶婇厤缃В鏋愮粨鏋滃拰浼氳瘽璁板繂鑱氬悎鎴愬墠绔竴娆℃覆鏌撴墍闇�鐨勫揩鐓с��
- */
@Override
- public AiChatRuntimeDto getRuntime(String promptCode, Long sessionId, Long userId, Long tenantId) {
- AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId);
+ public AiChatRuntimeDto getRuntime(String promptCode, Long sessionId, Long aiParamId, Long userId, Long tenantId) {
+ AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId, aiParamId);
+ AiChatRuntimeDto cached = aiConversationCacheStore.getRuntime(tenantId, userId, config.getPromptCode(), sessionId, aiParamId);
+ if (cached != null) {
+ return cached;
+ }
AiChatMemoryDto memory = aiChatMemoryService.getMemory(userId, tenantId, config.getPromptCode(), sessionId);
- return AiChatRuntimeDto.builder()
- .requestId(null)
- .sessionId(memory.getSessionId())
- .promptCode(config.getPromptCode())
- .promptName(config.getPrompt().getName())
- .model(config.getAiParam().getModel())
- .configuredMcpCount(config.getMcpMounts().size())
- .mountedMcpCount(config.getMcpMounts().size())
- .mountedMcpNames(config.getMcpMounts().stream().map(item -> item.getName()).toList())
- .mountErrors(List.of())
- .memorySummary(memory.getMemorySummary())
- .memoryFacts(memory.getMemoryFacts())
- .recentMessageCount(memory.getRecentMessageCount())
- .persistedMessages(memory.getPersistedMessages())
- .build();
+ List<AiChatModelOptionDto> modelOptions = aiParamService.listChatModelOptions(tenantId);
+ AiChatRuntimeDto runtime = aiChatRuntimeAssembler.buildRuntimeSnapshot(
+ null,
+ memory.getSessionId(),
+ config,
+ modelOptions,
+ config.getMcpMounts().size(),
+ config.getMcpMounts().stream().map(item -> item.getName()).toList(),
+ List.of(),
+ memory
+ );
+ aiConversationCacheStore.cacheRuntime(tenantId, userId, config.getPromptCode(), sessionId, aiParamId, runtime);
+ if (memory.getSessionId() != null && !Objects.equals(memory.getSessionId(), sessionId)) {
+ aiConversationCacheStore.cacheRuntime(tenantId, userId, config.getPromptCode(), memory.getSessionId(), aiParamId, runtime);
+ }
+ return runtime;
}
- /**
- * 鏌ヨ鎸囧畾 Prompt 鍦烘櫙涓嬬殑鍘嗗彶浼氳瘽鎽樿鍒楄〃銆�
- */
@Override
public List<AiChatSessionDto> listSessions(String promptCode, String keyword, Long userId, Long tenantId) {
AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId);
return aiChatMemoryService.listSessions(userId, tenantId, config.getPromptCode(), keyword);
+ }
+
+ @Override
+ public SseEmitter stream(AiChatRequest request, Long userId, Long tenantId) {
+ SseEmitter emitter = new SseEmitter(AiDefaults.SSE_TIMEOUT_MS);
+ CompletableFuture.runAsync(() -> aiChatOrchestrator.executeStream(request, userId, tenantId, emitter), aiChatTaskExecutor);
+ return emitter;
}
@Override
@@ -142,694 +101,5 @@
@Override
public void retainLatestRound(Long sessionId, Long userId, Long tenantId) {
aiChatMemoryService.retainLatestRound(userId, tenantId, sessionId);
- }
-
- /**
- * 鍚姩涓�娆℃柊鐨� SSE 瀵硅瘽娴併��
- * 鎺у埗绾跨▼绔嬪嵆杩斿洖 emitter锛岀湡姝g殑妯″瀷璋冪敤涓庡伐鍏锋墽琛屼氦缁� AI 涓撶敤绾跨▼姹犲紓姝ュ鐞嗐��
- */
- @Override
- public SseEmitter stream(AiChatRequest request, Long userId, Long tenantId) {
- SseEmitter emitter = new SseEmitter(AiDefaults.SSE_TIMEOUT_MS);
- CompletableFuture.runAsync(() -> doStream(request, userId, tenantId, emitter), aiChatTaskExecutor);
- return emitter;
- }
-
- private void doStream(AiChatRequest request, Long userId, Long tenantId, SseEmitter emitter) {
- /**
- * AI 瀵硅瘽鐨勬牳蹇冩墽琛岄摼璺細
- * 1. 鏍¢獙韬唤鍜岃В鏋愮鎴烽厤缃�
- * 2. 瑙f瀽鎴栧垱寤轰細璇濓紝鍔犺浇璁板繂
- * 3. 鍔ㄦ�佹寕杞� MCP 宸ュ叿
- * 4. 鍙戣捣妯″瀷娴佸紡/闈炴祦寮忚皟鐢�
- * 5. 鎸佷箙鍖栨湰杞秷鎭紝杈撳嚭 SSE 浜嬩欢骞惰褰曞璁℃棩蹇�
- */
- String requestId = request.getRequestId();
- long startedAt = System.currentTimeMillis();
- AtomicReference<Long> firstTokenAtRef = new AtomicReference<>();
- AtomicLong toolCallSequence = new AtomicLong(0);
- AtomicLong toolSuccessCount = new AtomicLong(0);
- AtomicLong toolFailureCount = new AtomicLong(0);
- Long sessionId = request.getSessionId();
- Long callLogId = null;
- String model = null;
- try {
- ensureIdentity(userId, tenantId);
- AiResolvedConfig config = resolveConfig(request, tenantId);
- final String resolvedModel = config.getAiParam().getModel();
- model = resolvedModel;
- AiChatSession session = resolveSession(request, userId, tenantId, config.getPromptCode());
- sessionId = session.getId();
- AiChatMemoryDto memory = loadMemory(userId, tenantId, config.getPromptCode(), session.getId());
- List<AiChatMessageDto> mergedMessages = mergeMessages(memory.getShortMemoryMessages(), request.getMessages());
- AiCallLog callLog = aiCallLogService.startCallLog(
- requestId,
- session.getId(),
- userId,
- tenantId,
- config.getPromptCode(),
- config.getPrompt().getName(),
- config.getAiParam().getModel(),
- config.getMcpMounts().size(),
- config.getMcpMounts().size(),
- config.getMcpMounts().stream().map(item -> item.getName()).toList()
- );
- callLogId = callLog.getId();
- try (McpMountRuntimeFactory.McpMountRuntime runtime = createRuntime(config, userId)) {
- emitStrict(emitter, "start", AiChatRuntimeDto.builder()
- .requestId(requestId)
- .sessionId(session.getId())
- .promptCode(config.getPromptCode())
- .promptName(config.getPrompt().getName())
- .model(config.getAiParam().getModel())
- .configuredMcpCount(config.getMcpMounts().size())
- .mountedMcpCount(runtime.getMountedCount())
- .mountedMcpNames(runtime.getMountedNames())
- .mountErrors(runtime.getErrors())
- .memorySummary(memory.getMemorySummary())
- .memoryFacts(memory.getMemoryFacts())
- .recentMessageCount(memory.getRecentMessageCount())
- .persistedMessages(memory.getPersistedMessages())
- .build());
- emitSafely(emitter, "status", AiChatStatusDto.builder()
- .requestId(requestId)
- .sessionId(session.getId())
- .status("STARTED")
- .model(resolvedModel)
- .timestamp(Instant.now().toEpochMilli())
- .elapsedMs(0L)
- .build());
- log.info("AI chat started, requestId={}, userId={}, tenantId={}, sessionId={}, model={}",
- requestId, userId, tenantId, session.getId(), resolvedModel);
-
- ToolCallback[] observableToolCallbacks = wrapToolCallbacks(
- runtime.getToolCallbacks(), emitter, requestId, session.getId(), toolCallSequence,
- toolSuccessCount, toolFailureCount, callLogId, userId, tenantId
- );
- Prompt prompt = new Prompt(
- buildPromptMessages(memory, mergedMessages, config.getPrompt(), request.getMetadata()),
- buildChatOptions(config.getAiParam(), observableToolCallbacks, userId, tenantId,
- requestId, session.getId(), request.getMetadata())
- );
- OpenAiChatModel chatModel = createChatModel(config.getAiParam());
- if (Boolean.FALSE.equals(config.getAiParam().getStreamingEnabled())) {
- ChatResponse response = invokeChatCall(chatModel, prompt);
- String content = extractContent(response);
- aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), content);
- if (StringUtils.hasText(content)) {
- markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt);
- emitStrict(emitter, "delta", buildMessagePayload("requestId", requestId, "content", content));
- }
- emitDone(emitter, requestId, response.getMetadata(), config.getAiParam().getModel(), session.getId(), startedAt, firstTokenAtRef.get());
- emitSafely(emitter, "status", buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get()));
- aiCallLogService.completeCallLog(
- callLogId,
- "COMPLETED",
- System.currentTimeMillis() - startedAt,
- resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()),
- response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getPromptTokens(),
- response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getCompletionTokens(),
- response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getTotalTokens(),
- toolSuccessCount.get(),
- toolFailureCount.get()
- );
- log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}",
- requestId, session.getId(), System.currentTimeMillis() - startedAt, resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()));
- emitter.complete();
- return;
- }
-
- Flux<ChatResponse> responseFlux = invokeChatStream(chatModel, prompt);
- AtomicReference<ChatResponseMetadata> lastMetadata = new AtomicReference<>();
- StringBuilder assistantContent = new StringBuilder();
- try {
- responseFlux.doOnNext(response -> {
- lastMetadata.set(response.getMetadata());
- String content = extractContent(response);
- if (StringUtils.hasText(content)) {
- markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt);
- assistantContent.append(content);
- emitStrict(emitter, "delta", buildMessagePayload("requestId", requestId, "content", content));
- }
- })
- .blockLast();
- } catch (Exception e) {
- throw buildAiException("AI_MODEL_STREAM_ERROR", AiErrorCategory.MODEL, "MODEL_STREAM",
- e == null ? "AI 妯″瀷娴佸紡璋冪敤澶辫触" : e.getMessage(), e);
- }
- aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), assistantContent.toString());
- emitDone(emitter, requestId, lastMetadata.get(), config.getAiParam().getModel(), session.getId(), startedAt, firstTokenAtRef.get());
- emitSafely(emitter, "status", buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get()));
- aiCallLogService.completeCallLog(
- callLogId,
- "COMPLETED",
- System.currentTimeMillis() - startedAt,
- resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()),
- lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getPromptTokens(),
- lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getCompletionTokens(),
- lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getTotalTokens(),
- toolSuccessCount.get(),
- toolFailureCount.get()
- );
- log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}",
- requestId, session.getId(), System.currentTimeMillis() - startedAt, resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()));
- emitter.complete();
- }
- } catch (AiChatException e) {
- handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(), e,
- callLogId, toolSuccessCount.get(), toolFailureCount.get());
- } catch (Exception e) {
- handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(),
- buildAiException("AI_INTERNAL_ERROR", AiErrorCategory.INTERNAL, "INTERNAL",
- e == null ? "AI 瀵硅瘽澶辫触" : e.getMessage(), e),
- callLogId, toolSuccessCount.get(), toolFailureCount.get());
- } finally {
- log.debug("AI chat stream finished, requestId={}", requestId);
- }
- }
-
- private void ensureIdentity(Long userId, Long tenantId) {
- if (userId == null) {
- throw buildAiException("AI_AUTH_USER_MISSING", AiErrorCategory.AUTH, "AUTH_VALIDATE", "褰撳墠鐧诲綍鐢ㄦ埛涓嶅瓨鍦�", null);
- }
- if (tenantId == null) {
- throw buildAiException("AI_AUTH_TENANT_MISSING", AiErrorCategory.AUTH, "AUTH_VALIDATE", "褰撳墠绉熸埛涓嶅瓨鍦�", null);
- }
- }
-
- private AiResolvedConfig resolveConfig(AiChatRequest request, Long tenantId) {
- /** 鎶婅姹傞噷鐨� Prompt 鍦烘櫙瑙f瀽鎴愪竴浠藉彲鐩存帴鎵ц鐨� AI 閰嶇疆銆� */
- try {
- return aiConfigResolverService.resolve(request.getPromptCode(), tenantId);
- } catch (Exception e) {
- throw buildAiException("AI_CONFIG_RESOLVE_ERROR", AiErrorCategory.CONFIG, "CONFIG_RESOLVE",
- e == null ? "AI 閰嶇疆瑙f瀽澶辫触" : e.getMessage(), e);
- }
- }
-
- private AiChatSession resolveSession(AiChatRequest request, Long userId, Long tenantId, String promptCode) {
- /** 鏍规嵁 sessionId 澶嶇敤鍘嗗彶浼氳瘽锛屾垨鍦ㄩ娆℃彁闂椂鍒涘缓鏂颁細璇濄�� */
- try {
- return aiChatMemoryService.resolveSession(userId, tenantId, promptCode, request.getSessionId(), resolveTitleSeed(request.getMessages()));
- } catch (Exception e) {
- throw buildAiException("AI_SESSION_RESOLVE_ERROR", AiErrorCategory.REQUEST, "SESSION_RESOLVE",
- e == null ? "AI 浼氳瘽瑙f瀽澶辫触" : e.getMessage(), e);
- }
- }
-
- private AiChatMemoryDto loadMemory(Long userId, Long tenantId, String promptCode, Long sessionId) {
- /** 璇诲彇浼氳瘽鐨勭煭鏈熻蹇嗐�佹憳瑕佽蹇嗗拰浜嬪疄璁板繂锛屼緵妯″瀷缁勮涓婁笅鏂囥�� */
- try {
- return aiChatMemoryService.getMemory(userId, tenantId, promptCode, sessionId);
- } catch (Exception e) {
- throw buildAiException("AI_MEMORY_LOAD_ERROR", AiErrorCategory.REQUEST, "MEMORY_LOAD",
- e == null ? "AI 浼氳瘽璁板繂鍔犺浇澶辫触" : e.getMessage(), e);
- }
- }
-
- private McpMountRuntimeFactory.McpMountRuntime createRuntime(AiResolvedConfig config, Long userId) {
- /** 鎸夐厤缃腑鐨� MCP 鎸傝浇璁板綍鏋勯�犳湰杞璇濅笓灞炵殑宸ュ叿杩愯鏃躲�� */
- try {
- return mcpMountRuntimeFactory.create(config.getMcpMounts(), userId);
- } catch (Exception e) {
- throw buildAiException("AI_MCP_MOUNT_ERROR", AiErrorCategory.MCP, "MCP_MOUNT",
- e == null ? "MCP 鎸傝浇澶辫触" : e.getMessage(), e);
- }
- }
-
- private ChatResponse invokeChatCall(OpenAiChatModel chatModel, Prompt prompt) {
- try {
- return chatModel.call(prompt);
- } catch (Exception e) {
- throw buildAiException("AI_MODEL_CALL_ERROR", AiErrorCategory.MODEL, "MODEL_CALL",
- e == null ? "AI 妯″瀷璋冪敤澶辫触" : e.getMessage(), e);
- }
- }
-
- private Flux<ChatResponse> invokeChatStream(OpenAiChatModel chatModel, Prompt prompt) {
- try {
- return chatModel.stream(prompt);
- } catch (Exception e) {
- throw buildAiException("AI_MODEL_STREAM_ERROR", AiErrorCategory.MODEL, "MODEL_STREAM_INIT",
- e == null ? "AI 妯″瀷娴佸紡璋冪敤澶辫触" : e.getMessage(), e);
- }
- }
-
- private void markFirstToken(AtomicReference<Long> firstTokenAtRef, SseEmitter emitter, String requestId, Long sessionId, String model, long startedAt) {
- if (!firstTokenAtRef.compareAndSet(null, System.currentTimeMillis())) {
- return;
- }
- emitSafely(emitter, "status", AiChatStatusDto.builder()
- .requestId(requestId)
- .sessionId(sessionId)
- .status("FIRST_TOKEN")
- .model(model)
- .timestamp(Instant.now().toEpochMilli())
- .elapsedMs(System.currentTimeMillis() - startedAt)
- .firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()))
- .build());
- }
-
- private AiChatStatusDto buildTerminalStatus(String requestId, Long sessionId, String status, String model, long startedAt, Long firstTokenAt) {
- return AiChatStatusDto.builder()
- .requestId(requestId)
- .sessionId(sessionId)
- .status(status)
- .model(model)
- .timestamp(Instant.now().toEpochMilli())
- .elapsedMs(System.currentTimeMillis() - startedAt)
- .firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAt))
- .build();
- }
-
- private Long resolveFirstTokenLatency(long startedAt, Long firstTokenAt) {
- return firstTokenAt == null ? null : Math.max(0L, firstTokenAt - startedAt);
- }
-
- private void handleStreamFailure(SseEmitter emitter, String requestId, Long sessionId, String model, long startedAt,
- Long firstTokenAt, AiChatException exception, Long callLogId,
- long toolSuccessCount, long toolFailureCount) {
- if (isClientAbortException(exception)) {
- log.warn("AI chat aborted by client, requestId={}, sessionId={}, stage={}, message={}",
- requestId, sessionId, exception.getStage(), exception.getMessage());
- emitSafely(emitter, "status", buildTerminalStatus(requestId, sessionId, "ABORTED", model, startedAt, firstTokenAt));
- aiCallLogService.failCallLog(
- callLogId,
- "ABORTED",
- exception.getCategory().name(),
- exception.getStage(),
- exception.getMessage(),
- System.currentTimeMillis() - startedAt,
- resolveFirstTokenLatency(startedAt, firstTokenAt),
- toolSuccessCount,
- toolFailureCount
- );
- emitter.completeWithError(exception);
- return;
- }
- log.error("AI chat failed, requestId={}, sessionId={}, category={}, stage={}, message={}",
- requestId, sessionId, exception.getCategory(), exception.getStage(), exception.getMessage(), exception);
- emitSafely(emitter, "status", buildTerminalStatus(requestId, sessionId, "FAILED", model, startedAt, firstTokenAt));
- emitSafely(emitter, "error", AiChatErrorDto.builder()
- .requestId(requestId)
- .sessionId(sessionId)
- .code(exception.getCode())
- .category(exception.getCategory().name())
- .stage(exception.getStage())
- .message(exception.getMessage())
- .timestamp(Instant.now().toEpochMilli())
- .build());
- aiCallLogService.failCallLog(
- callLogId,
- "FAILED",
- exception.getCategory().name(),
- exception.getStage(),
- exception.getMessage(),
- System.currentTimeMillis() - startedAt,
- resolveFirstTokenLatency(startedAt, firstTokenAt),
- toolSuccessCount,
- toolFailureCount
- );
- emitter.completeWithError(exception);
- }
-
- private OpenAiChatModel createChatModel(AiParam aiParam) {
- OpenAiApi openAiApi = buildOpenAiApi(aiParam);
- ToolCallingManager toolCallingManager = DefaultToolCallingManager.builder()
- .observationRegistry(observationRegistry)
- .toolCallbackResolver(new SpringBeanToolCallbackResolver(applicationContext, SchemaType.OPEN_API_SCHEMA))
- .toolExecutionExceptionProcessor(new DefaultToolExecutionExceptionProcessor(false))
- .build();
- return new OpenAiChatModel(
- openAiApi,
- OpenAiChatOptions.builder()
- .model(aiParam.getModel())
- .temperature(aiParam.getTemperature())
- .topP(aiParam.getTopP())
- .maxTokens(aiParam.getMaxTokens())
- .streamUsage(true)
- .build(),
- toolCallingManager,
- org.springframework.retry.support.RetryTemplate.builder().maxAttempts(1).build(),
- observationRegistry
- );
- }
-
- private OpenAiApi buildOpenAiApi(AiParam aiParam) {
- return AiOpenAiApiSupport.buildOpenAiApi(aiParam);
- }
-
- private OpenAiChatOptions buildChatOptions(AiParam aiParam, ToolCallback[] toolCallbacks, Long userId, Long tenantId,
- String requestId, Long sessionId, Map<String, Object> metadata) {
- /**
- * 缁勮涓�娆¤亰澶╄皟鐢ㄧ殑鍏ㄩ儴妯″瀷閫夐」鍜� Tool Context銆�
- * Tool Context 浼氶�忎紶缁欏唴缃伐鍏峰拰澶栭儴 MCP锛屼繚璇佸伐鍏峰湪绉熸埛鍜屼細璇濊寖鍥村唴鎵ц銆�
- */
- if (userId == null) {
- throw buildAiException("AI_AUTH_USER_MISSING", AiErrorCategory.AUTH, "OPTIONS_BUILD", "褰撳墠鐧诲綍鐢ㄦ埛涓嶅瓨鍦�", null);
- }
- OpenAiChatOptions.Builder builder = OpenAiChatOptions.builder()
- .model(aiParam.getModel())
- .temperature(aiParam.getTemperature())
- .topP(aiParam.getTopP())
- .maxTokens(aiParam.getMaxTokens())
- .streamUsage(true)
- .user(String.valueOf(userId));
- if (!Cools.isEmpty(toolCallbacks)) {
- builder.toolCallbacks(Arrays.stream(toolCallbacks).toList());
- }
- Map<String, Object> toolContext = new LinkedHashMap<>();
- toolContext.put("userId", userId);
- toolContext.put("tenantId", tenantId);
- toolContext.put("requestId", requestId);
- toolContext.put("sessionId", sessionId);
- Map<String, String> metadataMap = new LinkedHashMap<>();
- if (metadata != null) {
- metadata.forEach((key, value) -> {
- String normalized = value == null ? "" : String.valueOf(value);
- metadataMap.put(key, normalized);
- toolContext.put(key, normalized);
- });
- }
- builder.toolContext(toolContext);
- if (!metadataMap.isEmpty()) {
- builder.metadata(metadataMap);
- }
- return builder.build();
- }
-
- private ToolCallback[] wrapToolCallbacks(ToolCallback[] toolCallbacks, SseEmitter emitter, String requestId,
- Long sessionId, AtomicLong toolCallSequence,
- AtomicLong toolSuccessCount, AtomicLong toolFailureCount,
- Long callLogId, Long userId, Long tenantId) {
- /** 缁欐墍鏈夊伐鍏峰洖璋冨涓婁竴灞傚彲瑙傛祴鍖呰锛岀敤浜庡疄鏃� SSE 杞ㄨ抗鍜屽璁℃棩蹇楄惤搴撱�� */
- if (Cools.isEmpty(toolCallbacks)) {
- return toolCallbacks;
- }
- List<ToolCallback> wrappedCallbacks = new ArrayList<>();
- for (ToolCallback callback : toolCallbacks) {
- if (callback == null) {
- continue;
- }
- wrappedCallbacks.add(new ObservableToolCallback(callback, emitter, requestId, sessionId, toolCallSequence,
- toolSuccessCount, toolFailureCount, callLogId, userId, tenantId));
- }
- return wrappedCallbacks.toArray(new ToolCallback[0]);
- }
-
- private List<Message> buildPromptMessages(AiChatMemoryDto memory, List<AiChatMessageDto> sourceMessages, AiPrompt aiPrompt, Map<String, Object> metadata) {
- /**
- * 缁勮鏈�缁堟彁浜ょ粰妯″瀷鐨勬秷鎭垪琛ㄣ��
- * 椤哄簭涓婂缁堟槸锛氱郴缁� Prompt -> 鍘嗗彶鎽樿 -> 鍏抽敭浜嬪疄 -> 鏈�杩戝璇� -> 褰撳墠鐢ㄦ埛杈撳叆銆�
- */
- if (Cools.isEmpty(sourceMessages)) {
- throw new CoolException("瀵硅瘽娑堟伅涓嶈兘涓虹┖");
- }
- List<Message> messages = new ArrayList<>();
- if (StringUtils.hasText(aiPrompt.getSystemPrompt())) {
- messages.add(new SystemMessage(aiPrompt.getSystemPrompt()));
- }
- if (memory != null && StringUtils.hasText(memory.getMemorySummary())) {
- messages.add(new SystemMessage("鍘嗗彶鎽樿:\n" + memory.getMemorySummary()));
- }
- if (memory != null && StringUtils.hasText(memory.getMemoryFacts())) {
- messages.add(new SystemMessage("鍏抽敭浜嬪疄:\n" + memory.getMemoryFacts()));
- }
- int lastUserIndex = -1;
- for (int i = 0; i < sourceMessages.size(); i++) {
- AiChatMessageDto item = sourceMessages.get(i);
- if (item != null && "user".equalsIgnoreCase(item.getRole())) {
- lastUserIndex = i;
- }
- }
- for (int i = 0; i < sourceMessages.size(); i++) {
- AiChatMessageDto item = sourceMessages.get(i);
- if (item == null || !StringUtils.hasText(item.getContent())) {
- continue;
- }
- String role = item.getRole() == null ? "user" : item.getRole().toLowerCase();
- if ("system".equals(role)) {
- continue;
- }
- String content = item.getContent();
- if ("user".equals(role) && i == lastUserIndex) {
- content = renderUserPrompt(aiPrompt.getUserPromptTemplate(), content, metadata);
- }
- if ("assistant".equals(role)) {
- messages.add(new AssistantMessage(content));
- } else {
- messages.add(new UserMessage(content));
- }
- }
- if (messages.stream().noneMatch(item -> item instanceof UserMessage)) {
- throw new CoolException("鑷冲皯闇�瑕佷竴鏉$敤鎴锋秷鎭�");
- }
- return messages;
- }
-
- private List<AiChatMessageDto> mergeMessages(List<AiChatMessageDto> persistedMessages, List<AiChatMessageDto> memoryMessages) {
- /** 鎶婅惤搴撳巻鍙蹭笌鏈疆鍓嶇鍐呭瓨澧為噺鍚堝苟鎴愭ā鍨嬪彲娑堣垂鐨勫畬鏁翠笂涓嬫枃銆� */
- List<AiChatMessageDto> merged = new ArrayList<>();
- if (!Cools.isEmpty(persistedMessages)) {
- merged.addAll(persistedMessages);
- }
- if (!Cools.isEmpty(memoryMessages)) {
- merged.addAll(memoryMessages);
- }
- if (merged.isEmpty()) {
- throw new CoolException("瀵硅瘽娑堟伅涓嶈兘涓虹┖");
- }
- return merged;
- }
-
- private String resolveTitleSeed(List<AiChatMessageDto> messages) {
- if (Cools.isEmpty(messages)) {
- throw new CoolException("瀵硅瘽娑堟伅涓嶈兘涓虹┖");
- }
- for (int i = messages.size() - 1; i >= 0; i--) {
- AiChatMessageDto item = messages.get(i);
- if (item != null && "user".equalsIgnoreCase(item.getRole()) && StringUtils.hasText(item.getContent())) {
- return item.getContent();
- }
- }
- throw new CoolException("鑷冲皯闇�瑕佷竴鏉$敤鎴锋秷鎭�");
- }
-
- private String renderUserPrompt(String userPromptTemplate, String content, Map<String, Object> metadata) {
- if (!StringUtils.hasText(userPromptTemplate)) {
- return content;
- }
- String rendered = userPromptTemplate
- .replace("{{input}}", content)
- .replace("{input}", content);
- if (metadata != null) {
- for (Map.Entry<String, Object> entry : metadata.entrySet()) {
- String value = entry.getValue() == null ? "" : String.valueOf(entry.getValue());
- rendered = rendered.replace("{{" + entry.getKey() + "}}", value);
- rendered = rendered.replace("{" + entry.getKey() + "}", value);
- }
- }
- if (Objects.equals(rendered, userPromptTemplate)) {
- return userPromptTemplate + "\n\n" + content;
- }
- return rendered;
- }
-
- private String extractContent(ChatResponse response) {
- if (response == null || response.getResult() == null || response.getResult().getOutput() == null) {
- return null;
- }
- return response.getResult().getOutput().getText();
- }
-
- private String summarizeToolPayload(String content, int maxLength) {
- if (!StringUtils.hasText(content)) {
- return null;
- }
- String normalized = content.trim()
- .replace("\r", " ")
- .replace("\n", " ")
- .replaceAll("\\s+", " ");
- return normalized.length() > maxLength ? normalized.substring(0, maxLength) : normalized;
- }
-
- private void emitDone(SseEmitter emitter, String requestId, ChatResponseMetadata metadata, String fallbackModel, Long sessionId, long startedAt, Long firstTokenAt) {
- /** 杈撳嚭瀵硅瘽瀹屾垚浜嬩欢锛岀粺涓�灏佽鑰楁椂銆侀鍖呭欢杩熷拰 token 鐢ㄩ噺銆� */
- Usage usage = metadata == null ? null : metadata.getUsage();
- emitStrict(emitter, "done", AiChatDoneDto.builder()
- .requestId(requestId)
- .sessionId(sessionId)
- .model(metadata != null && StringUtils.hasText(metadata.getModel()) ? metadata.getModel() : fallbackModel)
- .elapsedMs(System.currentTimeMillis() - startedAt)
- .firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAt))
- .promptTokens(usage == null ? null : usage.getPromptTokens())
- .completionTokens(usage == null ? null : usage.getCompletionTokens())
- .totalTokens(usage == null ? null : usage.getTotalTokens())
- .build());
- }
-
- private Map<String, String> buildMessagePayload(String... keyValues) {
- Map<String, String> payload = new LinkedHashMap<>();
- if (keyValues == null || keyValues.length == 0) {
- return payload;
- }
- if (keyValues.length % 2 != 0) {
- throw new CoolException("娑堟伅杞借嵎鍙傛暟蹇呴』鎴愬鍑虹幇");
- }
- for (int i = 0; i < keyValues.length; i += 2) {
- payload.put(keyValues[i], keyValues[i + 1] == null ? "" : keyValues[i + 1]);
- }
- return payload;
- }
-
- private void emitStrict(SseEmitter emitter, String eventName, Object payload) {
- /** 涓ユ牸鍙戦�� SSE 浜嬩欢锛涗竴鏃﹀彂閫佸け璐ワ紝鐩存帴涓婃姏涓烘祦寮忚緭鍑哄紓甯搞�� */
- try {
- String data = objectMapper.writeValueAsString(payload);
- emitter.send(SseEmitter.event()
- .name(eventName)
- .data(data, MediaType.APPLICATION_JSON));
- } catch (IOException e) {
- throw buildAiException("AI_SSE_EMIT_ERROR", AiErrorCategory.STREAM, "SSE_EMIT", "SSE 杈撳嚭澶辫触: " + e.getMessage(), e);
- }
- }
-
- private void emitSafely(SseEmitter emitter, String eventName, Object payload) {
- /** 灏濊瘯鍙戦�侀潪鍏抽敭浜嬩欢锛屽彂閫佸け璐ュ彧璁板綍鏃ュ織锛屼笉鎵撴柇涓诲璇濇祦绋嬨�� */
- try {
- emitStrict(emitter, eventName, payload);
- } catch (Exception e) {
- log.warn("AI SSE event emit skipped, eventName={}, message={}", eventName, e.getMessage());
- }
- }
-
- private AiChatException buildAiException(String code, AiErrorCategory category, String stage, String message, Throwable cause) {
- return new AiChatException(code, category, stage, message, cause);
- }
-
- private boolean isClientAbortException(Throwable throwable) {
- Throwable current = throwable;
- while (current != null) {
- String message = current.getMessage();
- if (message != null) {
- String normalized = message.toLowerCase();
- if (normalized.contains("broken pipe")
- || normalized.contains("connection reset")
- || normalized.contains("forcibly closed")
- || normalized.contains("abort")) {
- return true;
- }
- }
- current = current.getCause();
- }
- return false;
- }
-
- private class ObservableToolCallback implements ToolCallback {
-
- private final ToolCallback delegate;
- private final SseEmitter emitter;
- private final String requestId;
- private final Long sessionId;
- private final AtomicLong toolCallSequence;
- private final AtomicLong toolSuccessCount;
- private final AtomicLong toolFailureCount;
- private final Long callLogId;
- private final Long userId;
- private final Long tenantId;
-
- private ObservableToolCallback(ToolCallback delegate, SseEmitter emitter, String requestId,
- Long sessionId, AtomicLong toolCallSequence,
- AtomicLong toolSuccessCount, AtomicLong toolFailureCount,
- Long callLogId, Long userId, Long tenantId) {
- this.delegate = delegate;
- this.emitter = emitter;
- this.requestId = requestId;
- this.sessionId = sessionId;
- this.toolCallSequence = toolCallSequence;
- this.toolSuccessCount = toolSuccessCount;
- this.toolFailureCount = toolFailureCount;
- this.callLogId = callLogId;
- this.userId = userId;
- this.tenantId = tenantId;
- }
-
- @Override
- public org.springframework.ai.tool.definition.ToolDefinition getToolDefinition() {
- return delegate.getToolDefinition();
- }
-
- @Override
- public org.springframework.ai.tool.metadata.ToolMetadata getToolMetadata() {
- return delegate.getToolMetadata();
- }
-
- @Override
- public String call(String toolInput) {
- return call(toolInput, null);
- }
-
- @Override
- public String call(String toolInput, ToolContext toolContext) {
- /**
- * 宸ュ叿鎵ц瑙傛祴鍖呰鍣ㄣ��
- * 鍦ㄧ湡瀹炶皟鐢ㄥ墠鍚庡垎鍒彂閫� tool_start / tool_result / tool_error锛�
- * 鍚屾椂鎶婅皟鐢ㄦ憳瑕佸啓鍏� MCP 璋冪敤鏃ュ織琛ㄣ��
- */
- String toolName = delegate.getToolDefinition() == null ? "unknown" : delegate.getToolDefinition().name();
- String mountName = delegate instanceof MountedToolCallback ? ((MountedToolCallback) delegate).getMountName() : null;
- String toolCallId = requestId + "-tool-" + toolCallSequence.incrementAndGet();
- long startedAt = System.currentTimeMillis();
- emitSafely(emitter, "tool_start", AiChatToolEventDto.builder()
- .requestId(requestId)
- .sessionId(sessionId)
- .toolCallId(toolCallId)
- .toolName(toolName)
- .mountName(mountName)
- .status("STARTED")
- .inputSummary(summarizeToolPayload(toolInput, 400))
- .timestamp(startedAt)
- .build());
- try {
- String output = toolContext == null ? delegate.call(toolInput) : delegate.call(toolInput, toolContext);
- long durationMs = System.currentTimeMillis() - startedAt;
- emitSafely(emitter, "tool_result", AiChatToolEventDto.builder()
- .requestId(requestId)
- .sessionId(sessionId)
- .toolCallId(toolCallId)
- .toolName(toolName)
- .mountName(mountName)
- .status("COMPLETED")
- .inputSummary(summarizeToolPayload(toolInput, 400))
- .outputSummary(summarizeToolPayload(output, 600))
- .durationMs(durationMs)
- .timestamp(System.currentTimeMillis())
- .build());
- toolSuccessCount.incrementAndGet();
- aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
- "COMPLETED", summarizeToolPayload(toolInput, 400), summarizeToolPayload(output, 600),
- null, durationMs, userId, tenantId);
- return output;
- } catch (RuntimeException e) {
- long durationMs = System.currentTimeMillis() - startedAt;
- emitSafely(emitter, "tool_error", AiChatToolEventDto.builder()
- .requestId(requestId)
- .sessionId(sessionId)
- .toolCallId(toolCallId)
- .toolName(toolName)
- .mountName(mountName)
- .status("FAILED")
- .inputSummary(summarizeToolPayload(toolInput, 400))
- .errorMessage(e.getMessage())
- .durationMs(durationMs)
- .timestamp(System.currentTimeMillis())
- .build());
- toolFailureCount.incrementAndGet();
- aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
- "FAILED", summarizeToolPayload(toolInput, 400), null, e.getMessage(),
- durationMs, userId, tenantId);
- throw e;
- }
- }
}
}
--
Gitblit v1.9.1