From 82624affb0251b75b62b35567d3eb260c06efe78 Mon Sep 17 00:00:00 2001
From: zhou zhou <3272660260@qq.com>
Date: 星期一, 23 三月 2026 12:48:07 +0800
Subject: [PATCH] #ai 代码优化

---
 rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java |  360 +++++++++--------------------------------------------------
 1 files changed, 58 insertions(+), 302 deletions(-)

diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java
index 5cb9552..1d6da64 100644
--- a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java
@@ -1,96 +1,81 @@
 package com.vincent.rsf.server.ai.service.impl;
 
-import com.fasterxml.jackson.databind.ObjectMapper;
-import com.vincent.rsf.framework.common.Cools;
-import com.vincent.rsf.framework.exception.CoolException;
 import com.vincent.rsf.server.ai.config.AiDefaults;
-import com.vincent.rsf.server.ai.dto.AiChatDoneDto;
 import com.vincent.rsf.server.ai.dto.AiChatMemoryDto;
-import com.vincent.rsf.server.ai.dto.AiChatMessageDto;
+import com.vincent.rsf.server.ai.dto.AiChatModelOptionDto;
 import com.vincent.rsf.server.ai.dto.AiChatRequest;
 import com.vincent.rsf.server.ai.dto.AiChatRuntimeDto;
 import com.vincent.rsf.server.ai.dto.AiChatSessionDto;
+import com.vincent.rsf.server.ai.dto.AiChatSessionPinRequest;
+import com.vincent.rsf.server.ai.dto.AiChatSessionRenameRequest;
 import com.vincent.rsf.server.ai.dto.AiResolvedConfig;
-import com.vincent.rsf.server.ai.entity.AiParam;
-import com.vincent.rsf.server.ai.entity.AiPrompt;
-import com.vincent.rsf.server.ai.entity.AiChatSession;
-import com.vincent.rsf.server.ai.service.AiChatService;
 import com.vincent.rsf.server.ai.service.AiChatMemoryService;
+import com.vincent.rsf.server.ai.service.AiChatService;
 import com.vincent.rsf.server.ai.service.AiConfigResolverService;
-import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory;
-import io.micrometer.observation.ObservationRegistry;
+import com.vincent.rsf.server.ai.service.AiParamService;
+import com.vincent.rsf.server.ai.service.impl.chat.AiChatOrchestrator;
+import com.vincent.rsf.server.ai.service.impl.chat.AiChatRuntimeAssembler;
+import com.vincent.rsf.server.ai.store.AiConversationCacheStore;
 import lombok.RequiredArgsConstructor;
-import lombok.extern.slf4j.Slf4j;
-import org.springframework.ai.chat.messages.AssistantMessage;
-import org.springframework.ai.chat.messages.Message;
-import org.springframework.ai.chat.messages.SystemMessage;
-import org.springframework.ai.chat.messages.UserMessage;
-import org.springframework.ai.chat.metadata.ChatResponseMetadata;
-import org.springframework.ai.chat.metadata.Usage;
-import org.springframework.ai.chat.model.ChatResponse;
-import org.springframework.ai.chat.prompt.Prompt;
-import org.springframework.ai.model.tool.DefaultToolCallingManager;
-import org.springframework.ai.model.tool.ToolCallingManager;
-import org.springframework.ai.openai.OpenAiChatModel;
-import org.springframework.ai.openai.OpenAiChatOptions;
-import org.springframework.ai.openai.api.OpenAiApi;
-import org.springframework.ai.tool.ToolCallback;
-import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor;
-import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver;
-import org.springframework.ai.util.json.schema.SchemaType;
-import org.springframework.context.support.GenericApplicationContext;
-import org.springframework.http.MediaType;
-import org.springframework.http.client.SimpleClientHttpRequestFactory;
+import org.springframework.beans.factory.annotation.Qualifier;
 import org.springframework.stereotype.Service;
-import org.springframework.util.StringUtils;
-import org.springframework.web.client.RestClient;
-import org.springframework.web.reactive.function.client.WebClient;
 import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
-import reactor.core.publisher.Flux;
 
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.LinkedHashMap;
 import java.util.List;
-import java.util.Map;
 import java.util.Objects;
 import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.atomic.AtomicReference;
