| | |
| | | package com.vincent.rsf.server.ai.service.impl; |
| | | |
| | | import com.fasterxml.jackson.databind.ObjectMapper; |
| | | import com.vincent.rsf.framework.common.Cools; |
| | | import com.vincent.rsf.framework.exception.CoolException; |
| | | import com.vincent.rsf.server.ai.config.AiDefaults; |
| | | import com.vincent.rsf.server.ai.dto.AiChatDoneDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatMemoryDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatMessageDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatModelOptionDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatRequest; |
| | | import com.vincent.rsf.server.ai.dto.AiChatRuntimeDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatSessionDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatSessionPinRequest; |
| | | import com.vincent.rsf.server.ai.dto.AiChatSessionRenameRequest; |
| | | import com.vincent.rsf.server.ai.dto.AiResolvedConfig; |
| | | import com.vincent.rsf.server.ai.entity.AiParam; |
| | | import com.vincent.rsf.server.ai.entity.AiPrompt; |
| | | import com.vincent.rsf.server.ai.entity.AiChatSession; |
| | | import com.vincent.rsf.server.ai.service.AiChatService; |
| | | import com.vincent.rsf.server.ai.service.AiChatMemoryService; |
| | | import com.vincent.rsf.server.ai.service.AiChatService; |
| | | import com.vincent.rsf.server.ai.service.AiConfigResolverService; |
| | | import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory; |
| | | import io.micrometer.observation.ObservationRegistry; |
| | | import com.vincent.rsf.server.ai.service.AiParamService; |
| | | import com.vincent.rsf.server.ai.service.impl.chat.AiChatOrchestrator; |
| | | import com.vincent.rsf.server.ai.service.impl.chat.AiChatRuntimeAssembler; |
| | | import com.vincent.rsf.server.ai.store.AiConversationCacheStore; |
| | | import lombok.RequiredArgsConstructor; |
| | | import lombok.extern.slf4j.Slf4j; |
| | | import org.springframework.ai.chat.messages.AssistantMessage; |
| | | import org.springframework.ai.chat.messages.Message; |
| | | import org.springframework.ai.chat.messages.SystemMessage; |
| | | import org.springframework.ai.chat.messages.UserMessage; |
| | | import org.springframework.ai.chat.metadata.ChatResponseMetadata; |
| | | import org.springframework.ai.chat.metadata.Usage; |
| | | import org.springframework.ai.chat.model.ChatResponse; |
| | | import org.springframework.ai.chat.prompt.Prompt; |
| | | import org.springframework.ai.model.tool.DefaultToolCallingManager; |
| | | import org.springframework.ai.model.tool.ToolCallingManager; |
| | | import org.springframework.ai.openai.OpenAiChatModel; |
| | | import org.springframework.ai.openai.OpenAiChatOptions; |
| | | import org.springframework.ai.openai.api.OpenAiApi; |
| | | import org.springframework.ai.tool.ToolCallback; |
| | | import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor; |
| | | import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver; |
| | | import org.springframework.ai.util.json.schema.SchemaType; |
| | | import org.springframework.context.support.GenericApplicationContext; |
| | | import org.springframework.http.MediaType; |
| | | import org.springframework.http.client.SimpleClientHttpRequestFactory; |
| | | import org.springframework.beans.factory.annotation.Qualifier; |
| | | import org.springframework.stereotype.Service; |
| | | import org.springframework.util.StringUtils; |
| | | import org.springframework.web.client.RestClient; |
| | | import org.springframework.web.reactive.function.client.WebClient; |
| | | import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; |
| | | import reactor.core.publisher.Flux; |
| | | |
| | | import java.io.IOException; |
| | | import java.util.ArrayList; |
| | | import java.util.Arrays; |
| | | import java.util.LinkedHashMap; |
| | | import java.util.List; |
| | | import java.util.Map; |
| | | import java.util.Objects; |
| | | import java.util.concurrent.CompletableFuture; |
| | | import java.util.concurrent.atomic.AtomicReference; |
| | | import java.util.concurrent.Executor; |
| | | |
| | | @Slf4j |
| | | @Service |
| | | @RequiredArgsConstructor |
| | | public class AiChatServiceImpl implements AiChatService { |
| | | |
| | | private final AiConfigResolverService aiConfigResolverService; |
| | | private final AiChatMemoryService aiChatMemoryService; |
| | | private final McpMountRuntimeFactory mcpMountRuntimeFactory; |
| | | private final GenericApplicationContext applicationContext; |
| | | private final ObservationRegistry observationRegistry; |
| | | private final ObjectMapper objectMapper; |
| | | private final AiParamService aiParamService; |
| | | private final AiConversationCacheStore aiConversationCacheStore; |
| | | private final AiChatRuntimeAssembler aiChatRuntimeAssembler; |
| | | private final AiChatOrchestrator aiChatOrchestrator; |
| | | @Qualifier("aiChatTaskExecutor") |
| | | private final Executor aiChatTaskExecutor; |
| | | |
| | | @Override |
| | | public AiChatRuntimeDto getRuntime(String promptCode, Long sessionId, Long userId, Long tenantId) { |
| | | AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId); |
| | | public AiChatRuntimeDto getRuntime(String promptCode, Long sessionId, Long aiParamId, Long userId, Long tenantId) { |
| | | AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId, aiParamId); |
| | | AiChatRuntimeDto cached = aiConversationCacheStore.getRuntime(tenantId, userId, config.getPromptCode(), sessionId, aiParamId); |
| | | if (cached != null) { |
| | | return cached; |
| | | } |
| | | AiChatMemoryDto memory = aiChatMemoryService.getMemory(userId, tenantId, config.getPromptCode(), sessionId); |
| | | return AiChatRuntimeDto.builder() |
| | | .sessionId(memory.getSessionId()) |
| | | .promptCode(config.getPromptCode()) |
| | | .promptName(config.getPrompt().getName()) |
| | | .model(config.getAiParam().getModel()) |
| | | .configuredMcpCount(config.getMcpMounts().size()) |
| | | .mountedMcpCount(config.getMcpMounts().size()) |
| | | .mountedMcpNames(config.getMcpMounts().stream().map(item -> item.getName()).toList()) |
| | | .mountErrors(List.of()) |
| | | .persistedMessages(memory.getPersistedMessages()) |
| | | .build(); |
| | | List<AiChatModelOptionDto> modelOptions = aiParamService.listChatModelOptions(tenantId); |
| | | AiChatRuntimeDto runtime = aiChatRuntimeAssembler.buildRuntimeSnapshot( |
| | | null, |
| | | memory.getSessionId(), |
| | | config, |
| | | modelOptions, |
| | | config.getMcpMounts().size(), |
| | | config.getMcpMounts().stream().map(item -> item.getName()).toList(), |
| | | List.of(), |
| | | memory |
| | | ); |
| | | aiConversationCacheStore.cacheRuntime(tenantId, userId, config.getPromptCode(), sessionId, aiParamId, runtime); |
| | | if (memory.getSessionId() != null && !Objects.equals(memory.getSessionId(), sessionId)) { |
| | | aiConversationCacheStore.cacheRuntime(tenantId, userId, config.getPromptCode(), memory.getSessionId(), aiParamId, runtime); |
| | | } |
| | | return runtime; |
| | | } |
| | | |
| | | @Override |
| | | public List<AiChatSessionDto> listSessions(String promptCode, Long userId, Long tenantId) { |
| | | public List<AiChatSessionDto> listSessions(String promptCode, String keyword, Long userId, Long tenantId) { |
| | | AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId); |
| | | return aiChatMemoryService.listSessions(userId, tenantId, config.getPromptCode()); |
| | | return aiChatMemoryService.listSessions(userId, tenantId, config.getPromptCode(), keyword); |
| | | } |
| | | |
| | | @Override |
| | | public SseEmitter stream(AiChatRequest request, Long userId, Long tenantId) { |
| | | SseEmitter emitter = new SseEmitter(AiDefaults.SSE_TIMEOUT_MS); |
| | | CompletableFuture.runAsync(() -> aiChatOrchestrator.executeStream(request, userId, tenantId, emitter), aiChatTaskExecutor); |
| | | return emitter; |
| | | } |
| | | |
| | | @Override |
| | |
| | | } |
| | | |
| | | @Override |
| | | public SseEmitter stream(AiChatRequest request, Long userId, Long tenantId) { |
| | | SseEmitter emitter = new SseEmitter(AiDefaults.SSE_TIMEOUT_MS); |
| | | CompletableFuture.runAsync(() -> doStream(request, userId, tenantId, emitter)); |
| | | return emitter; |
| | | public AiChatSessionDto renameSession(Long sessionId, AiChatSessionRenameRequest request, Long userId, Long tenantId) { |
| | | return aiChatMemoryService.renameSession(userId, tenantId, sessionId, request); |
| | | } |
| | | |
| | | private void doStream(AiChatRequest request, Long userId, Long tenantId, SseEmitter emitter) { |
| | | try { |
| | | AiResolvedConfig config = aiConfigResolverService.resolve(request.getPromptCode(), tenantId); |
| | | AiChatSession session = aiChatMemoryService.resolveSession(userId, tenantId, config.getPromptCode(), request.getSessionId(), resolveTitleSeed(request.getMessages())); |
| | | AiChatMemoryDto memory = aiChatMemoryService.getMemory(userId, tenantId, config.getPromptCode(), session.getId()); |
| | | List<AiChatMessageDto> mergedMessages = mergeMessages(memory.getPersistedMessages(), request.getMessages()); |
| | | try (McpMountRuntimeFactory.McpMountRuntime runtime = mcpMountRuntimeFactory.create(config.getMcpMounts(), userId)) { |
| | | emit(emitter, "start", AiChatRuntimeDto.builder() |
| | | .sessionId(session.getId()) |
| | | .promptCode(config.getPromptCode()) |
| | | .promptName(config.getPrompt().getName()) |
| | | .model(config.getAiParam().getModel()) |
| | | .configuredMcpCount(config.getMcpMounts().size()) |
| | | .mountedMcpCount(runtime.getMountedCount()) |
| | | .mountedMcpNames(runtime.getMountedNames()) |
| | | .mountErrors(runtime.getErrors()) |
| | | .persistedMessages(memory.getPersistedMessages()) |
| | | .build()); |
| | | |
| | | Prompt prompt = new Prompt( |
| | | buildPromptMessages(mergedMessages, config.getPrompt(), request.getMetadata()), |
| | | buildChatOptions(config.getAiParam(), runtime.getToolCallbacks(), userId, request.getMetadata()) |
| | | ); |
| | | OpenAiChatModel chatModel = createChatModel(config.getAiParam()); |
| | | if (Boolean.FALSE.equals(config.getAiParam().getStreamingEnabled())) { |
| | | ChatResponse response = chatModel.call(prompt); |
| | | String content = extractContent(response); |
| | | aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), content); |
| | | if (StringUtils.hasText(content)) { |
| | | emit(emitter, "delta", buildMessagePayload("content", content)); |
| | | } |
| | | emitDone(emitter, response.getMetadata(), config.getAiParam().getModel(), session.getId()); |
| | | emitter.complete(); |
| | | return; |
| | | } |
| | | |
| | | Flux<ChatResponse> responseFlux = chatModel.stream(prompt); |
| | | AtomicReference<ChatResponseMetadata> lastMetadata = new AtomicReference<>(); |
| | | StringBuilder assistantContent = new StringBuilder(); |
| | | responseFlux.doOnNext(response -> { |
| | | lastMetadata.set(response.getMetadata()); |
| | | String content = extractContent(response); |
| | | if (StringUtils.hasText(content)) { |
| | | assistantContent.append(content); |
| | | emit(emitter, "delta", buildMessagePayload("content", content)); |
| | | } |
| | | }) |
| | | .doOnError(error -> emit(emitter, "error", buildMessagePayload("message", error == null ? "AI 对话失败" : error.getMessage()))) |
| | | .blockLast(); |
| | | aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), assistantContent.toString()); |
| | | emitDone(emitter, lastMetadata.get(), config.getAiParam().getModel(), session.getId()); |
| | | emitter.complete(); |
| | | } |
| | | } catch (Exception e) { |
| | | log.error("AI stream error", e); |
| | | emit(emitter, "error", buildMessagePayload("message", e == null ? "AI 对话失败" : e.getMessage())); |
| | | emitter.completeWithError(e); |
| | | } |
| | | @Override |
| | | public AiChatSessionDto pinSession(Long sessionId, AiChatSessionPinRequest request, Long userId, Long tenantId) { |
| | | return aiChatMemoryService.pinSession(userId, tenantId, sessionId, request); |
| | | } |
| | | |
| | | private OpenAiChatModel createChatModel(AiParam aiParam) { |
| | | OpenAiApi openAiApi = buildOpenAiApi(aiParam); |
| | | ToolCallingManager toolCallingManager = DefaultToolCallingManager.builder() |
| | | .observationRegistry(observationRegistry) |
| | | .toolCallbackResolver(new SpringBeanToolCallbackResolver(applicationContext, SchemaType.OPEN_API_SCHEMA)) |
| | | .toolExecutionExceptionProcessor(new DefaultToolExecutionExceptionProcessor(false)) |
| | | .build(); |
| | | return new OpenAiChatModel( |
| | | openAiApi, |
| | | OpenAiChatOptions.builder() |
| | | .model(aiParam.getModel()) |
| | | .temperature(aiParam.getTemperature()) |
| | | .topP(aiParam.getTopP()) |
| | | .maxTokens(aiParam.getMaxTokens()) |
| | | .streamUsage(true) |
| | | .build(), |
| | | toolCallingManager, |
| | | org.springframework.retry.support.RetryTemplate.builder().maxAttempts(1).build(), |
| | | observationRegistry |
| | | ); |
| | | @Override |
| | | public void clearSessionMemory(Long sessionId, Long userId, Long tenantId) { |
| | | aiChatMemoryService.clearSessionMemory(userId, tenantId, sessionId); |
| | | } |
| | | |
| | | private OpenAiApi buildOpenAiApi(AiParam aiParam) { |
| | | int timeoutMs = aiParam.getTimeoutMs() == null ? AiDefaults.DEFAULT_TIMEOUT_MS : aiParam.getTimeoutMs(); |
| | | SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory(); |
| | | requestFactory.setConnectTimeout(timeoutMs); |
| | | requestFactory.setReadTimeout(timeoutMs); |
| | | |
| | | return OpenAiApi.builder() |
| | | .baseUrl(aiParam.getBaseUrl()) |
| | | .apiKey(aiParam.getApiKey()) |
| | | .restClientBuilder(RestClient.builder().requestFactory(requestFactory)) |
| | | .webClientBuilder(WebClient.builder()) |
| | | .build(); |
| | | } |
| | | |
| | | private OpenAiChatOptions buildChatOptions(AiParam aiParam, ToolCallback[] toolCallbacks, Long userId, Map<String, Object> metadata) { |
| | | if (userId == null) { |
| | | throw new CoolException("当前登录用户不存在"); |
| | | } |
| | | OpenAiChatOptions.Builder builder = OpenAiChatOptions.builder() |
| | | .model(aiParam.getModel()) |
| | | .temperature(aiParam.getTemperature()) |
| | | .topP(aiParam.getTopP()) |
| | | .maxTokens(aiParam.getMaxTokens()) |
| | | .streamUsage(true) |
| | | .user(String.valueOf(userId)); |
| | | if (!Cools.isEmpty(toolCallbacks)) { |
| | | builder.toolCallbacks(Arrays.asList(toolCallbacks)); |
| | | } |
| | | Map<String, String> metadataMap = new LinkedHashMap<>(); |
| | | if (metadata != null) { |
| | | metadata.forEach((key, value) -> metadataMap.put(key, value == null ? "" : String.valueOf(value))); |
| | | } |
| | | if (!metadataMap.isEmpty()) { |
| | | builder.metadata(metadataMap); |
| | | } |
| | | return builder.build(); |
| | | } |
| | | |
| | | private List<Message> buildPromptMessages(List<AiChatMessageDto> sourceMessages, AiPrompt aiPrompt, Map<String, Object> metadata) { |
| | | if (Cools.isEmpty(sourceMessages)) { |
| | | throw new CoolException("对话消息不能为空"); |
| | | } |
| | | List<Message> messages = new ArrayList<>(); |
| | | if (StringUtils.hasText(aiPrompt.getSystemPrompt())) { |
| | | messages.add(new SystemMessage(aiPrompt.getSystemPrompt())); |
| | | } |
| | | int lastUserIndex = -1; |
| | | for (int i = 0; i < sourceMessages.size(); i++) { |
| | | AiChatMessageDto item = sourceMessages.get(i); |
| | | if (item != null && "user".equalsIgnoreCase(item.getRole())) { |
| | | lastUserIndex = i; |
| | | } |
| | | } |
| | | for (int i = 0; i < sourceMessages.size(); i++) { |
| | | AiChatMessageDto item = sourceMessages.get(i); |
| | | if (item == null || !StringUtils.hasText(item.getContent())) { |
| | | continue; |
| | | } |
| | | String role = item.getRole() == null ? "user" : item.getRole().toLowerCase(); |
| | | if ("system".equals(role)) { |
| | | continue; |
| | | } |
| | | String content = item.getContent(); |
| | | if ("user".equals(role) && i == lastUserIndex) { |
| | | content = renderUserPrompt(aiPrompt.getUserPromptTemplate(), content, metadata); |
| | | } |
| | | if ("assistant".equals(role)) { |
| | | messages.add(new AssistantMessage(content)); |
| | | } else { |
| | | messages.add(new UserMessage(content)); |
| | | } |
| | | } |
| | | if (messages.stream().noneMatch(item -> item instanceof UserMessage)) { |
| | | throw new CoolException("至少需要一条用户消息"); |
| | | } |
| | | return messages; |
| | | } |
| | | |
| | | private List<AiChatMessageDto> mergeMessages(List<AiChatMessageDto> persistedMessages, List<AiChatMessageDto> memoryMessages) { |
| | | List<AiChatMessageDto> merged = new ArrayList<>(); |
| | | if (!Cools.isEmpty(persistedMessages)) { |
| | | merged.addAll(persistedMessages); |
| | | } |
| | | if (!Cools.isEmpty(memoryMessages)) { |
| | | merged.addAll(memoryMessages); |
| | | } |
| | | if (merged.isEmpty()) { |
| | | throw new CoolException("对话消息不能为空"); |
| | | } |
| | | return merged; |
| | | } |
| | | |
| | | private String resolveTitleSeed(List<AiChatMessageDto> messages) { |
| | | if (Cools.isEmpty(messages)) { |
| | | throw new CoolException("对话消息不能为空"); |
| | | } |
| | | for (int i = messages.size() - 1; i >= 0; i--) { |
| | | AiChatMessageDto item = messages.get(i); |
| | | if (item != null && "user".equalsIgnoreCase(item.getRole()) && StringUtils.hasText(item.getContent())) { |
| | | return item.getContent(); |
| | | } |
| | | } |
| | | throw new CoolException("至少需要一条用户消息"); |
| | | } |
| | | |
| | | private String renderUserPrompt(String userPromptTemplate, String content, Map<String, Object> metadata) { |
| | | if (!StringUtils.hasText(userPromptTemplate)) { |
| | | return content; |
| | | } |
| | | String rendered = userPromptTemplate |
| | | .replace("{{input}}", content) |
| | | .replace("{input}", content); |
| | | if (metadata != null) { |
| | | for (Map.Entry<String, Object> entry : metadata.entrySet()) { |
| | | String value = entry.getValue() == null ? "" : String.valueOf(entry.getValue()); |
| | | rendered = rendered.replace("{{" + entry.getKey() + "}}", value); |
| | | rendered = rendered.replace("{" + entry.getKey() + "}", value); |
| | | } |
| | | } |
| | | if (Objects.equals(rendered, userPromptTemplate)) { |
| | | return userPromptTemplate + "\n\n" + content; |
| | | } |
| | | return rendered; |
| | | } |
| | | |
| | | private String extractContent(ChatResponse response) { |
| | | if (response == null || response.getResult() == null || response.getResult().getOutput() == null) { |
| | | return null; |
| | | } |
| | | return response.getResult().getOutput().getText(); |
| | | } |
| | | |
| | | private void emitDone(SseEmitter emitter, ChatResponseMetadata metadata, String fallbackModel, Long sessionId) { |
| | | Usage usage = metadata == null ? null : metadata.getUsage(); |
| | | emit(emitter, "done", AiChatDoneDto.builder() |
| | | .sessionId(sessionId) |
| | | .model(metadata != null && StringUtils.hasText(metadata.getModel()) ? metadata.getModel() : fallbackModel) |
| | | .promptTokens(usage == null ? null : usage.getPromptTokens()) |
| | | .completionTokens(usage == null ? null : usage.getCompletionTokens()) |
| | | .totalTokens(usage == null ? null : usage.getTotalTokens()) |
| | | .build()); |
| | | } |
| | | |
| | | private Map<String, String> buildMessagePayload(String key, String value) { |
| | | Map<String, String> payload = new LinkedHashMap<>(); |
| | | payload.put(key, value == null ? "" : value); |
| | | return payload; |
| | | } |
| | | |
| | | private void emit(SseEmitter emitter, String eventName, Object payload) { |
| | | try { |
| | | String data = objectMapper.writeValueAsString(payload); |
| | | emitter.send(SseEmitter.event() |
| | | .name(eventName) |
| | | .data(data, MediaType.APPLICATION_JSON)); |
| | | } catch (IOException e) { |
| | | throw new CoolException("SSE 输出失败: " + e.getMessage()); |
| | | } |
| | | @Override |
| | | public void retainLatestRound(Long sessionId, Long userId, Long tenantId) { |
| | | aiChatMemoryService.retainLatestRound(userId, tenantId, sessionId); |
| | | } |
| | | } |