From b0728aba5c01842e24da3cff04e44be06c6bb655 Mon Sep 17 00:00:00 2001
From: zhou zhou <3272660260@qq.com>
Date: 星期四, 19 三月 2026 13:38:38 +0800
Subject: [PATCH] #AI.去除多余mcp

---
 rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java |  278 +++++++++++++++++++++++++++++++++++++++++++++++++++++--
 1 files changed, 266 insertions(+), 12 deletions(-)

diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java
index 05dc09b..8a784ea 100644
--- a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java
@@ -12,15 +12,21 @@
 import com.vincent.rsf.server.ai.dto.AiChatRuntimeDto;
 import com.vincent.rsf.server.ai.dto.AiChatStatusDto;
 import com.vincent.rsf.server.ai.dto.AiChatSessionDto;
+import com.vincent.rsf.server.ai.dto.AiChatSessionPinRequest;
+import com.vincent.rsf.server.ai.dto.AiChatSessionRenameRequest;
+import com.vincent.rsf.server.ai.dto.AiChatToolEventDto;
 import com.vincent.rsf.server.ai.dto.AiResolvedConfig;
+import com.vincent.rsf.server.ai.entity.AiCallLog;
 import com.vincent.rsf.server.ai.entity.AiParam;
 import com.vincent.rsf.server.ai.entity.AiPrompt;
 import com.vincent.rsf.server.ai.entity.AiChatSession;
 import com.vincent.rsf.server.ai.enums.AiErrorCategory;
 import com.vincent.rsf.server.ai.exception.AiChatException;
+import com.vincent.rsf.server.ai.service.AiCallLogService;
 import com.vincent.rsf.server.ai.service.AiChatService;
 import com.vincent.rsf.server.ai.service.AiChatMemoryService;
 import com.vincent.rsf.server.ai.service.AiConfigResolverService;
+import com.vincent.rsf.server.ai.service.MountedToolCallback;
 import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory;
 import io.micrometer.observation.ObservationRegistry;
 import lombok.RequiredArgsConstructor;
@@ -32,6 +38,7 @@
 import org.springframework.ai.chat.metadata.ChatResponseMetadata;
 import org.springframework.ai.chat.metadata.Usage;
 import org.springframework.ai.chat.model.ChatResponse;
+import org.springframework.ai.chat.model.ToolContext;
 import org.springframework.ai.chat.prompt.Prompt;
 import org.springframework.ai.model.tool.DefaultToolCallingManager;
 import org.springframework.ai.model.tool.ToolCallingManager;
@@ -64,6 +71,7 @@
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.Executor;
 import java.util.concurrent.atomic.AtomicReference;
+import java.util.concurrent.atomic.AtomicLong;
 
 @Slf4j
 @Service
@@ -73,6 +81,7 @@
     private final AiConfigResolverService aiConfigResolverService;
     private final AiChatMemoryService aiChatMemoryService;
     private final McpMountRuntimeFactory mcpMountRuntimeFactory;
+    private final AiCallLogService aiCallLogService;
     private final GenericApplicationContext applicationContext;
     private final ObservationRegistry observationRegistry;
     private final ObjectMapper objectMapper;
@@ -93,19 +102,42 @@
                 .mountedMcpCount(config.getMcpMounts().size())
                 .mountedMcpNames(config.getMcpMounts().stream().map(item -> item.getName()).toList())
                 .mountErrors(List.of())
+                .memorySummary(memory.getMemorySummary())
+                .memoryFacts(memory.getMemoryFacts())
+                .recentMessageCount(memory.getRecentMessageCount())
                 .persistedMessages(memory.getPersistedMessages())
                 .build();
     }
 
     @Override
-    public List<AiChatSessionDto> listSessions(String promptCode, Long userId, Long tenantId) {
+    public List<AiChatSessionDto> listSessions(String promptCode, String keyword, Long userId, Long tenantId) {
         AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId);
-        return aiChatMemoryService.listSessions(userId, tenantId, config.getPromptCode());
+        return aiChatMemoryService.listSessions(userId, tenantId, config.getPromptCode(), keyword);
     }
 
     @Override
     public void removeSession(Long sessionId, Long userId, Long tenantId) {
         aiChatMemoryService.removeSession(userId, tenantId, sessionId);
+    }
+
+    @Override
+    public AiChatSessionDto renameSession(Long sessionId, AiChatSessionRenameRequest request, Long userId, Long tenantId) {
+        return aiChatMemoryService.renameSession(userId, tenantId, sessionId, request);
+    }
+
+    @Override
+    public AiChatSessionDto pinSession(Long sessionId, AiChatSessionPinRequest request, Long userId, Long tenantId) {
+        return aiChatMemoryService.pinSession(userId, tenantId, sessionId, request);
+    }
+
+    @Override
+    public void clearSessionMemory(Long sessionId, Long userId, Long tenantId) {
+        aiChatMemoryService.clearSessionMemory(userId, tenantId, sessionId);
+    }
+
+    @Override
+    public void retainLatestRound(Long sessionId, Long userId, Long tenantId) {
+        aiChatMemoryService.retainLatestRound(userId, tenantId, sessionId);
     }
 
     @Override
