| | |
| | | package com.vincent.rsf.server.ai.service.impl.chat; |
| | | |
| | | import com.vincent.rsf.framework.exception.CoolException; |
| | | import com.vincent.rsf.server.ai.dto.AiChatToolEventDto; |
| | | import com.vincent.rsf.server.ai.service.AiCallLogService; |
| | | import com.vincent.rsf.server.ai.service.MountedToolCallback; |
| | | import com.vincent.rsf.server.ai.store.AiCachedToolResult; |
| | |
| | | import org.springframework.ai.tool.ToolCallback; |
| | | import org.springframework.stereotype.Component; |
| | | import org.springframework.util.StringUtils; |
| | | import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; |
| | | |
| | | import java.util.ArrayList; |
| | | import java.util.List; |
| | |
| | | @RequiredArgsConstructor |
| | | public class AiToolObservationService { |
| | | |
| | | private final AiSseEventPublisher aiSseEventPublisher; |
| | | private final AiToolResultStore aiToolResultStore; |
| | | private final AiCallLogService aiCallLogService; |
| | | |
| | | public ToolCallback[] wrapToolCallbacks(ToolCallback[] toolCallbacks, SseEmitter emitter, String requestId, |
| | | public ToolCallback[] wrapToolCallbacks(ToolCallback[] toolCallbacks, String requestId, |
| | | Long sessionId, AtomicLong toolCallSequence, |
| | | AtomicLong toolSuccessCount, AtomicLong toolFailureCount, |
| | | Long callLogId, Long userId, Long tenantId, |
| | | AiThinkingTraceEmitter thinkingTraceEmitter) { |
| | | AiChatTraceEmitter traceEmitter) { |
| | | if (toolCallbacks == null || toolCallbacks.length == 0) { |
| | | return toolCallbacks; |
| | | } |
| | |
| | | if (callback == null) { |
| | | continue; |
| | | } |
| | | wrappedCallbacks.add(new ObservableToolCallback(callback, emitter, requestId, sessionId, toolCallSequence, |
| | | toolSuccessCount, toolFailureCount, callLogId, userId, tenantId, thinkingTraceEmitter)); |
| | | wrappedCallbacks.add(new ObservableToolCallback(callback, requestId, sessionId, toolCallSequence, |
| | | toolSuccessCount, toolFailureCount, callLogId, userId, tenantId, traceEmitter)); |
| | | } |
| | | return wrappedCallbacks.toArray(new ToolCallback[0]); |
| | | } |
| | |
| | | private class ObservableToolCallback implements ToolCallback { |
| | | |
| | | private final ToolCallback delegate; |
| | | private final SseEmitter emitter; |
| | | private final String requestId; |
| | | private final Long sessionId; |
| | | private final AtomicLong toolCallSequence; |
| | |
| | | private final Long callLogId; |
| | | private final Long userId; |
| | | private final Long tenantId; |
| | | private final AiThinkingTraceEmitter thinkingTraceEmitter; |
| | | private final AiChatTraceEmitter traceEmitter; |
| | | |
| | | private ObservableToolCallback(ToolCallback delegate, SseEmitter emitter, String requestId, |
| | | private ObservableToolCallback(ToolCallback delegate, String requestId, |
| | | Long sessionId, AtomicLong toolCallSequence, |
| | | AtomicLong toolSuccessCount, AtomicLong toolFailureCount, |
| | | Long callLogId, Long userId, Long tenantId, |
| | | AiThinkingTraceEmitter thinkingTraceEmitter) { |
| | | AiChatTraceEmitter traceEmitter) { |
| | | this.delegate = delegate; |
| | | this.emitter = emitter; |
| | | this.requestId = requestId; |
| | | this.sessionId = sessionId; |
| | | this.toolCallSequence = toolCallSequence; |
| | |
| | | this.callLogId = callLogId; |
| | | this.userId = userId; |
| | | this.tenantId = tenantId; |
| | | this.thinkingTraceEmitter = thinkingTraceEmitter; |
| | | this.traceEmitter = traceEmitter; |
| | | } |
| | | |
| | | @Override |
| | |
| | | String mountName = delegate instanceof MountedToolCallback ? ((MountedToolCallback) delegate).getMountName() : null; |
| | | String toolCallId = requestId + "-tool-" + toolCallSequence.incrementAndGet(); |
| | | long startedAt = System.currentTimeMillis(); |
| | | String inputSummary = summarizeToolPayload(toolInput, 400); |
| | | AiCachedToolResult cachedToolResult = aiToolResultStore.getToolResult(tenantId, requestId, toolName, toolInput); |
| | | if (cachedToolResult != null) { |
| | | aiSseEventPublisher.emitSafely(emitter, "tool_result", AiChatToolEventDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .toolCallId(toolCallId) |
| | | .toolName(toolName) |
| | | .mountName(mountName) |
| | | .status(cachedToolResult.isSuccess() ? "COMPLETED" : "FAILED") |
| | | .inputSummary(summarizeToolPayload(toolInput, 400)) |
| | | .outputSummary(summarizeToolPayload(cachedToolResult.getOutput(), 600)) |
| | | .errorMessage(cachedToolResult.getErrorMessage()) |
| | | .durationMs(0L) |
| | | .timestamp(System.currentTimeMillis()) |
| | | .build()); |
| | | if (thinkingTraceEmitter != null) { |
| | | thinkingTraceEmitter.onToolResult(toolName, toolCallId, !cachedToolResult.isSuccess()); |
| | | String outputSummary = summarizeToolPayload(cachedToolResult.getOutput(), 600); |
| | | String errorMessage = cachedToolResult.getErrorMessage(); |
| | | if (traceEmitter != null) { |
| | | traceEmitter.onToolResult(toolName, mountName, toolCallId, inputSummary, outputSummary, |
| | | errorMessage, 0L, System.currentTimeMillis(), !cachedToolResult.isSuccess()); |
| | | } |
| | | if (cachedToolResult.isSuccess()) { |
| | | toolSuccessCount.incrementAndGet(); |
| | | aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName, |
| | | "COMPLETED", summarizeToolPayload(toolInput, 400), summarizeToolPayload(cachedToolResult.getOutput(), 600), |
| | | "COMPLETED", inputSummary, outputSummary, |
| | | null, 0L, userId, tenantId); |
| | | return cachedToolResult.getOutput(); |
| | | } |
| | | toolFailureCount.incrementAndGet(); |
| | | aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName, |
| | | "FAILED", summarizeToolPayload(toolInput, 400), null, cachedToolResult.getErrorMessage(), |
| | | "FAILED", inputSummary, null, errorMessage, |
| | | 0L, userId, tenantId); |
| | | throw new CoolException(cachedToolResult.getErrorMessage()); |
| | | throw new CoolException(errorMessage); |
| | | } |
| | | if (thinkingTraceEmitter != null) { |
| | | thinkingTraceEmitter.onToolStart(toolName, toolCallId); |
| | | if (traceEmitter != null) { |
| | | traceEmitter.onToolStart(toolName, mountName, toolCallId, inputSummary, startedAt); |
| | | } |
| | | aiSseEventPublisher.emitSafely(emitter, "tool_start", AiChatToolEventDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .toolCallId(toolCallId) |
| | | .toolName(toolName) |
| | | .mountName(mountName) |
| | | .status("STARTED") |
| | | .inputSummary(summarizeToolPayload(toolInput, 400)) |
| | | .timestamp(startedAt) |
| | | .build()); |
| | | try { |
| | | String output = toolContext == null ? delegate.call(toolInput) : delegate.call(toolInput, toolContext); |
| | | long durationMs = System.currentTimeMillis() - startedAt; |
| | | aiSseEventPublisher.emitSafely(emitter, "tool_result", AiChatToolEventDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .toolCallId(toolCallId) |
| | | .toolName(toolName) |
| | | .mountName(mountName) |
| | | .status("COMPLETED") |
| | | .inputSummary(summarizeToolPayload(toolInput, 400)) |
| | | .outputSummary(summarizeToolPayload(output, 600)) |
| | | .durationMs(durationMs) |
| | | .timestamp(System.currentTimeMillis()) |
| | | .build()); |
| | | if (thinkingTraceEmitter != null) { |
| | | thinkingTraceEmitter.onToolResult(toolName, toolCallId, false); |
| | | String outputSummary = summarizeToolPayload(output, 600); |
| | | if (traceEmitter != null) { |
| | | traceEmitter.onToolResult(toolName, mountName, toolCallId, inputSummary, outputSummary, |
| | | null, durationMs, System.currentTimeMillis(), false); |
| | | } |
| | | aiToolResultStore.cacheToolResult(tenantId, requestId, toolName, toolInput, true, output, null); |
| | | toolSuccessCount.incrementAndGet(); |
| | | aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName, |
| | | "COMPLETED", summarizeToolPayload(toolInput, 400), summarizeToolPayload(output, 600), |
| | | "COMPLETED", inputSummary, outputSummary, |
| | | null, durationMs, userId, tenantId); |
| | | return output; |
| | | } catch (RuntimeException e) { |
| | | long durationMs = System.currentTimeMillis() - startedAt; |
| | | aiSseEventPublisher.emitSafely(emitter, "tool_error", AiChatToolEventDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .toolCallId(toolCallId) |
| | | .toolName(toolName) |
| | | .mountName(mountName) |
| | | .status("FAILED") |
| | | .inputSummary(summarizeToolPayload(toolInput, 400)) |
| | | .errorMessage(e.getMessage()) |
| | | .durationMs(durationMs) |
| | | .timestamp(System.currentTimeMillis()) |
| | | .build()); |
| | | if (thinkingTraceEmitter != null) { |
| | | thinkingTraceEmitter.onToolResult(toolName, toolCallId, true); |
| | | String errorMessage = e.getMessage(); |
| | | if (traceEmitter != null) { |
| | | traceEmitter.onToolResult(toolName, mountName, toolCallId, inputSummary, null, |
| | | errorMessage, durationMs, System.currentTimeMillis(), true); |
| | | } |
| | | aiToolResultStore.cacheToolResult(tenantId, requestId, toolName, toolInput, false, null, e.getMessage()); |
| | | aiToolResultStore.cacheToolResult(tenantId, requestId, toolName, toolInput, false, null, errorMessage); |
| | | toolFailureCount.incrementAndGet(); |
| | | aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName, |
| | | "FAILED", summarizeToolPayload(toolInput, 400), null, e.getMessage(), |
| | | "FAILED", inputSummary, null, errorMessage, |
| | | durationMs, userId, tenantId); |
| | | throw e; |
| | | } |