package com.vincent.rsf.server.ai.service.impl.chat;
|
|
import com.vincent.rsf.framework.exception.CoolException;
|
import com.vincent.rsf.server.ai.dto.AiChatToolEventDto;
|
import com.vincent.rsf.server.ai.service.AiCallLogService;
|
import com.vincent.rsf.server.ai.service.MountedToolCallback;
|
import com.vincent.rsf.server.ai.store.AiCachedToolResult;
|
import com.vincent.rsf.server.ai.store.AiToolResultStore;
|
import lombok.RequiredArgsConstructor;
|
import org.springframework.ai.chat.model.ToolContext;
|
import org.springframework.ai.tool.ToolCallback;
|
import org.springframework.stereotype.Component;
|
import org.springframework.util.StringUtils;
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
|
import java.util.ArrayList;
|
import java.util.List;
|
import java.util.concurrent.atomic.AtomicLong;
|
|
@Component
|
@RequiredArgsConstructor
|
public class AiToolObservationService {
|
|
private final AiSseEventPublisher aiSseEventPublisher;
|
private final AiToolResultStore aiToolResultStore;
|
private final AiCallLogService aiCallLogService;
|
|
public ToolCallback[] wrapToolCallbacks(ToolCallback[] toolCallbacks, SseEmitter emitter, String requestId,
|
Long sessionId, AtomicLong toolCallSequence,
|
AtomicLong toolSuccessCount, AtomicLong toolFailureCount,
|
Long callLogId, Long userId, Long tenantId,
|
AiThinkingTraceEmitter thinkingTraceEmitter) {
|
if (toolCallbacks == null || toolCallbacks.length == 0) {
|
return toolCallbacks;
|
}
|
List<ToolCallback> wrappedCallbacks = new ArrayList<>();
|
for (ToolCallback callback : toolCallbacks) {
|
if (callback == null) {
|
continue;
|
}
|
wrappedCallbacks.add(new ObservableToolCallback(callback, emitter, requestId, sessionId, toolCallSequence,
|
toolSuccessCount, toolFailureCount, callLogId, userId, tenantId, thinkingTraceEmitter));
|
}
|
return wrappedCallbacks.toArray(new ToolCallback[0]);
|
}
|
|
private String summarizeToolPayload(String content, int maxLength) {
|
if (!StringUtils.hasText(content)) {
|
return null;
|
}
|
String normalized = content.trim()
|
.replace("\r", " ")
|
.replace("\n", " ")
|
.replaceAll("\\s+", " ");
|
return normalized.length() > maxLength ? normalized.substring(0, maxLength) : normalized;
|
}
|
|
private class ObservableToolCallback implements ToolCallback {
|
|
private final ToolCallback delegate;
|
private final SseEmitter emitter;
|
private final String requestId;
|
private final Long sessionId;
|
private final AtomicLong toolCallSequence;
|
private final AtomicLong toolSuccessCount;
|
private final AtomicLong toolFailureCount;
|
private final Long callLogId;
|
private final Long userId;
|
private final Long tenantId;
|
private final AiThinkingTraceEmitter thinkingTraceEmitter;
|
|
private ObservableToolCallback(ToolCallback delegate, SseEmitter emitter, String requestId,
|
Long sessionId, AtomicLong toolCallSequence,
|
AtomicLong toolSuccessCount, AtomicLong toolFailureCount,
|
Long callLogId, Long userId, Long tenantId,
|
AiThinkingTraceEmitter thinkingTraceEmitter) {
|
this.delegate = delegate;
|
this.emitter = emitter;
|
this.requestId = requestId;
|
this.sessionId = sessionId;
|
this.toolCallSequence = toolCallSequence;
|
this.toolSuccessCount = toolSuccessCount;
|
this.toolFailureCount = toolFailureCount;
|
this.callLogId = callLogId;
|
this.userId = userId;
|
this.tenantId = tenantId;
|
this.thinkingTraceEmitter = thinkingTraceEmitter;
|
}
|
|
@Override
|
public org.springframework.ai.tool.definition.ToolDefinition getToolDefinition() {
|
return delegate.getToolDefinition();
|
}
|
|
@Override
|
public org.springframework.ai.tool.metadata.ToolMetadata getToolMetadata() {
|
return delegate.getToolMetadata();
|
}
|
|
@Override
|
public String call(String toolInput) {
|
return call(toolInput, null);
|
}
|
|
@Override
|
public String call(String toolInput, ToolContext toolContext) {
|
String toolName = delegate.getToolDefinition() == null ? "unknown" : delegate.getToolDefinition().name();
|
String mountName = delegate instanceof MountedToolCallback ? ((MountedToolCallback) delegate).getMountName() : null;
|
String toolCallId = requestId + "-tool-" + toolCallSequence.incrementAndGet();
|
long startedAt = System.currentTimeMillis();
|
AiCachedToolResult cachedToolResult = aiToolResultStore.getToolResult(tenantId, requestId, toolName, toolInput);
|
if (cachedToolResult != null) {
|
aiSseEventPublisher.emitSafely(emitter, "tool_result", AiChatToolEventDto.builder()
|
.requestId(requestId)
|
.sessionId(sessionId)
|
.toolCallId(toolCallId)
|
.toolName(toolName)
|
.mountName(mountName)
|
.status(cachedToolResult.isSuccess() ? "COMPLETED" : "FAILED")
|
.inputSummary(summarizeToolPayload(toolInput, 400))
|
.outputSummary(summarizeToolPayload(cachedToolResult.getOutput(), 600))
|
.errorMessage(cachedToolResult.getErrorMessage())
|
.durationMs(0L)
|
.timestamp(System.currentTimeMillis())
|
.build());
|
if (thinkingTraceEmitter != null) {
|
thinkingTraceEmitter.onToolResult(toolName, toolCallId, !cachedToolResult.isSuccess());
|
}
|
if (cachedToolResult.isSuccess()) {
|
toolSuccessCount.incrementAndGet();
|
aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
|
"COMPLETED", summarizeToolPayload(toolInput, 400), summarizeToolPayload(cachedToolResult.getOutput(), 600),
|
null, 0L, userId, tenantId);
|
return cachedToolResult.getOutput();
|
}
|
toolFailureCount.incrementAndGet();
|
aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
|
"FAILED", summarizeToolPayload(toolInput, 400), null, cachedToolResult.getErrorMessage(),
|
0L, userId, tenantId);
|
throw new CoolException(cachedToolResult.getErrorMessage());
|
}
|
if (thinkingTraceEmitter != null) {
|
thinkingTraceEmitter.onToolStart(toolName, toolCallId);
|
}
|
aiSseEventPublisher.emitSafely(emitter, "tool_start", AiChatToolEventDto.builder()
|
.requestId(requestId)
|
.sessionId(sessionId)
|
.toolCallId(toolCallId)
|
.toolName(toolName)
|
.mountName(mountName)
|
.status("STARTED")
|
.inputSummary(summarizeToolPayload(toolInput, 400))
|
.timestamp(startedAt)
|
.build());
|
try {
|
String output = toolContext == null ? delegate.call(toolInput) : delegate.call(toolInput, toolContext);
|
long durationMs = System.currentTimeMillis() - startedAt;
|
aiSseEventPublisher.emitSafely(emitter, "tool_result", AiChatToolEventDto.builder()
|
.requestId(requestId)
|
.sessionId(sessionId)
|
.toolCallId(toolCallId)
|
.toolName(toolName)
|
.mountName(mountName)
|
.status("COMPLETED")
|
.inputSummary(summarizeToolPayload(toolInput, 400))
|
.outputSummary(summarizeToolPayload(output, 600))
|
.durationMs(durationMs)
|
.timestamp(System.currentTimeMillis())
|
.build());
|
if (thinkingTraceEmitter != null) {
|
thinkingTraceEmitter.onToolResult(toolName, toolCallId, false);
|
}
|
aiToolResultStore.cacheToolResult(tenantId, requestId, toolName, toolInput, true, output, null);
|
toolSuccessCount.incrementAndGet();
|
aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
|
"COMPLETED", summarizeToolPayload(toolInput, 400), summarizeToolPayload(output, 600),
|
null, durationMs, userId, tenantId);
|
return output;
|
} catch (RuntimeException e) {
|
long durationMs = System.currentTimeMillis() - startedAt;
|
aiSseEventPublisher.emitSafely(emitter, "tool_error", AiChatToolEventDto.builder()
|
.requestId(requestId)
|
.sessionId(sessionId)
|
.toolCallId(toolCallId)
|
.toolName(toolName)
|
.mountName(mountName)
|
.status("FAILED")
|
.inputSummary(summarizeToolPayload(toolInput, 400))
|
.errorMessage(e.getMessage())
|
.durationMs(durationMs)
|
.timestamp(System.currentTimeMillis())
|
.build());
|
if (thinkingTraceEmitter != null) {
|
thinkingTraceEmitter.onToolResult(toolName, toolCallId, true);
|
}
|
aiToolResultStore.cacheToolResult(tenantId, requestId, toolName, toolInput, false, null, e.getMessage());
|
toolFailureCount.incrementAndGet();
|
aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
|
"FAILED", summarizeToolPayload(toolInput, 400), null, e.getMessage(),
|
durationMs, userId, tenantId);
|
throw e;
|
}
|
}
|
}
|
}
|