@@ -119,7 +151,11 @@
         String requestId = request.getRequestId();
         long startedAt = System.currentTimeMillis();
         AtomicReference<Long> firstTokenAtRef = new AtomicReference<>();
+        AtomicLong toolCallSequence = new AtomicLong(0);
+        AtomicLong toolSuccessCount = new AtomicLong(0);
+        AtomicLong toolFailureCount = new AtomicLong(0);
         Long sessionId = request.getSessionId();
+        Long callLogId = null;
         String model = null;
         try {
             ensureIdentity(userId, tenantId);
@@ -129,7 +165,20 @@
             AiChatSession session = resolveSession(request, userId, tenantId, config.getPromptCode());
             sessionId = session.getId();
             AiChatMemoryDto memory = loadMemory(userId, tenantId, config.getPromptCode(), session.getId());
-            List<AiChatMessageDto> mergedMessages = mergeMessages(memory.getPersistedMessages(), request.getMessages());
+            List<AiChatMessageDto> mergedMessages = mergeMessages(memory.getShortMemoryMessages(), request.getMessages());
+            AiCallLog callLog = aiCallLogService.startCallLog(
+                    requestId,
+                    session.getId(),
+                    userId,
+                    tenantId,
+                    config.getPromptCode(),
+                    config.getPrompt().getName(),
+                    config.getAiParam().getModel(),
+                    config.getMcpMounts().size(),
+                    config.getMcpMounts().size(),
+                    config.getMcpMounts().stream().map(item -> item.getName()).toList()
+            );
+            callLogId = callLog.getId();
             try (McpMountRuntimeFactory.McpMountRuntime runtime = createRuntime(config, userId)) {
                 emitStrict(emitter, "start", AiChatRuntimeDto.builder()
                         .requestId(requestId)
@@ -141,6 +190,9 @@
                         .mountedMcpCount(runtime.getMountedCount())
                         .mountedMcpNames(runtime.getMountedNames())
                         .mountErrors(runtime.getErrors())
+                        .memorySummary(memory.getMemorySummary())
+                        .memoryFacts(memory.getMemoryFacts())
+                        .recentMessageCount(memory.getRecentMessageCount())
                         .persistedMessages(memory.getPersistedMessages())
                         .build());
                 emitSafely(emitter, "status", AiChatStatusDto.builder()
@@ -154,9 +206,14 @@
                 log.info("AI chat started, requestId={}, userId={}, tenantId={}, sessionId={}, model={}",
                         requestId, userId, tenantId, session.getId(), resolvedModel);
 
+                ToolCallback[] observableToolCallbacks = wrapToolCallbacks(
+                        runtime.getToolCallbacks(), emitter, requestId, session.getId(), toolCallSequence,
+                        toolSuccessCount, toolFailureCount, callLogId, userId, tenantId
+                );
                 Prompt prompt = new Prompt(
-                        buildPromptMessages(mergedMessages, config.getPrompt(), request.getMetadata()),
-                        buildChatOptions(config.getAiParam(), runtime.getToolCallbacks(), userId, request.getMetadata())
+                        buildPromptMessages(memory, mergedMessages, config.getPrompt(), request.getMetadata()),
+                        buildChatOptions(config.getAiParam(), observableToolCallbacks, userId, tenantId,
+                                requestId, session.getId(), request.getMetadata())
                 );
                 OpenAiChatModel chatModel = createChatModel(config.getAiParam());
                 if (Boolean.FALSE.equals(config.getAiParam().getStreamingEnabled())) {
@@ -169,6 +226,17 @@
                     }
                     emitDone(emitter, requestId, response.getMetadata(), config.getAiParam().getModel(), session.getId(), startedAt, firstTokenAtRef.get());
                     emitSafely(emitter, "status", buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get()));
+                    aiCallLogService.completeCallLog(
+                            callLogId,
+                            "COMPLETED",
+                            System.currentTimeMillis() - startedAt,
+                            resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()),
+                            response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getPromptTokens(),
+                            response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getCompletionTokens(),
+                            response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getTotalTokens(),
+                            toolSuccessCount.get(),
+                            toolFailureCount.get()
+                    );
                     log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}",
                             requestId, session.getId(), System.currentTimeMillis() - startedAt, resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()));
                     emitter.complete();
