package com.vincent.rsf.server.ai.service.impl.chat; import com.vincent.rsf.framework.exception.CoolException; import com.vincent.rsf.server.ai.dto.AiChatToolEventDto; import com.vincent.rsf.server.ai.service.AiCallLogService; import com.vincent.rsf.server.ai.service.MountedToolCallback; import com.vincent.rsf.server.ai.store.AiCachedToolResult; import com.vincent.rsf.server.ai.store.AiToolResultStore; import lombok.RequiredArgsConstructor; import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.tool.ToolCallback; import org.springframework.stereotype.Component; import org.springframework.util.StringUtils; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import java.util.ArrayList; import java.util.List; import java.util.concurrent.atomic.AtomicLong; @Component @RequiredArgsConstructor public class AiToolObservationService { private final AiSseEventPublisher aiSseEventPublisher; private final AiToolResultStore aiToolResultStore; private final AiCallLogService aiCallLogService; public ToolCallback[] wrapToolCallbacks(ToolCallback[] toolCallbacks, SseEmitter emitter, String requestId, Long sessionId, AtomicLong toolCallSequence, AtomicLong toolSuccessCount, AtomicLong toolFailureCount, Long callLogId, Long userId, Long tenantId, AiThinkingTraceEmitter thinkingTraceEmitter) { if (toolCallbacks == null || toolCallbacks.length == 0) { return toolCallbacks; } List wrappedCallbacks = new ArrayList<>(); for (ToolCallback callback : toolCallbacks) { if (callback == null) { continue; } wrappedCallbacks.add(new ObservableToolCallback(callback, emitter, requestId, sessionId, toolCallSequence, toolSuccessCount, toolFailureCount, callLogId, userId, tenantId, thinkingTraceEmitter)); } return wrappedCallbacks.toArray(new ToolCallback[0]); } private String summarizeToolPayload(String content, int maxLength) { if (!StringUtils.hasText(content)) { return null; } String normalized = content.trim() .replace("\r", " ") .replace("\n", " ") .replaceAll("\\s+", " "); return normalized.length() > maxLength ? normalized.substring(0, maxLength) : normalized; } private class ObservableToolCallback implements ToolCallback { private final ToolCallback delegate; private final SseEmitter emitter; private final String requestId; private final Long sessionId; private final AtomicLong toolCallSequence; private final AtomicLong toolSuccessCount; private final AtomicLong toolFailureCount; private final Long callLogId; private final Long userId; private final Long tenantId; private final AiThinkingTraceEmitter thinkingTraceEmitter; private ObservableToolCallback(ToolCallback delegate, SseEmitter emitter, String requestId, Long sessionId, AtomicLong toolCallSequence, AtomicLong toolSuccessCount, AtomicLong toolFailureCount, Long callLogId, Long userId, Long tenantId, AiThinkingTraceEmitter thinkingTraceEmitter) { this.delegate = delegate; this.emitter = emitter; this.requestId = requestId; this.sessionId = sessionId; this.toolCallSequence = toolCallSequence; this.toolSuccessCount = toolSuccessCount; this.toolFailureCount = toolFailureCount; this.callLogId = callLogId; this.userId = userId; this.tenantId = tenantId; this.thinkingTraceEmitter = thinkingTraceEmitter; } @Override public org.springframework.ai.tool.definition.ToolDefinition getToolDefinition() { return delegate.getToolDefinition(); } @Override public org.springframework.ai.tool.metadata.ToolMetadata getToolMetadata() { return delegate.getToolMetadata(); } @Override public String call(String toolInput) { return call(toolInput, null); } @Override public String call(String toolInput, ToolContext toolContext) { String toolName = delegate.getToolDefinition() == null ? "unknown" : delegate.getToolDefinition().name(); String mountName = delegate instanceof MountedToolCallback ? ((MountedToolCallback) delegate).getMountName() : null; String toolCallId = requestId + "-tool-" + toolCallSequence.incrementAndGet(); long startedAt = System.currentTimeMillis(); AiCachedToolResult cachedToolResult = aiToolResultStore.getToolResult(tenantId, requestId, toolName, toolInput); if (cachedToolResult != null) { aiSseEventPublisher.emitSafely(emitter, "tool_result", AiChatToolEventDto.builder() .requestId(requestId) .sessionId(sessionId) .toolCallId(toolCallId) .toolName(toolName) .mountName(mountName) .status(cachedToolResult.isSuccess() ? "COMPLETED" : "FAILED") .inputSummary(summarizeToolPayload(toolInput, 400)) .outputSummary(summarizeToolPayload(cachedToolResult.getOutput(), 600)) .errorMessage(cachedToolResult.getErrorMessage()) .durationMs(0L) .timestamp(System.currentTimeMillis()) .build()); if (thinkingTraceEmitter != null) { thinkingTraceEmitter.onToolResult(toolName, toolCallId, !cachedToolResult.isSuccess()); } if (cachedToolResult.isSuccess()) { toolSuccessCount.incrementAndGet(); aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName, "COMPLETED", summarizeToolPayload(toolInput, 400), summarizeToolPayload(cachedToolResult.getOutput(), 600), null, 0L, userId, tenantId); return cachedToolResult.getOutput(); } toolFailureCount.incrementAndGet(); aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName, "FAILED", summarizeToolPayload(toolInput, 400), null, cachedToolResult.getErrorMessage(), 0L, userId, tenantId); throw new CoolException(cachedToolResult.getErrorMessage()); } if (thinkingTraceEmitter != null) { thinkingTraceEmitter.onToolStart(toolName, toolCallId); } aiSseEventPublisher.emitSafely(emitter, "tool_start", AiChatToolEventDto.builder() .requestId(requestId) .sessionId(sessionId) .toolCallId(toolCallId) .toolName(toolName) .mountName(mountName) .status("STARTED") .inputSummary(summarizeToolPayload(toolInput, 400)) .timestamp(startedAt) .build()); try { String output = toolContext == null ? delegate.call(toolInput) : delegate.call(toolInput, toolContext); long durationMs = System.currentTimeMillis() - startedAt; aiSseEventPublisher.emitSafely(emitter, "tool_result", AiChatToolEventDto.builder() .requestId(requestId) .sessionId(sessionId) .toolCallId(toolCallId) .toolName(toolName) .mountName(mountName) .status("COMPLETED") .inputSummary(summarizeToolPayload(toolInput, 400)) .outputSummary(summarizeToolPayload(output, 600)) .durationMs(durationMs) .timestamp(System.currentTimeMillis()) .build()); if (thinkingTraceEmitter != null) { thinkingTraceEmitter.onToolResult(toolName, toolCallId, false); } aiToolResultStore.cacheToolResult(tenantId, requestId, toolName, toolInput, true, output, null); toolSuccessCount.incrementAndGet(); aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName, "COMPLETED", summarizeToolPayload(toolInput, 400), summarizeToolPayload(output, 600), null, durationMs, userId, tenantId); return output; } catch (RuntimeException e) { long durationMs = System.currentTimeMillis() - startedAt; aiSseEventPublisher.emitSafely(emitter, "tool_error", AiChatToolEventDto.builder() .requestId(requestId) .sessionId(sessionId) .toolCallId(toolCallId) .toolName(toolName) .mountName(mountName) .status("FAILED") .inputSummary(summarizeToolPayload(toolInput, 400)) .errorMessage(e.getMessage()) .durationMs(durationMs) .timestamp(System.currentTimeMillis()) .build()); if (thinkingTraceEmitter != null) { thinkingTraceEmitter.onToolResult(toolName, toolCallId, true); } aiToolResultStore.cacheToolResult(tenantId, requestId, toolName, toolInput, false, null, e.getMessage()); toolFailureCount.incrementAndGet(); aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName, "FAILED", summarizeToolPayload(toolInput, 400), null, e.getMessage(), durationMs, userId, tenantId); throw e; } } } }