From 1d0ab9996661fdc66037870d4b98037f2dfa079a Mon Sep 17 00:00:00 2001
From: zhou zhou <3272660260@qq.com>
Date: 星期四, 19 三月 2026 12:03:19 +0800
Subject: [PATCH] #AI.工具调用可视化
---
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java | 158 ++++++++++++++++++++++++++++++++++++++++++++++++++--
1 files changed, 151 insertions(+), 7 deletions(-)
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java
index 280914f..5e3c7ab 100644
--- a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java
@@ -14,6 +14,7 @@
import com.vincent.rsf.server.ai.dto.AiChatSessionDto;
import com.vincent.rsf.server.ai.dto.AiChatSessionPinRequest;
import com.vincent.rsf.server.ai.dto.AiChatSessionRenameRequest;
+import com.vincent.rsf.server.ai.dto.AiChatToolEventDto;
import com.vincent.rsf.server.ai.dto.AiResolvedConfig;
import com.vincent.rsf.server.ai.entity.AiParam;
import com.vincent.rsf.server.ai.entity.AiPrompt;
@@ -34,6 +35,7 @@
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.model.ChatResponse;
+import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.tool.DefaultToolCallingManager;
import org.springframework.ai.model.tool.ToolCallingManager;
@@ -66,6 +68,7 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicReference;
+import java.util.concurrent.atomic.AtomicLong;
@Slf4j
@Service
@@ -95,6 +98,9 @@
.mountedMcpCount(config.getMcpMounts().size())
.mountedMcpNames(config.getMcpMounts().stream().map(item -> item.getName()).toList())
.mountErrors(List.of())
+ .memorySummary(memory.getMemorySummary())
+ .memoryFacts(memory.getMemoryFacts())
+ .recentMessageCount(memory.getRecentMessageCount())
.persistedMessages(memory.getPersistedMessages())
.build();
}
@@ -121,6 +127,16 @@
}
@Override
+ public void clearSessionMemory(Long sessionId, Long userId, Long tenantId) {
+ aiChatMemoryService.clearSessionMemory(userId, tenantId, sessionId);
+ }
+
+ @Override
+ public void retainLatestRound(Long sessionId, Long userId, Long tenantId) {
+ aiChatMemoryService.retainLatestRound(userId, tenantId, sessionId);
+ }
+
+ @Override
public SseEmitter stream(AiChatRequest request, Long userId, Long tenantId) {
SseEmitter emitter = new SseEmitter(AiDefaults.SSE_TIMEOUT_MS);
CompletableFuture.runAsync(() -> doStream(request, userId, tenantId, emitter), aiChatTaskExecutor);
@@ -131,6 +147,7 @@
String requestId = request.getRequestId();
long startedAt = System.currentTimeMillis();
AtomicReference<Long> firstTokenAtRef = new AtomicReference<>();
+ AtomicLong toolCallSequence = new AtomicLong(0);
Long sessionId = request.getSessionId();
String model = null;
try {
@@ -141,7 +158,7 @@
AiChatSession session = resolveSession(request, userId, tenantId, config.getPromptCode());
sessionId = session.getId();
AiChatMemoryDto memory = loadMemory(userId, tenantId, config.getPromptCode(), session.getId());
- List<AiChatMessageDto> mergedMessages = mergeMessages(memory.getPersistedMessages(), request.getMessages());
+ List<AiChatMessageDto> mergedMessages = mergeMessages(memory.getShortMemoryMessages(), request.getMessages());
try (McpMountRuntimeFactory.McpMountRuntime runtime = createRuntime(config, userId)) {
emitStrict(emitter, "start", AiChatRuntimeDto.builder()
.requestId(requestId)
@@ -153,6 +170,9 @@
.mountedMcpCount(runtime.getMountedCount())
.mountedMcpNames(runtime.getMountedNames())
.mountErrors(runtime.getErrors())
+ .memorySummary(memory.getMemorySummary())
+ .memoryFacts(memory.getMemoryFacts())
+ .recentMessageCount(memory.getRecentMessageCount())
.persistedMessages(memory.getPersistedMessages())
.build());
emitSafely(emitter, "status", AiChatStatusDto.builder()
@@ -166,9 +186,13 @@
log.info("AI chat started, requestId={}, userId={}, tenantId={}, sessionId={}, model={}",
requestId, userId, tenantId, session.getId(), resolvedModel);
+ ToolCallback[] observableToolCallbacks = wrapToolCallbacks(
+ runtime.getToolCallbacks(), emitter, requestId, session.getId(), toolCallSequence
+ );
Prompt prompt = new Prompt(
- buildPromptMessages(mergedMessages, config.getPrompt(), request.getMetadata()),
- buildChatOptions(config.getAiParam(), runtime.getToolCallbacks(), userId, request.getMetadata())
+ buildPromptMessages(memory, mergedMessages, config.getPrompt(), request.getMetadata()),
+ buildChatOptions(config.getAiParam(), observableToolCallbacks, userId, tenantId,
+ requestId, session.getId(), request.getMetadata())
);
OpenAiChatModel chatModel = createChatModel(config.getAiParam());
if (Boolean.FALSE.equals(config.getAiParam().getStreamingEnabled())) {
@@ -376,7 +400,8 @@
.build();
}
- private OpenAiChatOptions buildChatOptions(AiParam aiParam, ToolCallback[] toolCallbacks, Long userId, Map<String, Object> metadata) {
+ private OpenAiChatOptions buildChatOptions(AiParam aiParam, ToolCallback[] toolCallbacks, Long userId, Long tenantId,
+ String requestId, Long sessionId, Map<String, Object> metadata) {
if (userId == null) {
throw buildAiException("AI_AUTH_USER_MISSING", AiErrorCategory.AUTH, "OPTIONS_BUILD", "褰撳墠鐧诲綍鐢ㄦ埛涓嶅瓨鍦�", null);
}
@@ -388,25 +413,56 @@
.streamUsage(true)
.user(String.valueOf(userId));
if (!Cools.isEmpty(toolCallbacks)) {
- builder.toolCallbacks(Arrays.asList(toolCallbacks));
+ builder.toolCallbacks(Arrays.stream(toolCallbacks).toList());
}
+ Map<String, Object> toolContext = new LinkedHashMap<>();
+ toolContext.put("userId", userId);
+ toolContext.put("tenantId", tenantId);
+ toolContext.put("requestId", requestId);
+ toolContext.put("sessionId", sessionId);
Map<String, String> metadataMap = new LinkedHashMap<>();
if (metadata != null) {
- metadata.forEach((key, value) -> metadataMap.put(key, value == null ? "" : String.valueOf(value)));
+ metadata.forEach((key, value) -> {
+ String normalized = value == null ? "" : String.valueOf(value);
+ metadataMap.put(key, normalized);
+ toolContext.put(key, normalized);
+ });
}
+ builder.toolContext(toolContext);
if (!metadataMap.isEmpty()) {
builder.metadata(metadataMap);
}
return builder.build();
}
- private List<Message> buildPromptMessages(List<AiChatMessageDto> sourceMessages, AiPrompt aiPrompt, Map<String, Object> metadata) {
+ private ToolCallback[] wrapToolCallbacks(ToolCallback[] toolCallbacks, SseEmitter emitter, String requestId,
+ Long sessionId, AtomicLong toolCallSequence) {
+ if (Cools.isEmpty(toolCallbacks)) {
+ return toolCallbacks;
+ }
+ List<ToolCallback> wrappedCallbacks = new ArrayList<>();
+ for (ToolCallback callback : toolCallbacks) {
+ if (callback == null) {
+ continue;
+ }
+ wrappedCallbacks.add(new ObservableToolCallback(callback, emitter, requestId, sessionId, toolCallSequence));
+ }
+ return wrappedCallbacks.toArray(new ToolCallback[0]);
+ }
+
+ private List<Message> buildPromptMessages(AiChatMemoryDto memory, List<AiChatMessageDto> sourceMessages, AiPrompt aiPrompt, Map<String, Object> metadata) {
if (Cools.isEmpty(sourceMessages)) {
throw new CoolException("瀵硅瘽娑堟伅涓嶈兘涓虹┖");
}
List<Message> messages = new ArrayList<>();
if (StringUtils.hasText(aiPrompt.getSystemPrompt())) {
messages.add(new SystemMessage(aiPrompt.getSystemPrompt()));
+ }
+ if (memory != null && StringUtils.hasText(memory.getMemorySummary())) {
+ messages.add(new SystemMessage("鍘嗗彶鎽樿:\n" + memory.getMemorySummary()));
+ }
+ if (memory != null && StringUtils.hasText(memory.getMemoryFacts())) {
+ messages.add(new SystemMessage("鍏抽敭浜嬪疄:\n" + memory.getMemoryFacts()));
}
int lastUserIndex = -1;
for (int i = 0; i < sourceMessages.size(); i++) {
@@ -494,6 +550,17 @@
return response.getResult().getOutput().getText();
}
+ private String summarizeToolPayload(String content, int maxLength) {
+ if (!StringUtils.hasText(content)) {
+ return null;
+ }
+ String normalized = content.trim()
+ .replace("\r", " ")
+ .replace("\n", " ")
+ .replaceAll("\\s+", " ");
+ return normalized.length() > maxLength ? normalized.substring(0, maxLength) : normalized;
+ }
+
private void emitDone(SseEmitter emitter, String requestId, ChatResponseMetadata metadata, String fallbackModel, Long sessionId, long startedAt, Long firstTokenAt) {
Usage usage = metadata == null ? null : metadata.getUsage();
emitStrict(emitter, "done", AiChatDoneDto.builder()
@@ -562,4 +629,81 @@
}
return false;
}
+
+ private class ObservableToolCallback implements ToolCallback {
+
+ private final ToolCallback delegate;
+ private final SseEmitter emitter;
+ private final String requestId;
+ private final Long sessionId;
+ private final AtomicLong toolCallSequence;
+
+ private ObservableToolCallback(ToolCallback delegate, SseEmitter emitter, String requestId,
+ Long sessionId, AtomicLong toolCallSequence) {
+ this.delegate = delegate;
+ this.emitter = emitter;
+ this.requestId = requestId;
+ this.sessionId = sessionId;
+ this.toolCallSequence = toolCallSequence;
+ }
+
+ @Override
+ public org.springframework.ai.tool.definition.ToolDefinition getToolDefinition() {
+ return delegate.getToolDefinition();
+ }
+
+ @Override
+ public org.springframework.ai.tool.metadata.ToolMetadata getToolMetadata() {
+ return delegate.getToolMetadata();
+ }
+
+ @Override
+ public String call(String toolInput) {
+ return call(toolInput, null);
+ }
+
+ @Override
+ public String call(String toolInput, ToolContext toolContext) {
+ String toolName = delegate.getToolDefinition() == null ? "unknown" : delegate.getToolDefinition().name();
+ String toolCallId = requestId + "-tool-" + toolCallSequence.incrementAndGet();
+ long startedAt = System.currentTimeMillis();
+ emitSafely(emitter, "tool_start", AiChatToolEventDto.builder()
+ .requestId(requestId)
+ .sessionId(sessionId)
+ .toolCallId(toolCallId)
+ .toolName(toolName)
+ .status("STARTED")
+ .inputSummary(summarizeToolPayload(toolInput, 400))
+ .timestamp(startedAt)
+ .build());
+ try {
+ String output = toolContext == null ? delegate.call(toolInput) : delegate.call(toolInput, toolContext);
+ emitSafely(emitter, "tool_result", AiChatToolEventDto.builder()
+ .requestId(requestId)
+ .sessionId(sessionId)
+ .toolCallId(toolCallId)
+ .toolName(toolName)
+ .status("COMPLETED")
+ .inputSummary(summarizeToolPayload(toolInput, 400))
+ .outputSummary(summarizeToolPayload(output, 600))
+ .durationMs(System.currentTimeMillis() - startedAt)
+ .timestamp(System.currentTimeMillis())
+ .build());
+ return output;
+ } catch (RuntimeException e) {
+ emitSafely(emitter, "tool_error", AiChatToolEventDto.builder()
+ .requestId(requestId)
+ .sessionId(sessionId)
+ .toolCallId(toolCallId)
+ .toolName(toolName)
+ .status("FAILED")
+ .inputSummary(summarizeToolPayload(toolInput, 400))
+ .errorMessage(e.getMessage())
+ .durationMs(System.currentTimeMillis() - startedAt)
+ .timestamp(System.currentTimeMillis())
+ .build());
+ throw e;
+ }
+ }
+ }
}
--
Gitblit v1.9.1