@@ -196,16 +264,29 @@
                 aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), assistantContent.toString());
                 emitDone(emitter, requestId, lastMetadata.get(), config.getAiParam().getModel(), session.getId(), startedAt, firstTokenAtRef.get());
                 emitSafely(emitter, "status", buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get()));
+                aiCallLogService.completeCallLog(
+                        callLogId,
+                        "COMPLETED",
+                        System.currentTimeMillis() - startedAt,
+                        resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()),
+                        lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getPromptTokens(),
+                        lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getCompletionTokens(),
+                        lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getTotalTokens(),
+                        toolSuccessCount.get(),
+                        toolFailureCount.get()
+                );
                 log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}",
                         requestId, session.getId(), System.currentTimeMillis() - startedAt, resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()));
                 emitter.complete();
             }
         } catch (AiChatException e) {
-            handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(), e);
+            handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(), e,
+                    callLogId, toolSuccessCount.get(), toolFailureCount.get());
         } catch (Exception e) {
             handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(),
                     buildAiException("AI_INTERNAL_ERROR", AiErrorCategory.INTERNAL, "INTERNAL",
-                            e == null ? "AI 瀵硅瘽澶辫触" : e.getMessage(), e));
+                            e == null ? "AI 瀵硅瘽澶辫触" : e.getMessage(), e),
+                    callLogId, toolSuccessCount.get(), toolFailureCount.get());
         } finally {
             log.debug("AI chat stream finished, requestId={}", requestId);
         }
@@ -305,11 +386,24 @@
         return firstTokenAt == null ? null : Math.max(0L, firstTokenAt - startedAt);
     }
 
-    private void handleStreamFailure(SseEmitter emitter, String requestId, Long sessionId, String model, long startedAt, Long firstTokenAt, AiChatException exception) {
+    private void handleStreamFailure(SseEmitter emitter, String requestId, Long sessionId, String model, long startedAt,
+                                     Long firstTokenAt, AiChatException exception, Long callLogId,
+                                     long toolSuccessCount, long toolFailureCount) {
         if (isClientAbortException(exception)) {
             log.warn("AI chat aborted by client, requestId={}, sessionId={}, stage={}, message={}",
                     requestId, sessionId, exception.getStage(), exception.getMessage());
             emitSafely(emitter, "status", buildTerminalStatus(requestId, sessionId, "ABORTED", model, startedAt, firstTokenAt));
+            aiCallLogService.failCallLog(
+                    callLogId,
+                    "ABORTED",
+                    exception.getCategory().name(),
+                    exception.getStage(),
+                    exception.getMessage(),
+                    System.currentTimeMillis() - startedAt,
+                    resolveFirstTokenLatency(startedAt, firstTokenAt),
+                    toolSuccessCount,
+                    toolFailureCount
+            );
             emitter.completeWithError(exception);
             return;
         }
@@ -325,6 +419,17 @@
                 .message(exception.getMessage())
                 .timestamp(Instant.now().toEpochMilli())
                 .build());
+        aiCallLogService.failCallLog(
+                callLogId,
+                "FAILED",
+                exception.getCategory().name(),
+                exception.getStage(),
+                exception.getMessage(),
+                System.currentTimeMillis() - startedAt,
+                resolveFirstTokenLatency(startedAt, firstTokenAt),
+                toolSuccessCount,
+                toolFailureCount
+        );
         emitter.completeWithError(exception);
     }
 
@@ -364,7 +469,8 @@
                 .build();
     }
 
