| | |
| | | import com.vincent.rsf.server.ai.dto.AiChatSessionDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatSessionPinRequest; |
| | | import com.vincent.rsf.server.ai.dto.AiChatSessionRenameRequest; |
| | | import com.vincent.rsf.server.ai.dto.AiChatToolEventDto; |
| | | import com.vincent.rsf.server.ai.dto.AiResolvedConfig; |
| | | import com.vincent.rsf.server.ai.entity.AiParam; |
| | | import com.vincent.rsf.server.ai.entity.AiPrompt; |
| | |
| | | import org.springframework.ai.chat.metadata.ChatResponseMetadata; |
| | | import org.springframework.ai.chat.metadata.Usage; |
| | | import org.springframework.ai.chat.model.ChatResponse; |
| | | import org.springframework.ai.chat.model.ToolContext; |
| | | import org.springframework.ai.chat.prompt.Prompt; |
| | | import org.springframework.ai.model.tool.DefaultToolCallingManager; |
| | | import org.springframework.ai.model.tool.ToolCallingManager; |
| | |
| | | import java.util.concurrent.CompletableFuture; |
| | | import java.util.concurrent.Executor; |
| | | import java.util.concurrent.atomic.AtomicReference; |
| | | import java.util.concurrent.atomic.AtomicLong; |
| | | |
| | | @Slf4j |
| | | @Service |
| | |
| | | .mountedMcpCount(config.getMcpMounts().size()) |
| | | .mountedMcpNames(config.getMcpMounts().stream().map(item -> item.getName()).toList()) |
| | | .mountErrors(List.of()) |
| | | .memorySummary(memory.getMemorySummary()) |
| | | .memoryFacts(memory.getMemoryFacts()) |
| | | .recentMessageCount(memory.getRecentMessageCount()) |
| | | .persistedMessages(memory.getPersistedMessages()) |
| | | .build(); |
| | | } |
| | |
| | | } |
| | | |
| | | @Override |
| | | public void clearSessionMemory(Long sessionId, Long userId, Long tenantId) { |
| | | aiChatMemoryService.clearSessionMemory(userId, tenantId, sessionId); |
| | | } |
| | | |
| | | @Override |
| | | public void retainLatestRound(Long sessionId, Long userId, Long tenantId) { |
| | | aiChatMemoryService.retainLatestRound(userId, tenantId, sessionId); |
| | | } |
| | | |
| | | @Override |
| | | public SseEmitter stream(AiChatRequest request, Long userId, Long tenantId) { |
| | | SseEmitter emitter = new SseEmitter(AiDefaults.SSE_TIMEOUT_MS); |
| | | CompletableFuture.runAsync(() -> doStream(request, userId, tenantId, emitter), aiChatTaskExecutor); |
| | |
| | | String requestId = request.getRequestId(); |
| | | long startedAt = System.currentTimeMillis(); |
| | | AtomicReference<Long> firstTokenAtRef = new AtomicReference<>(); |
| | | AtomicLong toolCallSequence = new AtomicLong(0); |
| | | Long sessionId = request.getSessionId(); |
| | | String model = null; |
| | | try { |
| | |
| | | AiChatSession session = resolveSession(request, userId, tenantId, config.getPromptCode()); |
| | | sessionId = session.getId(); |
| | | AiChatMemoryDto memory = loadMemory(userId, tenantId, config.getPromptCode(), session.getId()); |
| | | List<AiChatMessageDto> mergedMessages = mergeMessages(memory.getPersistedMessages(), request.getMessages()); |
| | | List<AiChatMessageDto> mergedMessages = mergeMessages(memory.getShortMemoryMessages(), request.getMessages()); |
| | | try (McpMountRuntimeFactory.McpMountRuntime runtime = createRuntime(config, userId)) { |
| | | emitStrict(emitter, "start", AiChatRuntimeDto.builder() |
| | | .requestId(requestId) |
| | |
| | | .mountedMcpCount(runtime.getMountedCount()) |
| | | .mountedMcpNames(runtime.getMountedNames()) |
| | | .mountErrors(runtime.getErrors()) |
| | | .memorySummary(memory.getMemorySummary()) |
| | | .memoryFacts(memory.getMemoryFacts()) |
| | | .recentMessageCount(memory.getRecentMessageCount()) |
| | | .persistedMessages(memory.getPersistedMessages()) |
| | | .build()); |
| | | emitSafely(emitter, "status", AiChatStatusDto.builder() |
| | |
| | | log.info("AI chat started, requestId={}, userId={}, tenantId={}, sessionId={}, model={}", |
| | | requestId, userId, tenantId, session.getId(), resolvedModel); |
| | | |
| | | ToolCallback[] observableToolCallbacks = wrapToolCallbacks( |
| | | runtime.getToolCallbacks(), emitter, requestId, session.getId(), toolCallSequence |
| | | ); |
| | | Prompt prompt = new Prompt( |
| | | buildPromptMessages(mergedMessages, config.getPrompt(), request.getMetadata()), |
| | | buildChatOptions(config.getAiParam(), runtime.getToolCallbacks(), userId, request.getMetadata()) |
| | | buildPromptMessages(memory, mergedMessages, config.getPrompt(), request.getMetadata()), |
| | | buildChatOptions(config.getAiParam(), observableToolCallbacks, userId, tenantId, |
| | | requestId, session.getId(), request.getMetadata()) |
| | | ); |
| | | OpenAiChatModel chatModel = createChatModel(config.getAiParam()); |
| | | if (Boolean.FALSE.equals(config.getAiParam().getStreamingEnabled())) { |
| | |
| | | .build(); |
| | | } |
| | | |
| | | private OpenAiChatOptions buildChatOptions(AiParam aiParam, ToolCallback[] toolCallbacks, Long userId, Map<String, Object> metadata) { |
| | | private OpenAiChatOptions buildChatOptions(AiParam aiParam, ToolCallback[] toolCallbacks, Long userId, Long tenantId, |
| | | String requestId, Long sessionId, Map<String, Object> metadata) { |
| | | if (userId == null) { |
| | | throw buildAiException("AI_AUTH_USER_MISSING", AiErrorCategory.AUTH, "OPTIONS_BUILD", "当前登录用户不存在", null); |
| | | } |
| | |
| | | .streamUsage(true) |
| | | .user(String.valueOf(userId)); |
| | | if (!Cools.isEmpty(toolCallbacks)) { |
| | | builder.toolCallbacks(Arrays.asList(toolCallbacks)); |
| | | builder.toolCallbacks(Arrays.stream(toolCallbacks).toList()); |
| | | } |
| | | Map<String, Object> toolContext = new LinkedHashMap<>(); |
| | | toolContext.put("userId", userId); |
| | | toolContext.put("tenantId", tenantId); |
| | | toolContext.put("requestId", requestId); |
| | | toolContext.put("sessionId", sessionId); |
| | | Map<String, String> metadataMap = new LinkedHashMap<>(); |
| | | if (metadata != null) { |
| | | metadata.forEach((key, value) -> metadataMap.put(key, value == null ? "" : String.valueOf(value))); |
| | | metadata.forEach((key, value) -> { |
| | | String normalized = value == null ? "" : String.valueOf(value); |
| | | metadataMap.put(key, normalized); |
| | | toolContext.put(key, normalized); |
| | | }); |
| | | } |
| | | builder.toolContext(toolContext); |
| | | if (!metadataMap.isEmpty()) { |
| | | builder.metadata(metadataMap); |
| | | } |
| | | return builder.build(); |
| | | } |
| | | |
| | | private List<Message> buildPromptMessages(List<AiChatMessageDto> sourceMessages, AiPrompt aiPrompt, Map<String, Object> metadata) { |
| | | private ToolCallback[] wrapToolCallbacks(ToolCallback[] toolCallbacks, SseEmitter emitter, String requestId, |
| | | Long sessionId, AtomicLong toolCallSequence) { |
| | | if (Cools.isEmpty(toolCallbacks)) { |
| | | return toolCallbacks; |
| | | } |
| | | List<ToolCallback> wrappedCallbacks = new ArrayList<>(); |
| | | for (ToolCallback callback : toolCallbacks) { |
| | | if (callback == null) { |
| | | continue; |
| | | } |
| | | wrappedCallbacks.add(new ObservableToolCallback(callback, emitter, requestId, sessionId, toolCallSequence)); |
| | | } |
| | | return wrappedCallbacks.toArray(new ToolCallback[0]); |
| | | } |
| | | |
| | | private List<Message> buildPromptMessages(AiChatMemoryDto memory, List<AiChatMessageDto> sourceMessages, AiPrompt aiPrompt, Map<String, Object> metadata) { |
| | | if (Cools.isEmpty(sourceMessages)) { |
| | | throw new CoolException("对话消息不能为空"); |
| | | } |
| | | List<Message> messages = new ArrayList<>(); |
| | | if (StringUtils.hasText(aiPrompt.getSystemPrompt())) { |
| | | messages.add(new SystemMessage(aiPrompt.getSystemPrompt())); |
| | | } |
| | | if (memory != null && StringUtils.hasText(memory.getMemorySummary())) { |
| | | messages.add(new SystemMessage("历史摘要:\n" + memory.getMemorySummary())); |
| | | } |
| | | if (memory != null && StringUtils.hasText(memory.getMemoryFacts())) { |
| | | messages.add(new SystemMessage("关键事实:\n" + memory.getMemoryFacts())); |
| | | } |
| | | int lastUserIndex = -1; |
| | | for (int i = 0; i < sourceMessages.size(); i++) { |
| | |
| | | return response.getResult().getOutput().getText(); |
| | | } |
| | | |
| | | private String summarizeToolPayload(String content, int maxLength) { |
| | | if (!StringUtils.hasText(content)) { |
| | | return null; |
| | | } |
| | | String normalized = content.trim() |
| | | .replace("\r", " ") |
| | | .replace("\n", " ") |
| | | .replaceAll("\\s+", " "); |
| | | return normalized.length() > maxLength ? normalized.substring(0, maxLength) : normalized; |
| | | } |
| | | |
| | | private void emitDone(SseEmitter emitter, String requestId, ChatResponseMetadata metadata, String fallbackModel, Long sessionId, long startedAt, Long firstTokenAt) { |
| | | Usage usage = metadata == null ? null : metadata.getUsage(); |
| | | emitStrict(emitter, "done", AiChatDoneDto.builder() |
| | |
| | | } |
| | | return false; |
| | | } |
| | | |
| | | private class ObservableToolCallback implements ToolCallback { |
| | | |
| | | private final ToolCallback delegate; |
| | | private final SseEmitter emitter; |
| | | private final String requestId; |
| | | private final Long sessionId; |
| | | private final AtomicLong toolCallSequence; |
| | | |
| | | private ObservableToolCallback(ToolCallback delegate, SseEmitter emitter, String requestId, |
| | | Long sessionId, AtomicLong toolCallSequence) { |
| | | this.delegate = delegate; |
| | | this.emitter = emitter; |
| | | this.requestId = requestId; |
| | | this.sessionId = sessionId; |
| | | this.toolCallSequence = toolCallSequence; |
| | | } |
| | | |
| | | @Override |
| | | public org.springframework.ai.tool.definition.ToolDefinition getToolDefinition() { |
| | | return delegate.getToolDefinition(); |
| | | } |
| | | |
| | | @Override |
| | | public org.springframework.ai.tool.metadata.ToolMetadata getToolMetadata() { |
| | | return delegate.getToolMetadata(); |
| | | } |
| | | |
| | | @Override |
| | | public String call(String toolInput) { |
| | | return call(toolInput, null); |
| | | } |
| | | |
| | | @Override |
| | | public String call(String toolInput, ToolContext toolContext) { |
| | | String toolName = delegate.getToolDefinition() == null ? "unknown" : delegate.getToolDefinition().name(); |
| | | String toolCallId = requestId + "-tool-" + toolCallSequence.incrementAndGet(); |
| | | long startedAt = System.currentTimeMillis(); |
| | | emitSafely(emitter, "tool_start", AiChatToolEventDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .toolCallId(toolCallId) |
| | | .toolName(toolName) |
| | | .status("STARTED") |
| | | .inputSummary(summarizeToolPayload(toolInput, 400)) |
| | | .timestamp(startedAt) |
| | | .build()); |
| | | try { |
| | | String output = toolContext == null ? delegate.call(toolInput) : delegate.call(toolInput, toolContext); |
| | | emitSafely(emitter, "tool_result", AiChatToolEventDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .toolCallId(toolCallId) |
| | | .toolName(toolName) |
| | | .status("COMPLETED") |
| | | .inputSummary(summarizeToolPayload(toolInput, 400)) |
| | | .outputSummary(summarizeToolPayload(output, 600)) |
| | | .durationMs(System.currentTimeMillis() - startedAt) |
| | | .timestamp(System.currentTimeMillis()) |
| | | .build()); |
| | | return output; |
| | | } catch (RuntimeException e) { |
| | | emitSafely(emitter, "tool_error", AiChatToolEventDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .toolCallId(toolCallId) |
| | | .toolName(toolName) |
| | | .status("FAILED") |
| | | .inputSummary(summarizeToolPayload(toolInput, 400)) |
| | | .errorMessage(e.getMessage()) |
| | | .durationMs(System.currentTimeMillis() - startedAt) |
| | | .timestamp(System.currentTimeMillis()) |
| | | .build()); |
| | | throw e; |
| | | } |
| | | } |
| | | } |
| | | } |