+import java.util.concurrent.Executor;
 
-@Slf4j
 @Service
 @RequiredArgsConstructor
 public class AiChatServiceImpl implements AiChatService {
 
     private final AiConfigResolverService aiConfigResolverService;
     private final AiChatMemoryService aiChatMemoryService;
-    private final McpMountRuntimeFactory mcpMountRuntimeFactory;
-    private final GenericApplicationContext applicationContext;
-    private final ObservationRegistry observationRegistry;
-    private final ObjectMapper objectMapper;
+    private final AiParamService aiParamService;
+    private final AiConversationCacheStore aiConversationCacheStore;
+    private final AiChatRuntimeAssembler aiChatRuntimeAssembler;
+    private final AiChatOrchestrator aiChatOrchestrator;
+    @Qualifier("aiChatTaskExecutor")
+    private final Executor aiChatTaskExecutor;
 
     @Override
-    public AiChatRuntimeDto getRuntime(String promptCode, Long sessionId, Long userId, Long tenantId) {
-        AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId);
+    public AiChatRuntimeDto getRuntime(String promptCode, Long sessionId, Long aiParamId, Long userId, Long tenantId) {
+        AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId, aiParamId);
+        AiChatRuntimeDto cached = aiConversationCacheStore.getRuntime(tenantId, userId, config.getPromptCode(), sessionId, aiParamId);
+        if (cached != null) {
+            return cached;
+        }
         AiChatMemoryDto memory = aiChatMemoryService.getMemory(userId, tenantId, config.getPromptCode(), sessionId);
-        return AiChatRuntimeDto.builder()
-                .sessionId(memory.getSessionId())
-                .promptCode(config.getPromptCode())
-                .promptName(config.getPrompt().getName())
-                .model(config.getAiParam().getModel())
-                .configuredMcpCount(config.getMcpMounts().size())
-                .mountedMcpCount(config.getMcpMounts().size())
-                .mountedMcpNames(config.getMcpMounts().stream().map(item -> item.getName()).toList())
-                .mountErrors(List.of())
-                .persistedMessages(memory.getPersistedMessages())
-                .build();
+        List<AiChatModelOptionDto> modelOptions = aiParamService.listChatModelOptions(tenantId);
+        AiChatRuntimeDto runtime = aiChatRuntimeAssembler.buildRuntimeSnapshot(
+                null,
+                memory.getSessionId(),
+                config,
+                modelOptions,
+                config.getMcpMounts().size(),
+                config.getMcpMounts().stream().map(item -> item.getName()).toList(),
+                List.of(),
+                memory
+        );
+        aiConversationCacheStore.cacheRuntime(tenantId, userId, config.getPromptCode(), sessionId, aiParamId, runtime);
+        if (memory.getSessionId() != null && !Objects.equals(memory.getSessionId(), sessionId)) {
+            aiConversationCacheStore.cacheRuntime(tenantId, userId, config.getPromptCode(), memory.getSessionId(), aiParamId, runtime);
+        }
+        return runtime;
     }
 
     @Override
-    public List<AiChatSessionDto> listSessions(String promptCode, Long userId, Long tenantId) {
+    public List<AiChatSessionDto> listSessions(String promptCode, String keyword, Long userId, Long tenantId) {
         AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId);
-        return aiChatMemoryService.listSessions(userId, tenantId, config.getPromptCode());
+        return aiChatMemoryService.listSessions(userId, tenantId, config.getPromptCode(), keyword);
+    }
+
+    @Override
+    public SseEmitter stream(AiChatRequest request, Long userId, Long tenantId) {
+        SseEmitter emitter = new SseEmitter(AiDefaults.SSE_TIMEOUT_MS);
+        CompletableFuture.runAsync(() -> aiChatOrchestrator.executeStream(request, userId, tenantId, emitter), aiChatTaskExecutor);
+        return emitter;
     }
 
     @Override
@@ -99,251 +84,22 @@
     }
 
     @Override