-    private OpenAiChatOptions buildChatOptions(AiParam aiParam, ToolCallback[] toolCallbacks, Long userId, Map<String, Object> metadata) {
+    private OpenAiChatOptions buildChatOptions(AiParam aiParam, ToolCallback[] toolCallbacks, Long userId, Long tenantId,
+                                               String requestId, Long sessionId, Map<String, Object> metadata) {
         if (userId == null) {
             throw buildAiException("AI_AUTH_USER_MISSING", AiErrorCategory.AUTH, "OPTIONS_BUILD", "褰撳墠鐧诲綍鐢ㄦ埛涓嶅瓨鍦�", null);
         }
@@ -376,25 +482,59 @@
                 .streamUsage(true)
                 .user(String.valueOf(userId));
         if (!Cools.isEmpty(toolCallbacks)) {
-            builder.toolCallbacks(Arrays.asList(toolCallbacks));
+            builder.toolCallbacks(Arrays.stream(toolCallbacks).toList());
         }
+        Map<String, Object> toolContext = new LinkedHashMap<>();
+        toolContext.put("userId", userId);
+        toolContext.put("tenantId", tenantId);
+        toolContext.put("requestId", requestId);
+        toolContext.put("sessionId", sessionId);
         Map<String, String> metadataMap = new LinkedHashMap<>();
         if (metadata != null) {
-            metadata.forEach((key, value) -> metadataMap.put(key, value == null ? "" : String.valueOf(value)));
+            metadata.forEach((key, value) -> {
+                String normalized = value == null ? "" : String.valueOf(value);
+                metadataMap.put(key, normalized);
+                toolContext.put(key, normalized);
+            });
         }
+        builder.toolContext(toolContext);
         if (!metadataMap.isEmpty()) {
             builder.metadata(metadataMap);
         }
         return builder.build();
     }
 
-    private List<Message> buildPromptMessages(List<AiChatMessageDto> sourceMessages, AiPrompt aiPrompt, Map<String, Object> metadata) {
+    private ToolCallback[] wrapToolCallbacks(ToolCallback[] toolCallbacks, SseEmitter emitter, String requestId,
+                                             Long sessionId, AtomicLong toolCallSequence,
+                                             AtomicLong toolSuccessCount, AtomicLong toolFailureCount,
+                                             Long callLogId, Long userId, Long tenantId) {
+        if (Cools.isEmpty(toolCallbacks)) {
+            return toolCallbacks;
+        }
+        List<ToolCallback> wrappedCallbacks = new ArrayList<>();
+        for (ToolCallback callback : toolCallbacks) {
+            if (callback == null) {
+                continue;
+            }
+            wrappedCallbacks.add(new ObservableToolCallback(callback, emitter, requestId, sessionId, toolCallSequence,
+                    toolSuccessCount, toolFailureCount, callLogId, userId, tenantId));
+        }
+        return wrappedCallbacks.toArray(new ToolCallback[0]);
+    }
+
+    private List<Message> buildPromptMessages(AiChatMemoryDto memory, List<AiChatMessageDto> sourceMessages, AiPrompt aiPrompt, Map<String, Object> metadata) {
         if (Cools.isEmpty(sourceMessages)) {
             throw new CoolException("瀵硅瘽娑堟伅涓嶈兘涓虹┖");
         }
         List<Message> messages = new ArrayList<>();
         if (StringUtils.hasText(aiPrompt.getSystemPrompt())) {
             messages.add(new SystemMessage(aiPrompt.getSystemPrompt()));
+        }
+        if (memory != null && StringUtils.hasText(memory.getMemorySummary())) {
+            messages.add(new SystemMessage("鍘嗗彶鎽樿:\n" + memory.getMemorySummary()));
+        }
+        if (memory != null && StringUtils.hasText(memory.getMemoryFacts())) {
+            messages.add(new SystemMessage("鍏抽敭浜嬪疄:\n" + memory.getMemoryFacts()));
         }
         int lastUserIndex = -1;
         for (int i = 0; i < sourceMessages.size(); i++) {
@@ -482,6 +622,17 @@
         return response.getResult().getOutput().getText();
     }
 
+    private String summarizeToolPayload(String content, int maxLength) {
+        if (!StringUtils.hasText(content)) {
+            return null;
+        }
+        String normalized = content.trim()
+                .replace("\r", " ")
+                .replace("\n", " ")
+                .replaceAll("\\s+", " ");
+        return normalized.length() > maxLength ? normalized.substring(0, maxLength) : normalized;
+    }
+
     private void emitDone(SseEmitter emitter, String requestId, ChatResponseMetadata metadata, String fallbackModel, Long sessionId, long startedAt, Long firstTokenAt) {
         Usage usage = metadata == null ? null : metadata.getUsage();
         emitStrict(emitter, "done", AiChatDoneDto.builder()
@@ -550,4 +701,107 @@
         }
         return false;
     }
+
+    private class ObservableToolCallback implements ToolCallback {
+
+        private final ToolCallback delegate;
+        private final SseEmitter emitter;
+        private final String requestId;
+        private final Long sessionId;
+        private final AtomicLong toolCallSequence;
+        private final AtomicLong toolSuccessCount;
+        private final AtomicLong toolFailureCount;
+        private final Long callLogId;
+        private final Long userId;
+        private final Long tenantId;
+
+        private ObservableToolCallback(ToolCallback delegate, SseEmitter emitter, String requestId,
+                                       Long sessionId, AtomicLong toolCallSequence,
+                                       AtomicLong toolSuccessCount, AtomicLong toolFailureCount,
+                                       Long callLogId, Long userId, Long tenantId) {
+            this.delegate = delegate;
+            this.emitter = emitter;
+            this.requestId = requestId;
+            this.sessionId = sessionId;
+            this.toolCallSequence = toolCallSequence;
+            this.toolSuccessCount = toolSuccessCount;
+            this.toolFailureCount = toolFailureCount;
+            this.callLogId = callLogId;
+            this.userId = userId;
+            this.tenantId = tenantId;
+        }
+
+        @Override
+        public org.springframework.ai.tool.definition.ToolDefinition getToolDefinition() {
+            return delegate.getToolDefinition();
+        }
+
+        @Override
+        public org.springframework.ai.tool.metadata.ToolMetadata getToolMetadata() {
+            return delegate.getToolMetadata();
+        }
+
+        @Override
+        public String call(String toolInput) {
+            return call(toolInput, null);
+        }
+
+        @Override
+        public String call(String toolInput, ToolContext toolContext) {
+            String toolName = delegate.getToolDefinition() == null ? "unknown" : delegate.getToolDefinition().name();
+            String mountName = delegate instanceof MountedToolCallback ? ((MountedToolCallback) delegate).getMountName() : null;
+            String toolCallId = requestId + "-tool-" + toolCallSequence.incrementAndGet();
+            long startedAt = System.currentTimeMillis();
+            emitSafely(emitter, "tool_start", AiChatToolEventDto.builder()
+                    .requestId(requestId)
+                    .sessionId(sessionId)
+                    .toolCallId(toolCallId)
+                    .toolName(toolName)
+                    .mountName(mountName)
+                    .status("STARTED")
+                    .inputSummary(summarizeToolPayload(toolInput, 400))
+                    .timestamp(startedAt)
+                    .build());
+            try {
+                String output = toolContext == null ? delegate.call(toolInput) : delegate.call(toolInput, toolContext);
+                long durationMs = System.currentTimeMillis() - startedAt;
+                emitSafely(emitter, "tool_result", AiChatToolEventDto.builder()
+                        .requestId(requestId)
+                        .sessionId(sessionId)
+                        .toolCallId(toolCallId)
+                        .toolName(toolName)
+                        .mountName(mountName)
+                        .status("COMPLETED")
+                        .inputSummary(summarizeToolPayload(toolInput, 400))
+                        .outputSummary(summarizeToolPayload(output, 600))
+                        .durationMs(durationMs)
+                        .timestamp(System.currentTimeMillis())
+                        .build());
+                toolSuccessCount.incrementAndGet();
+                aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
+                        "COMPLETED", summarizeToolPayload(toolInput, 400), summarizeToolPayload(output, 600),
+                        null, durationMs, userId, tenantId);
+                return output;
+            } catch (RuntimeException e) {
+                long durationMs = System.currentTimeMillis() - startedAt;
+                emitSafely(emitter, "tool_error", AiChatToolEventDto.builder()
+                        .requestId(requestId)
+                        .sessionId(sessionId)
+                        .toolCallId(toolCallId)
+                        .toolName(toolName)
+                        .mountName(mountName)
+                        .status("FAILED")
+                        .inputSummary(summarizeToolPayload(toolInput, 400))
+                        .errorMessage(e.getMessage())
+                        .durationMs(durationMs)
+                        .timestamp(System.currentTimeMillis())
+                        .build());
+                toolFailureCount.incrementAndGet();
+                aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
+                        "FAILED", summarizeToolPayload(toolInput, 400), null, e.getMessage(),
+                        durationMs, userId, tenantId);
+                throw e;
+            }
+        }
+    }
 }

--
Gitblit v1.9.1