From 82624affb0251b75b62b35567d3eb260c06efe78 Mon Sep 17 00:00:00 2001
From: zhou zhou <3272660260@qq.com>
Date: 星期一, 23 三月 2026 12:48:07 +0800
Subject: [PATCH] #ai 代码优化
---
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java | 360 +++++++++--------------------------------------------------
1 files changed, 58 insertions(+), 302 deletions(-)
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java
index 5cb9552..1d6da64 100644
--- a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java
@@ -1,96 +1,81 @@
package com.vincent.rsf.server.ai.service.impl;
-import com.fasterxml.jackson.databind.ObjectMapper;
-import com.vincent.rsf.framework.common.Cools;
-import com.vincent.rsf.framework.exception.CoolException;
import com.vincent.rsf.server.ai.config.AiDefaults;
-import com.vincent.rsf.server.ai.dto.AiChatDoneDto;
import com.vincent.rsf.server.ai.dto.AiChatMemoryDto;
-import com.vincent.rsf.server.ai.dto.AiChatMessageDto;
+import com.vincent.rsf.server.ai.dto.AiChatModelOptionDto;
import com.vincent.rsf.server.ai.dto.AiChatRequest;
import com.vincent.rsf.server.ai.dto.AiChatRuntimeDto;
import com.vincent.rsf.server.ai.dto.AiChatSessionDto;
+import com.vincent.rsf.server.ai.dto.AiChatSessionPinRequest;
+import com.vincent.rsf.server.ai.dto.AiChatSessionRenameRequest;
import com.vincent.rsf.server.ai.dto.AiResolvedConfig;
-import com.vincent.rsf.server.ai.entity.AiParam;
-import com.vincent.rsf.server.ai.entity.AiPrompt;
-import com.vincent.rsf.server.ai.entity.AiChatSession;
-import com.vincent.rsf.server.ai.service.AiChatService;
import com.vincent.rsf.server.ai.service.AiChatMemoryService;
+import com.vincent.rsf.server.ai.service.AiChatService;
import com.vincent.rsf.server.ai.service.AiConfigResolverService;
-import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory;
-import io.micrometer.observation.ObservationRegistry;
+import com.vincent.rsf.server.ai.service.AiParamService;
+import com.vincent.rsf.server.ai.service.impl.chat.AiChatOrchestrator;
+import com.vincent.rsf.server.ai.service.impl.chat.AiChatRuntimeAssembler;
+import com.vincent.rsf.server.ai.store.AiConversationCacheStore;
import lombok.RequiredArgsConstructor;
-import lombok.extern.slf4j.Slf4j;
-import org.springframework.ai.chat.messages.AssistantMessage;
-import org.springframework.ai.chat.messages.Message;
-import org.springframework.ai.chat.messages.SystemMessage;
-import org.springframework.ai.chat.messages.UserMessage;
-import org.springframework.ai.chat.metadata.ChatResponseMetadata;
-import org.springframework.ai.chat.metadata.Usage;
-import org.springframework.ai.chat.model.ChatResponse;
-import org.springframework.ai.chat.prompt.Prompt;
-import org.springframework.ai.model.tool.DefaultToolCallingManager;
-import org.springframework.ai.model.tool.ToolCallingManager;
-import org.springframework.ai.openai.OpenAiChatModel;
-import org.springframework.ai.openai.OpenAiChatOptions;
-import org.springframework.ai.openai.api.OpenAiApi;
-import org.springframework.ai.tool.ToolCallback;
-import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor;
-import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver;
-import org.springframework.ai.util.json.schema.SchemaType;
-import org.springframework.context.support.GenericApplicationContext;
-import org.springframework.http.MediaType;
-import org.springframework.http.client.SimpleClientHttpRequestFactory;
+import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service;
-import org.springframework.util.StringUtils;
-import org.springframework.web.client.RestClient;
-import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
-import reactor.core.publisher.Flux;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.LinkedHashMap;
import java.util.List;
-import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.atomic.AtomicReference;
+import java.util.concurrent.Executor;
-@Slf4j
@Service
@RequiredArgsConstructor
public class AiChatServiceImpl implements AiChatService {
private final AiConfigResolverService aiConfigResolverService;
private final AiChatMemoryService aiChatMemoryService;
- private final McpMountRuntimeFactory mcpMountRuntimeFactory;
- private final GenericApplicationContext applicationContext;
- private final ObservationRegistry observationRegistry;
- private final ObjectMapper objectMapper;
+ private final AiParamService aiParamService;
+ private final AiConversationCacheStore aiConversationCacheStore;
+ private final AiChatRuntimeAssembler aiChatRuntimeAssembler;
+ private final AiChatOrchestrator aiChatOrchestrator;
+ @Qualifier("aiChatTaskExecutor")
+ private final Executor aiChatTaskExecutor;
@Override
- public AiChatRuntimeDto getRuntime(String promptCode, Long sessionId, Long userId, Long tenantId) {
- AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId);
+ public AiChatRuntimeDto getRuntime(String promptCode, Long sessionId, Long aiParamId, Long userId, Long tenantId) {
+ AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId, aiParamId);
+ AiChatRuntimeDto cached = aiConversationCacheStore.getRuntime(tenantId, userId, config.getPromptCode(), sessionId, aiParamId);
+ if (cached != null) {
+ return cached;
+ }
AiChatMemoryDto memory = aiChatMemoryService.getMemory(userId, tenantId, config.getPromptCode(), sessionId);
- return AiChatRuntimeDto.builder()
- .sessionId(memory.getSessionId())
- .promptCode(config.getPromptCode())
- .promptName(config.getPrompt().getName())
- .model(config.getAiParam().getModel())
- .configuredMcpCount(config.getMcpMounts().size())
- .mountedMcpCount(config.getMcpMounts().size())
- .mountedMcpNames(config.getMcpMounts().stream().map(item -> item.getName()).toList())
- .mountErrors(List.of())
- .persistedMessages(memory.getPersistedMessages())
- .build();
+ List<AiChatModelOptionDto> modelOptions = aiParamService.listChatModelOptions(tenantId);
+ AiChatRuntimeDto runtime = aiChatRuntimeAssembler.buildRuntimeSnapshot(
+ null,
+ memory.getSessionId(),
+ config,
+ modelOptions,
+ config.getMcpMounts().size(),
+ config.getMcpMounts().stream().map(item -> item.getName()).toList(),
+ List.of(),
+ memory
+ );
+ aiConversationCacheStore.cacheRuntime(tenantId, userId, config.getPromptCode(), sessionId, aiParamId, runtime);
+ if (memory.getSessionId() != null && !Objects.equals(memory.getSessionId(), sessionId)) {
+ aiConversationCacheStore.cacheRuntime(tenantId, userId, config.getPromptCode(), memory.getSessionId(), aiParamId, runtime);
+ }
+ return runtime;
}
@Override
- public List<AiChatSessionDto> listSessions(String promptCode, Long userId, Long tenantId) {
+ public List<AiChatSessionDto> listSessions(String promptCode, String keyword, Long userId, Long tenantId) {
AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId);
- return aiChatMemoryService.listSessions(userId, tenantId, config.getPromptCode());
+ return aiChatMemoryService.listSessions(userId, tenantId, config.getPromptCode(), keyword);
+ }
+
+ @Override
+ public SseEmitter stream(AiChatRequest request, Long userId, Long tenantId) {
+ SseEmitter emitter = new SseEmitter(AiDefaults.SSE_TIMEOUT_MS);
+ CompletableFuture.runAsync(() -> aiChatOrchestrator.executeStream(request, userId, tenantId, emitter), aiChatTaskExecutor);
+ return emitter;
}
@Override
@@ -99,251 +84,22 @@
}
@Override
- public SseEmitter stream(AiChatRequest request, Long userId, Long tenantId) {
- SseEmitter emitter = new SseEmitter(AiDefaults.SSE_TIMEOUT_MS);
- CompletableFuture.runAsync(() -> doStream(request, userId, tenantId, emitter));
- return emitter;
+ public AiChatSessionDto renameSession(Long sessionId, AiChatSessionRenameRequest request, Long userId, Long tenantId) {
+ return aiChatMemoryService.renameSession(userId, tenantId, sessionId, request);
}
- private void doStream(AiChatRequest request, Long userId, Long tenantId, SseEmitter emitter) {
- try {
- AiResolvedConfig config = aiConfigResolverService.resolve(request.getPromptCode(), tenantId);
- AiChatSession session = aiChatMemoryService.resolveSession(userId, tenantId, config.getPromptCode(), request.getSessionId(), resolveTitleSeed(request.getMessages()));
- AiChatMemoryDto memory = aiChatMemoryService.getMemory(userId, tenantId, config.getPromptCode(), session.getId());
- List<AiChatMessageDto> mergedMessages = mergeMessages(memory.getPersistedMessages(), request.getMessages());
- try (McpMountRuntimeFactory.McpMountRuntime runtime = mcpMountRuntimeFactory.create(config.getMcpMounts(), userId)) {
- emit(emitter, "start", AiChatRuntimeDto.builder()
- .sessionId(session.getId())
- .promptCode(config.getPromptCode())
- .promptName(config.getPrompt().getName())
- .model(config.getAiParam().getModel())
- .configuredMcpCount(config.getMcpMounts().size())
- .mountedMcpCount(runtime.getMountedCount())
- .mountedMcpNames(runtime.getMountedNames())
- .mountErrors(runtime.getErrors())
- .persistedMessages(memory.getPersistedMessages())
- .build());
-
- Prompt prompt = new Prompt(
- buildPromptMessages(mergedMessages, config.getPrompt(), request.getMetadata()),
- buildChatOptions(config.getAiParam(), runtime.getToolCallbacks(), userId, request.getMetadata())
- );
- OpenAiChatModel chatModel = createChatModel(config.getAiParam());
- if (Boolean.FALSE.equals(config.getAiParam().getStreamingEnabled())) {
- ChatResponse response = chatModel.call(prompt);
- String content = extractContent(response);
- aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), content);
- if (StringUtils.hasText(content)) {
- emit(emitter, "delta", buildMessagePayload("content", content));
- }
- emitDone(emitter, response.getMetadata(), config.getAiParam().getModel(), session.getId());
- emitter.complete();
- return;
- }
-
- Flux<ChatResponse> responseFlux = chatModel.stream(prompt);
- AtomicReference<ChatResponseMetadata> lastMetadata = new AtomicReference<>();
- StringBuilder assistantContent = new StringBuilder();
- responseFlux.doOnNext(response -> {
- lastMetadata.set(response.getMetadata());
- String content = extractContent(response);
- if (StringUtils.hasText(content)) {
- assistantContent.append(content);
- emit(emitter, "delta", buildMessagePayload("content", content));
- }
- })
- .doOnError(error -> emit(emitter, "error", buildMessagePayload("message", error == null ? "AI 瀵硅瘽澶辫触" : error.getMessage())))
- .blockLast();
- aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), assistantContent.toString());
- emitDone(emitter, lastMetadata.get(), config.getAiParam().getModel(), session.getId());
- emitter.complete();
- }
- } catch (Exception e) {
- log.error("AI stream error", e);
- emit(emitter, "error", buildMessagePayload("message", e == null ? "AI 瀵硅瘽澶辫触" : e.getMessage()));
- emitter.completeWithError(e);
- }
+ @Override
+ public AiChatSessionDto pinSession(Long sessionId, AiChatSessionPinRequest request, Long userId, Long tenantId) {
+ return aiChatMemoryService.pinSession(userId, tenantId, sessionId, request);
}
- private OpenAiChatModel createChatModel(AiParam aiParam) {
- OpenAiApi openAiApi = buildOpenAiApi(aiParam);
- ToolCallingManager toolCallingManager = DefaultToolCallingManager.builder()
- .observationRegistry(observationRegistry)
- .toolCallbackResolver(new SpringBeanToolCallbackResolver(applicationContext, SchemaType.OPEN_API_SCHEMA))
- .toolExecutionExceptionProcessor(new DefaultToolExecutionExceptionProcessor(false))
- .build();
- return new OpenAiChatModel(
- openAiApi,
- OpenAiChatOptions.builder()
- .model(aiParam.getModel())
- .temperature(aiParam.getTemperature())
- .topP(aiParam.getTopP())
- .maxTokens(aiParam.getMaxTokens())
- .streamUsage(true)
- .build(),
- toolCallingManager,
- org.springframework.retry.support.RetryTemplate.builder().maxAttempts(1).build(),
- observationRegistry
- );
+ @Override
+ public void clearSessionMemory(Long sessionId, Long userId, Long tenantId) {
+ aiChatMemoryService.clearSessionMemory(userId, tenantId, sessionId);
}
- private OpenAiApi buildOpenAiApi(AiParam aiParam) {
- int timeoutMs = aiParam.getTimeoutMs() == null ? AiDefaults.DEFAULT_TIMEOUT_MS : aiParam.getTimeoutMs();
- SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory();
- requestFactory.setConnectTimeout(timeoutMs);
- requestFactory.setReadTimeout(timeoutMs);
-
- return OpenAiApi.builder()
- .baseUrl(aiParam.getBaseUrl())
- .apiKey(aiParam.getApiKey())
- .restClientBuilder(RestClient.builder().requestFactory(requestFactory))
- .webClientBuilder(WebClient.builder())
- .build();
- }
-
- private OpenAiChatOptions buildChatOptions(AiParam aiParam, ToolCallback[] toolCallbacks, Long userId, Map<String, Object> metadata) {
- if (userId == null) {
- throw new CoolException("褰撳墠鐧诲綍鐢ㄦ埛涓嶅瓨鍦�");
- }
- OpenAiChatOptions.Builder builder = OpenAiChatOptions.builder()
- .model(aiParam.getModel())
- .temperature(aiParam.getTemperature())
- .topP(aiParam.getTopP())
- .maxTokens(aiParam.getMaxTokens())
- .streamUsage(true)
- .user(String.valueOf(userId));
- if (!Cools.isEmpty(toolCallbacks)) {
- builder.toolCallbacks(Arrays.asList(toolCallbacks));
- }
- Map<String, String> metadataMap = new LinkedHashMap<>();
- if (metadata != null) {
- metadata.forEach((key, value) -> metadataMap.put(key, value == null ? "" : String.valueOf(value)));
- }
- if (!metadataMap.isEmpty()) {
- builder.metadata(metadataMap);
- }
- return builder.build();
- }
-
- private List<Message> buildPromptMessages(List<AiChatMessageDto> sourceMessages, AiPrompt aiPrompt, Map<String, Object> metadata) {
- if (Cools.isEmpty(sourceMessages)) {
- throw new CoolException("瀵硅瘽娑堟伅涓嶈兘涓虹┖");
- }
- List<Message> messages = new ArrayList<>();
- if (StringUtils.hasText(aiPrompt.getSystemPrompt())) {
- messages.add(new SystemMessage(aiPrompt.getSystemPrompt()));
- }
- int lastUserIndex = -1;
- for (int i = 0; i < sourceMessages.size(); i++) {
- AiChatMessageDto item = sourceMessages.get(i);
- if (item != null && "user".equalsIgnoreCase(item.getRole())) {
- lastUserIndex = i;
- }
- }
- for (int i = 0; i < sourceMessages.size(); i++) {
- AiChatMessageDto item = sourceMessages.get(i);
- if (item == null || !StringUtils.hasText(item.getContent())) {
- continue;
- }
- String role = item.getRole() == null ? "user" : item.getRole().toLowerCase();
- if ("system".equals(role)) {
- continue;
- }
- String content = item.getContent();
- if ("user".equals(role) && i == lastUserIndex) {
- content = renderUserPrompt(aiPrompt.getUserPromptTemplate(), content, metadata);
- }
- if ("assistant".equals(role)) {
- messages.add(new AssistantMessage(content));
- } else {
- messages.add(new UserMessage(content));
- }
- }
- if (messages.stream().noneMatch(item -> item instanceof UserMessage)) {
- throw new CoolException("鑷冲皯闇�瑕佷竴鏉$敤鎴锋秷鎭�");
- }
- return messages;
- }
-
- private List<AiChatMessageDto> mergeMessages(List<AiChatMessageDto> persistedMessages, List<AiChatMessageDto> memoryMessages) {
- List<AiChatMessageDto> merged = new ArrayList<>();
- if (!Cools.isEmpty(persistedMessages)) {
- merged.addAll(persistedMessages);
- }
- if (!Cools.isEmpty(memoryMessages)) {
- merged.addAll(memoryMessages);
- }
- if (merged.isEmpty()) {
- throw new CoolException("瀵硅瘽娑堟伅涓嶈兘涓虹┖");
- }
- return merged;
- }
-
- private String resolveTitleSeed(List<AiChatMessageDto> messages) {
- if (Cools.isEmpty(messages)) {
- throw new CoolException("瀵硅瘽娑堟伅涓嶈兘涓虹┖");
- }
- for (int i = messages.size() - 1; i >= 0; i--) {
- AiChatMessageDto item = messages.get(i);
- if (item != null && "user".equalsIgnoreCase(item.getRole()) && StringUtils.hasText(item.getContent())) {
- return item.getContent();
- }
- }
- throw new CoolException("鑷冲皯闇�瑕佷竴鏉$敤鎴锋秷鎭�");
- }
-
- private String renderUserPrompt(String userPromptTemplate, String content, Map<String, Object> metadata) {
- if (!StringUtils.hasText(userPromptTemplate)) {
- return content;
- }
- String rendered = userPromptTemplate
- .replace("{{input}}", content)
- .replace("{input}", content);
- if (metadata != null) {
- for (Map.Entry<String, Object> entry : metadata.entrySet()) {
- String value = entry.getValue() == null ? "" : String.valueOf(entry.getValue());
- rendered = rendered.replace("{{" + entry.getKey() + "}}", value);
- rendered = rendered.replace("{" + entry.getKey() + "}", value);
- }
- }
- if (Objects.equals(rendered, userPromptTemplate)) {
- return userPromptTemplate + "\n\n" + content;
- }
- return rendered;
- }
-
- private String extractContent(ChatResponse response) {
- if (response == null || response.getResult() == null || response.getResult().getOutput() == null) {
- return null;
- }
- return response.getResult().getOutput().getText();
- }
-
- private void emitDone(SseEmitter emitter, ChatResponseMetadata metadata, String fallbackModel, Long sessionId) {
- Usage usage = metadata == null ? null : metadata.getUsage();
- emit(emitter, "done", AiChatDoneDto.builder()
- .sessionId(sessionId)
- .model(metadata != null && StringUtils.hasText(metadata.getModel()) ? metadata.getModel() : fallbackModel)
- .promptTokens(usage == null ? null : usage.getPromptTokens())
- .completionTokens(usage == null ? null : usage.getCompletionTokens())
- .totalTokens(usage == null ? null : usage.getTotalTokens())
- .build());
- }
-
- private Map<String, String> buildMessagePayload(String key, String value) {
- Map<String, String> payload = new LinkedHashMap<>();
- payload.put(key, value == null ? "" : value);
- return payload;
- }
-
- private void emit(SseEmitter emitter, String eventName, Object payload) {
- try {
- String data = objectMapper.writeValueAsString(payload);
- emitter.send(SseEmitter.event()
- .name(eventName)
- .data(data, MediaType.APPLICATION_JSON));
- } catch (IOException e) {
- throw new CoolException("SSE 杈撳嚭澶辫触: " + e.getMessage());
- }
+ @Override
+ public void retainLatestRound(Long sessionId, Long userId, Long tenantId) {
+ aiChatMemoryService.retainLatestRound(userId, tenantId, sessionId);
}
}
--
Gitblit v1.9.1