-    public SseEmitter stream(AiChatRequest request, Long userId, Long tenantId) {
-        SseEmitter emitter = new SseEmitter(AiDefaults.SSE_TIMEOUT_MS);
-        CompletableFuture.runAsync(() -> doStream(request, userId, tenantId, emitter));
-        return emitter;
+    public AiChatSessionDto renameSession(Long sessionId, AiChatSessionRenameRequest request, Long userId, Long tenantId) {
+        return aiChatMemoryService.renameSession(userId, tenantId, sessionId, request);
     }
 
-    private void doStream(AiChatRequest request, Long userId, Long tenantId, SseEmitter emitter) {
-        try {
-            AiResolvedConfig config = aiConfigResolverService.resolve(request.getPromptCode(), tenantId);
-            AiChatSession session = aiChatMemoryService.resolveSession(userId, tenantId, config.getPromptCode(), request.getSessionId(), resolveTitleSeed(request.getMessages()));
-            AiChatMemoryDto memory = aiChatMemoryService.getMemory(userId, tenantId, config.getPromptCode(), session.getId());
-            List<AiChatMessageDto> mergedMessages = mergeMessages(memory.getPersistedMessages(), request.getMessages());
-            try (McpMountRuntimeFactory.McpMountRuntime runtime = mcpMountRuntimeFactory.create(config.getMcpMounts(), userId)) {
-                emit(emitter, "start", AiChatRuntimeDto.builder()
-                        .sessionId(session.getId())
-                        .promptCode(config.getPromptCode())
-                        .promptName(config.getPrompt().getName())
-                        .model(config.getAiParam().getModel())
-                        .configuredMcpCount(config.getMcpMounts().size())
-                        .mountedMcpCount(runtime.getMountedCount())
-                        .mountedMcpNames(runtime.getMountedNames())
-                        .mountErrors(runtime.getErrors())
-                        .persistedMessages(memory.getPersistedMessages())
-                        .build());
-
-                Prompt prompt = new Prompt(
-                        buildPromptMessages(mergedMessages, config.getPrompt(), request.getMetadata()),
-                        buildChatOptions(config.getAiParam(), runtime.getToolCallbacks(), userId, request.getMetadata())
-                );
-                OpenAiChatModel chatModel = createChatModel(config.getAiParam());
-                if (Boolean.FALSE.equals(config.getAiParam().getStreamingEnabled())) {
-                    ChatResponse response = chatModel.call(prompt);
-                    String content = extractContent(response);
-                    aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), content);
-                    if (StringUtils.hasText(content)) {
-                        emit(emitter, "delta", buildMessagePayload("content", content));
-                    }
-                    emitDone(emitter, response.getMetadata(), config.getAiParam().getModel(), session.getId());
-                    emitter.complete();
-                    return;
-                }
-
-                Flux<ChatResponse> responseFlux = chatModel.stream(prompt);
-                AtomicReference<ChatResponseMetadata> lastMetadata = new AtomicReference<>();
-                StringBuilder assistantContent = new StringBuilder();
-                responseFlux.doOnNext(response -> {
-                            lastMetadata.set(response.getMetadata());
-                            String content = extractContent(response);
-                            if (StringUtils.hasText(content)) {
-                                assistantContent.append(content);
-                                emit(emitter, "delta", buildMessagePayload("content", content));
-                            }
-                        })
-                        .doOnError(error -> emit(emitter, "error", buildMessagePayload("message", error == null ? "AI 瀵硅瘽澶辫触" : error.getMessage())))
-                        .blockLast();
-                aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), assistantContent.toString());
-                emitDone(emitter, lastMetadata.get(), config.getAiParam().getModel(), session.getId());
-                emitter.complete();
-            }
-        } catch (Exception e) {
-            log.error("AI stream error", e);
-            emit(emitter, "error", buildMessagePayload("message", e == null ? "AI 瀵硅瘽澶辫触" : e.getMessage()));
-            emitter.completeWithError(e);
-        }
+    @Override
+    public AiChatSessionDto pinSession(Long sessionId, AiChatSessionPinRequest request, Long userId, Long tenantId) {
+        return aiChatMemoryService.pinSession(userId, tenantId, sessionId, request);
     }
 
