zhou zhou
14 小时以前 b05f094ac51dce91eb8c00235226d54a04658c6d
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/chat/AiToolObservationService.java
@@ -1,7 +1,6 @@
package com.vincent.rsf.server.ai.service.impl.chat;
import com.vincent.rsf.framework.exception.CoolException;
import com.vincent.rsf.server.ai.dto.AiChatToolEventDto;
import com.vincent.rsf.server.ai.service.AiCallLogService;
import com.vincent.rsf.server.ai.service.MountedToolCallback;
import com.vincent.rsf.server.ai.store.AiCachedToolResult;
@@ -11,7 +10,6 @@
import org.springframework.ai.tool.ToolCallback;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.util.ArrayList;
import java.util.List;
@@ -21,15 +19,14 @@
@RequiredArgsConstructor
public class AiToolObservationService {
    private final AiSseEventPublisher aiSseEventPublisher;
    private final AiToolResultStore aiToolResultStore;
    private final AiCallLogService aiCallLogService;
    public ToolCallback[] wrapToolCallbacks(ToolCallback[] toolCallbacks, SseEmitter emitter, String requestId,
    public ToolCallback[] wrapToolCallbacks(ToolCallback[] toolCallbacks, String requestId,
                                            Long sessionId, AtomicLong toolCallSequence,
                                            AtomicLong toolSuccessCount, AtomicLong toolFailureCount,
                                            Long callLogId, Long userId, Long tenantId,
                                            AiThinkingTraceEmitter thinkingTraceEmitter) {
                                            AiChatTraceEmitter traceEmitter) {
        if (toolCallbacks == null || toolCallbacks.length == 0) {
            return toolCallbacks;
        }
@@ -38,8 +35,8 @@
            if (callback == null) {
                continue;
            }
            wrappedCallbacks.add(new ObservableToolCallback(callback, emitter, requestId, sessionId, toolCallSequence,
                    toolSuccessCount, toolFailureCount, callLogId, userId, tenantId, thinkingTraceEmitter));
            wrappedCallbacks.add(new ObservableToolCallback(callback, requestId, sessionId, toolCallSequence,
                    toolSuccessCount, toolFailureCount, callLogId, userId, tenantId, traceEmitter));
        }
        return wrappedCallbacks.toArray(new ToolCallback[0]);
    }
@@ -58,7 +55,6 @@
    private class ObservableToolCallback implements ToolCallback {
        private final ToolCallback delegate;
        private final SseEmitter emitter;
        private final String requestId;
        private final Long sessionId;
        private final AtomicLong toolCallSequence;
@@ -67,15 +63,14 @@
        private final Long callLogId;
        private final Long userId;
        private final Long tenantId;
        private final AiThinkingTraceEmitter thinkingTraceEmitter;
        private final AiChatTraceEmitter traceEmitter;
        private ObservableToolCallback(ToolCallback delegate, SseEmitter emitter, String requestId,
        private ObservableToolCallback(ToolCallback delegate, String requestId,
                                       Long sessionId, AtomicLong toolCallSequence,
                                       AtomicLong toolSuccessCount, AtomicLong toolFailureCount,
                                       Long callLogId, Long userId, Long tenantId,
                                       AiThinkingTraceEmitter thinkingTraceEmitter) {
                                       AiChatTraceEmitter traceEmitter) {
            this.delegate = delegate;
            this.emitter = emitter;
            this.requestId = requestId;
            this.sessionId = sessionId;
            this.toolCallSequence = toolCallSequence;
@@ -84,7 +79,7 @@
            this.callLogId = callLogId;
            this.userId = userId;
            this.tenantId = tenantId;
            this.thinkingTraceEmitter = thinkingTraceEmitter;
            this.traceEmitter = traceEmitter;
        }
        @Override
@@ -108,95 +103,56 @@
            String mountName = delegate instanceof MountedToolCallback ? ((MountedToolCallback) delegate).getMountName() : null;
            String toolCallId = requestId + "-tool-" + toolCallSequence.incrementAndGet();
            long startedAt = System.currentTimeMillis();
            String inputSummary = summarizeToolPayload(toolInput, 400);
            AiCachedToolResult cachedToolResult = aiToolResultStore.getToolResult(tenantId, requestId, toolName, toolInput);
            if (cachedToolResult != null) {
                aiSseEventPublisher.emitSafely(emitter, "tool_result", AiChatToolEventDto.builder()
                        .requestId(requestId)
                        .sessionId(sessionId)
                        .toolCallId(toolCallId)
                        .toolName(toolName)
                        .mountName(mountName)
                        .status(cachedToolResult.isSuccess() ? "COMPLETED" : "FAILED")
                        .inputSummary(summarizeToolPayload(toolInput, 400))
                        .outputSummary(summarizeToolPayload(cachedToolResult.getOutput(), 600))
                        .errorMessage(cachedToolResult.getErrorMessage())
                        .durationMs(0L)
                        .timestamp(System.currentTimeMillis())
                        .build());
                if (thinkingTraceEmitter != null) {
                    thinkingTraceEmitter.onToolResult(toolName, toolCallId, !cachedToolResult.isSuccess());
                String outputSummary = summarizeToolPayload(cachedToolResult.getOutput(), 600);
                String errorMessage = cachedToolResult.getErrorMessage();
                if (traceEmitter != null) {
                    traceEmitter.onToolResult(toolName, mountName, toolCallId, inputSummary, outputSummary,
                            errorMessage, 0L, System.currentTimeMillis(), !cachedToolResult.isSuccess());
                }
                if (cachedToolResult.isSuccess()) {
                    toolSuccessCount.incrementAndGet();
                    aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
                            "COMPLETED", summarizeToolPayload(toolInput, 400), summarizeToolPayload(cachedToolResult.getOutput(), 600),
                            "COMPLETED", inputSummary, outputSummary,
                            null, 0L, userId, tenantId);
                    return cachedToolResult.getOutput();
                }
                toolFailureCount.incrementAndGet();
                aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
                        "FAILED", summarizeToolPayload(toolInput, 400), null, cachedToolResult.getErrorMessage(),
                        "FAILED", inputSummary, null, errorMessage,
                        0L, userId, tenantId);
                throw new CoolException(cachedToolResult.getErrorMessage());
                throw new CoolException(errorMessage);
            }
            if (thinkingTraceEmitter != null) {
                thinkingTraceEmitter.onToolStart(toolName, toolCallId);
            if (traceEmitter != null) {
                traceEmitter.onToolStart(toolName, mountName, toolCallId, inputSummary, startedAt);
            }
            aiSseEventPublisher.emitSafely(emitter, "tool_start", AiChatToolEventDto.builder()
                    .requestId(requestId)
                    .sessionId(sessionId)
                    .toolCallId(toolCallId)
                    .toolName(toolName)
                    .mountName(mountName)
                    .status("STARTED")
                    .inputSummary(summarizeToolPayload(toolInput, 400))
                    .timestamp(startedAt)
                    .build());
            try {
                String output = toolContext == null ? delegate.call(toolInput) : delegate.call(toolInput, toolContext);
                long durationMs = System.currentTimeMillis() - startedAt;
                aiSseEventPublisher.emitSafely(emitter, "tool_result", AiChatToolEventDto.builder()
                        .requestId(requestId)
                        .sessionId(sessionId)
                        .toolCallId(toolCallId)
                        .toolName(toolName)
                        .mountName(mountName)
                        .status("COMPLETED")
                        .inputSummary(summarizeToolPayload(toolInput, 400))
                        .outputSummary(summarizeToolPayload(output, 600))
                        .durationMs(durationMs)
                        .timestamp(System.currentTimeMillis())
                        .build());
                if (thinkingTraceEmitter != null) {
                    thinkingTraceEmitter.onToolResult(toolName, toolCallId, false);
                String outputSummary = summarizeToolPayload(output, 600);
                if (traceEmitter != null) {
                    traceEmitter.onToolResult(toolName, mountName, toolCallId, inputSummary, outputSummary,
                            null, durationMs, System.currentTimeMillis(), false);
                }
                aiToolResultStore.cacheToolResult(tenantId, requestId, toolName, toolInput, true, output, null);
                toolSuccessCount.incrementAndGet();
                aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
                        "COMPLETED", summarizeToolPayload(toolInput, 400), summarizeToolPayload(output, 600),
                        "COMPLETED", inputSummary, outputSummary,
                        null, durationMs, userId, tenantId);
                return output;
            } catch (RuntimeException e) {
                long durationMs = System.currentTimeMillis() - startedAt;
                aiSseEventPublisher.emitSafely(emitter, "tool_error", AiChatToolEventDto.builder()
                        .requestId(requestId)
                        .sessionId(sessionId)
                        .toolCallId(toolCallId)
                        .toolName(toolName)
                        .mountName(mountName)
                        .status("FAILED")
                        .inputSummary(summarizeToolPayload(toolInput, 400))
                        .errorMessage(e.getMessage())
                        .durationMs(durationMs)
                        .timestamp(System.currentTimeMillis())
                        .build());
                if (thinkingTraceEmitter != null) {
                    thinkingTraceEmitter.onToolResult(toolName, toolCallId, true);
                String errorMessage = e.getMessage();
                if (traceEmitter != null) {
                    traceEmitter.onToolResult(toolName, mountName, toolCallId, inputSummary, null,
                            errorMessage, durationMs, System.currentTimeMillis(), true);
                }
                aiToolResultStore.cacheToolResult(tenantId, requestId, toolName, toolInput, false, null, e.getMessage());
                aiToolResultStore.cacheToolResult(tenantId, requestId, toolName, toolInput, false, null, errorMessage);
                toolFailureCount.incrementAndGet();
                aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
                        "FAILED", summarizeToolPayload(toolInput, 400), null, e.getMessage(),
                        "FAILED", inputSummary, null, errorMessage,
                        durationMs, userId, tenantId);
                throw e;
            }