| | |
| | | import com.vincent.rsf.server.ai.dto.AiChatSessionRenameRequest; |
| | | import com.vincent.rsf.server.ai.dto.AiChatToolEventDto; |
| | | import com.vincent.rsf.server.ai.dto.AiResolvedConfig; |
| | | import com.vincent.rsf.server.ai.entity.AiCallLog; |
| | | import com.vincent.rsf.server.ai.entity.AiParam; |
| | | import com.vincent.rsf.server.ai.entity.AiPrompt; |
| | | import com.vincent.rsf.server.ai.entity.AiChatSession; |
| | | import com.vincent.rsf.server.ai.enums.AiErrorCategory; |
| | | import com.vincent.rsf.server.ai.exception.AiChatException; |
| | | import com.vincent.rsf.server.ai.service.AiCallLogService; |
| | | import com.vincent.rsf.server.ai.service.AiChatService; |
| | | import com.vincent.rsf.server.ai.service.AiChatMemoryService; |
| | | import com.vincent.rsf.server.ai.service.AiConfigResolverService; |
| | | import com.vincent.rsf.server.ai.service.MountedToolCallback; |
| | | import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory; |
| | | import io.micrometer.observation.ObservationRegistry; |
| | | import lombok.RequiredArgsConstructor; |
| | |
| | | private final AiConfigResolverService aiConfigResolverService; |
| | | private final AiChatMemoryService aiChatMemoryService; |
| | | private final McpMountRuntimeFactory mcpMountRuntimeFactory; |
| | | private final AiCallLogService aiCallLogService; |
| | | private final GenericApplicationContext applicationContext; |
| | | private final ObservationRegistry observationRegistry; |
| | | private final ObjectMapper objectMapper; |
| | |
| | | long startedAt = System.currentTimeMillis(); |
| | | AtomicReference<Long> firstTokenAtRef = new AtomicReference<>(); |
| | | AtomicLong toolCallSequence = new AtomicLong(0); |
| | | AtomicLong toolSuccessCount = new AtomicLong(0); |
| | | AtomicLong toolFailureCount = new AtomicLong(0); |
| | | Long sessionId = request.getSessionId(); |
| | | Long callLogId = null; |
| | | String model = null; |
| | | try { |
| | | ensureIdentity(userId, tenantId); |
| | |
| | | sessionId = session.getId(); |
| | | AiChatMemoryDto memory = loadMemory(userId, tenantId, config.getPromptCode(), session.getId()); |
| | | List<AiChatMessageDto> mergedMessages = mergeMessages(memory.getShortMemoryMessages(), request.getMessages()); |
| | | AiCallLog callLog = aiCallLogService.startCallLog( |
| | | requestId, |
| | | session.getId(), |
| | | userId, |
| | | tenantId, |
| | | config.getPromptCode(), |
| | | config.getPrompt().getName(), |
| | | config.getAiParam().getModel(), |
| | | config.getMcpMounts().size(), |
| | | config.getMcpMounts().size(), |
| | | config.getMcpMounts().stream().map(item -> item.getName()).toList() |
| | | ); |
| | | callLogId = callLog.getId(); |
| | | try (McpMountRuntimeFactory.McpMountRuntime runtime = createRuntime(config, userId)) { |
| | | emitStrict(emitter, "start", AiChatRuntimeDto.builder() |
| | | .requestId(requestId) |
| | |
| | | requestId, userId, tenantId, session.getId(), resolvedModel); |
| | | |
| | | ToolCallback[] observableToolCallbacks = wrapToolCallbacks( |
| | | runtime.getToolCallbacks(), emitter, requestId, session.getId(), toolCallSequence |
| | | runtime.getToolCallbacks(), emitter, requestId, session.getId(), toolCallSequence, |
| | | toolSuccessCount, toolFailureCount, callLogId, userId, tenantId |
| | | ); |
| | | Prompt prompt = new Prompt( |
| | | buildPromptMessages(memory, mergedMessages, config.getPrompt(), request.getMetadata()), |
| | |
| | | } |
| | | emitDone(emitter, requestId, response.getMetadata(), config.getAiParam().getModel(), session.getId(), startedAt, firstTokenAtRef.get()); |
| | | emitSafely(emitter, "status", buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get())); |
| | | aiCallLogService.completeCallLog( |
| | | callLogId, |
| | | "COMPLETED", |
| | | System.currentTimeMillis() - startedAt, |
| | | resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()), |
| | | response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getPromptTokens(), |
| | | response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getCompletionTokens(), |
| | | response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getTotalTokens(), |
| | | toolSuccessCount.get(), |
| | | toolFailureCount.get() |
| | | ); |
| | | log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}", |
| | | requestId, session.getId(), System.currentTimeMillis() - startedAt, resolveFirstTokenLatency(startedAt, firstTokenAtRef.get())); |
| | | emitter.complete(); |
| | |
| | | aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), assistantContent.toString()); |
| | | emitDone(emitter, requestId, lastMetadata.get(), config.getAiParam().getModel(), session.getId(), startedAt, firstTokenAtRef.get()); |
| | | emitSafely(emitter, "status", buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get())); |
| | | aiCallLogService.completeCallLog( |
| | | callLogId, |
| | | "COMPLETED", |
| | | System.currentTimeMillis() - startedAt, |
| | | resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()), |
| | | lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getPromptTokens(), |
| | | lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getCompletionTokens(), |
| | | lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getTotalTokens(), |
| | | toolSuccessCount.get(), |
| | | toolFailureCount.get() |
| | | ); |
| | | log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}", |
| | | requestId, session.getId(), System.currentTimeMillis() - startedAt, resolveFirstTokenLatency(startedAt, firstTokenAtRef.get())); |
| | | emitter.complete(); |
| | | } |
| | | } catch (AiChatException e) { |
| | | handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(), e); |
| | | handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(), e, |
| | | callLogId, toolSuccessCount.get(), toolFailureCount.get()); |
| | | } catch (Exception e) { |
| | | handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(), |
| | | buildAiException("AI_INTERNAL_ERROR", AiErrorCategory.INTERNAL, "INTERNAL", |
| | | e == null ? "AI 对话失败" : e.getMessage(), e)); |
| | | e == null ? "AI 对话失败" : e.getMessage(), e), |
| | | callLogId, toolSuccessCount.get(), toolFailureCount.get()); |
| | | } finally { |
| | | log.debug("AI chat stream finished, requestId={}", requestId); |
| | | } |
| | |
| | | return firstTokenAt == null ? null : Math.max(0L, firstTokenAt - startedAt); |
| | | } |
| | | |
| | | private void handleStreamFailure(SseEmitter emitter, String requestId, Long sessionId, String model, long startedAt, Long firstTokenAt, AiChatException exception) { |
| | | private void handleStreamFailure(SseEmitter emitter, String requestId, Long sessionId, String model, long startedAt, |
| | | Long firstTokenAt, AiChatException exception, Long callLogId, |
| | | long toolSuccessCount, long toolFailureCount) { |
| | | if (isClientAbortException(exception)) { |
| | | log.warn("AI chat aborted by client, requestId={}, sessionId={}, stage={}, message={}", |
| | | requestId, sessionId, exception.getStage(), exception.getMessage()); |
| | | emitSafely(emitter, "status", buildTerminalStatus(requestId, sessionId, "ABORTED", model, startedAt, firstTokenAt)); |
| | | aiCallLogService.failCallLog( |
| | | callLogId, |
| | | "ABORTED", |
| | | exception.getCategory().name(), |
| | | exception.getStage(), |
| | | exception.getMessage(), |
| | | System.currentTimeMillis() - startedAt, |
| | | resolveFirstTokenLatency(startedAt, firstTokenAt), |
| | | toolSuccessCount, |
| | | toolFailureCount |
| | | ); |
| | | emitter.completeWithError(exception); |
| | | return; |
| | | } |
| | |
| | | .message(exception.getMessage()) |
| | | .timestamp(Instant.now().toEpochMilli()) |
| | | .build()); |
| | | aiCallLogService.failCallLog( |
| | | callLogId, |
| | | "FAILED", |
| | | exception.getCategory().name(), |
| | | exception.getStage(), |
| | | exception.getMessage(), |
| | | System.currentTimeMillis() - startedAt, |
| | | resolveFirstTokenLatency(startedAt, firstTokenAt), |
| | | toolSuccessCount, |
| | | toolFailureCount |
| | | ); |
| | | emitter.completeWithError(exception); |
| | | } |
| | | |
| | |
| | | } |
| | | |
| | | private ToolCallback[] wrapToolCallbacks(ToolCallback[] toolCallbacks, SseEmitter emitter, String requestId, |
| | | Long sessionId, AtomicLong toolCallSequence) { |
| | | Long sessionId, AtomicLong toolCallSequence, |
| | | AtomicLong toolSuccessCount, AtomicLong toolFailureCount, |
| | | Long callLogId, Long userId, Long tenantId) { |
| | | if (Cools.isEmpty(toolCallbacks)) { |
| | | return toolCallbacks; |
| | | } |
| | |
| | | if (callback == null) { |
| | | continue; |
| | | } |
| | | wrappedCallbacks.add(new ObservableToolCallback(callback, emitter, requestId, sessionId, toolCallSequence)); |
| | | wrappedCallbacks.add(new ObservableToolCallback(callback, emitter, requestId, sessionId, toolCallSequence, |
| | | toolSuccessCount, toolFailureCount, callLogId, userId, tenantId)); |
| | | } |
| | | return wrappedCallbacks.toArray(new ToolCallback[0]); |
| | | } |
| | |
| | | private final String requestId; |
| | | private final Long sessionId; |
| | | private final AtomicLong toolCallSequence; |
| | | private final AtomicLong toolSuccessCount; |
| | | private final AtomicLong toolFailureCount; |
| | | private final Long callLogId; |
| | | private final Long userId; |
| | | private final Long tenantId; |
| | | |
| | | private ObservableToolCallback(ToolCallback delegate, SseEmitter emitter, String requestId, |
| | | Long sessionId, AtomicLong toolCallSequence) { |
| | | Long sessionId, AtomicLong toolCallSequence, |
| | | AtomicLong toolSuccessCount, AtomicLong toolFailureCount, |
| | | Long callLogId, Long userId, Long tenantId) { |
| | | this.delegate = delegate; |
| | | this.emitter = emitter; |
| | | this.requestId = requestId; |
| | | this.sessionId = sessionId; |
| | | this.toolCallSequence = toolCallSequence; |
| | | this.toolSuccessCount = toolSuccessCount; |
| | | this.toolFailureCount = toolFailureCount; |
| | | this.callLogId = callLogId; |
| | | this.userId = userId; |
| | | this.tenantId = tenantId; |
| | | } |
| | | |
| | | @Override |
| | |
| | | @Override |
| | | public String call(String toolInput, ToolContext toolContext) { |
| | | String toolName = delegate.getToolDefinition() == null ? "unknown" : delegate.getToolDefinition().name(); |
| | | String mountName = delegate instanceof MountedToolCallback ? ((MountedToolCallback) delegate).getMountName() : null; |
| | | String toolCallId = requestId + "-tool-" + toolCallSequence.incrementAndGet(); |
| | | long startedAt = System.currentTimeMillis(); |
| | | emitSafely(emitter, "tool_start", AiChatToolEventDto.builder() |
| | |
| | | .sessionId(sessionId) |
| | | .toolCallId(toolCallId) |
| | | .toolName(toolName) |
| | | .mountName(mountName) |
| | | .status("STARTED") |
| | | .inputSummary(summarizeToolPayload(toolInput, 400)) |
| | | .timestamp(startedAt) |
| | | .build()); |
| | | try { |
| | | String output = toolContext == null ? delegate.call(toolInput) : delegate.call(toolInput, toolContext); |
| | | long durationMs = System.currentTimeMillis() - startedAt; |
| | | emitSafely(emitter, "tool_result", AiChatToolEventDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .toolCallId(toolCallId) |
| | | .toolName(toolName) |
| | | .mountName(mountName) |
| | | .status("COMPLETED") |
| | | .inputSummary(summarizeToolPayload(toolInput, 400)) |
| | | .outputSummary(summarizeToolPayload(output, 600)) |
| | | .durationMs(System.currentTimeMillis() - startedAt) |
| | | .durationMs(durationMs) |
| | | .timestamp(System.currentTimeMillis()) |
| | | .build()); |
| | | toolSuccessCount.incrementAndGet(); |
| | | aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName, |
| | | "COMPLETED", summarizeToolPayload(toolInput, 400), summarizeToolPayload(output, 600), |
| | | null, durationMs, userId, tenantId); |
| | | return output; |
| | | } catch (RuntimeException e) { |
| | | long durationMs = System.currentTimeMillis() - startedAt; |
| | | emitSafely(emitter, "tool_error", AiChatToolEventDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .toolCallId(toolCallId) |
| | | .toolName(toolName) |
| | | .mountName(mountName) |
| | | .status("FAILED") |
| | | .inputSummary(summarizeToolPayload(toolInput, 400)) |
| | | .errorMessage(e.getMessage()) |
| | | .durationMs(System.currentTimeMillis() - startedAt) |
| | | .durationMs(durationMs) |
| | | .timestamp(System.currentTimeMillis()) |
| | | .build()); |
| | | toolFailureCount.incrementAndGet(); |
| | | aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName, |
| | | "FAILED", summarizeToolPayload(toolInput, 400), null, e.getMessage(), |
| | | durationMs, userId, tenantId); |
| | | throw e; |
| | | } |
| | | } |