-    private OpenAiChatModel createChatModel(AiParam aiParam) {
-        OpenAiApi openAiApi = buildOpenAiApi(aiParam);
-        ToolCallingManager toolCallingManager = DefaultToolCallingManager.builder()
-                .observationRegistry(observationRegistry)
-                .toolCallbackResolver(new SpringBeanToolCallbackResolver(applicationContext, SchemaType.OPEN_API_SCHEMA))
-                .toolExecutionExceptionProcessor(new DefaultToolExecutionExceptionProcessor(false))
-                .build();
-        return new OpenAiChatModel(
-                openAiApi,
-                OpenAiChatOptions.builder()
-                        .model(aiParam.getModel())
-                        .temperature(aiParam.getTemperature())
-                        .topP(aiParam.getTopP())
-                        .maxTokens(aiParam.getMaxTokens())
-                        .streamUsage(true)
-                        .build(),
-                toolCallingManager,
-                org.springframework.retry.support.RetryTemplate.builder().maxAttempts(1).build(),
-                observationRegistry
-        );
+    @Override
+    public void clearSessionMemory(Long sessionId, Long userId, Long tenantId) {
+        aiChatMemoryService.clearSessionMemory(userId, tenantId, sessionId);
     }
 
-    private OpenAiApi buildOpenAiApi(AiParam aiParam) {
-        int timeoutMs = aiParam.getTimeoutMs() == null ? AiDefaults.DEFAULT_TIMEOUT_MS : aiParam.getTimeoutMs();
-        SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory();
-        requestFactory.setConnectTimeout(timeoutMs);
-        requestFactory.setReadTimeout(timeoutMs);
-
-        return OpenAiApi.builder()
-                .baseUrl(aiParam.getBaseUrl())
-                .apiKey(aiParam.getApiKey())
-                .restClientBuilder(RestClient.builder().requestFactory(requestFactory))
-                .webClientBuilder(WebClient.builder())
-                .build();
-    }
-
-    private OpenAiChatOptions buildChatOptions(AiParam aiParam, ToolCallback[] toolCallbacks, Long userId, Map<String, Object> metadata) {
-        if (userId == null) {
-            throw new CoolException("褰撳墠鐧诲綍鐢ㄦ埛涓嶅瓨鍦�");
-        }
-        OpenAiChatOptions.Builder builder = OpenAiChatOptions.builder()
-                .model(aiParam.getModel())
-                .temperature(aiParam.getTemperature())
-                .topP(aiParam.getTopP())
-                .maxTokens(aiParam.getMaxTokens())
-                .streamUsage(true)
-                .user(String.valueOf(userId));
-        if (!Cools.isEmpty(toolCallbacks)) {
-            builder.toolCallbacks(Arrays.asList(toolCallbacks));
-        }
-        Map<String, String> metadataMap = new LinkedHashMap<>();
-        if (metadata != null) {
-            metadata.forEach((key, value) -> metadataMap.put(key, value == null ? "" : String.valueOf(value)));
-        }
-        if (!metadataMap.isEmpty()) {
-            builder.metadata(metadataMap);
-        }
-        return builder.build();
-    }
-
-    private List<Message> buildPromptMessages(List<AiChatMessageDto> sourceMessages, AiPrompt aiPrompt, Map<String, Object> metadata) {
-        if (Cools.isEmpty(sourceMessages)) {
-            throw new CoolException("瀵硅瘽娑堟伅涓嶈兘涓虹┖");
-        }
-        List<Message> messages = new ArrayList<>();
-        if (StringUtils.hasText(aiPrompt.getSystemPrompt())) {
-            messages.add(new SystemMessage(aiPrompt.getSystemPrompt()));
-        }
-        int lastUserIndex = -1;
-        for (int i = 0; i < sourceMessages.size(); i++) {
-            AiChatMessageDto item = sourceMessages.get(i);
-            if (item != null && "user".equalsIgnoreCase(item.getRole())) {
-                lastUserIndex = i;
-            }
-        }
-        for (int i = 0; i < sourceMessages.size(); i++) {
-            AiChatMessageDto item = sourceMessages.get(i);
-            if (item == null || !StringUtils.hasText(item.getContent())) {
-                continue;
-            }
-            String role = item.getRole() == null ? "user" : item.getRole().toLowerCase();
-            if ("system".equals(role)) {
-                continue;
-            }
-            String content = item.getContent();
-            if ("user".equals(role) && i == lastUserIndex) {
-                content = renderUserPrompt(aiPrompt.getUserPromptTemplate(), content, metadata);
-            }
-            if ("assistant".equals(role)) {
-                messages.add(new AssistantMessage(content));
-            } else {
-                messages.add(new UserMessage(content));
-            }
-        }
-        if (messages.stream().noneMatch(item -> item instanceof UserMessage)) {
-            throw new CoolException("鑷冲皯闇�瑕佷竴鏉$敤鎴锋秷鎭�");
-        }
-        return messages;
-    }
-
-    private List<AiChatMessageDto> mergeMessages(List<AiChatMessageDto> persistedMessages, List<AiChatMessageDto> memoryMessages) {
-        List<AiChatMessageDto> merged = new ArrayList<>();
-        if (!Cools.isEmpty(persistedMessages)) {
-            merged.addAll(persistedMessages);
-        }
-        if (!Cools.isEmpty(memoryMessages)) {
-            merged.addAll(memoryMessages);
-        }
-        if (merged.isEmpty()) {
-            throw new CoolException("瀵硅瘽娑堟伅涓嶈兘涓虹┖");
-        }
-        return merged;
-    }
-
-    private String resolveTitleSeed(List<AiChatMessageDto> messages) {
-        if (Cools.isEmpty(messages)) {
-            throw new CoolException("瀵硅瘽娑堟伅涓嶈兘涓虹┖");
-        }
-        for (int i = messages.size() - 1; i >= 0; i--) {
-            AiChatMessageDto item = messages.get(i);
-            if (item != null && "user".equalsIgnoreCase(item.getRole()) && StringUtils.hasText(item.getContent())) {
-                return item.getContent();
-            }
-        }
-        throw new CoolException("鑷冲皯闇�瑕佷竴鏉$敤鎴锋秷鎭�");
-    }
-
-    private String renderUserPrompt(String userPromptTemplate, String content, Map<String, Object> metadata) {
-        if (!StringUtils.hasText(userPromptTemplate)) {
-            return content;
-        }
-        String rendered = userPromptTemplate
-                .replace("{{input}}", content)
-                .replace("{input}", content);
-        if (metadata != null) {
-            for (Map.Entry<String, Object> entry : metadata.entrySet()) {
-                String value = entry.getValue() == null ? "" : String.valueOf(entry.getValue());
-                rendered = rendered.replace("{{" + entry.getKey() + "}}", value);
-                rendered = rendered.replace("{" + entry.getKey() + "}", value);
-            }
-        }
-        if (Objects.equals(rendered, userPromptTemplate)) {
-            return userPromptTemplate + "\n\n" + content;
-        }
-        return rendered;
-    }
-
-    private String extractContent(ChatResponse response) {
-        if (response == null || response.getResult() == null || response.getResult().getOutput() == null) {
-            return null;
-        }
-        return response.getResult().getOutput().getText();
-    }
-
-    private void emitDone(SseEmitter emitter, ChatResponseMetadata metadata, String fallbackModel, Long sessionId) {
-        Usage usage = metadata == null ? null : metadata.getUsage();
-        emit(emitter, "done", AiChatDoneDto.builder()
-                .sessionId(sessionId)
-                .model(metadata != null && StringUtils.hasText(metadata.getModel()) ? metadata.getModel() : fallbackModel)
-                .promptTokens(usage == null ? null : usage.getPromptTokens())
-                .completionTokens(usage == null ? null : usage.getCompletionTokens())
-                .totalTokens(usage == null ? null : usage.getTotalTokens())
-                .build());
-    }
-
-    private Map<String, String> buildMessagePayload(String key, String value) {
-        Map<String, String> payload = new LinkedHashMap<>();
-        payload.put(key, value == null ? "" : value);
-        return payload;
-    }
-
-    private void emit(SseEmitter emitter, String eventName, Object payload) {
-        try {
-            String data = objectMapper.writeValueAsString(payload);
-            emitter.send(SseEmitter.event()
-                    .name(eventName)
-                    .data(data, MediaType.APPLICATION_JSON));
-        } catch (IOException e) {
-            throw new CoolException("SSE 杈撳嚭澶辫触: " + e.getMessage());
-        }
+    @Override
+    public void retainLatestRound(Long sessionId, Long userId, Long tenantId) {
+        aiChatMemoryService.retainLatestRound(userId, tenantId, sessionId);
     }
 }

--
Gitblit v1.9.1