package com.vincent.rsf.server.ai.service.impl.chat; import com.vincent.rsf.server.ai.dto.AiChatTraceEventDto; import org.springframework.util.StringUtils; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import java.time.Instant; import java.util.Map; import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; public class AiChatTraceEmitter { private static final String TRACE_EVENT_NAME = "trace"; private final AiSseEventPublisher aiSseEventPublisher; private final SseEmitter emitter; private final String requestId; private final Long sessionId; private final AtomicLong traceSequence; private final Map traceOrderMap = new ConcurrentHashMap<>(); private String currentPhase; private String currentStatus; public AiChatTraceEmitter(AiSseEventPublisher aiSseEventPublisher, SseEmitter emitter, String requestId, Long sessionId, AtomicLong traceSequence) { this.aiSseEventPublisher = aiSseEventPublisher; this.emitter = emitter; this.requestId = requestId; this.sessionId = sessionId; this.traceSequence = traceSequence; } public void startAnalyze() { if (currentPhase != null) { return; } currentPhase = "ANALYZE"; currentStatus = "STARTED"; emitThinkingEvent("ANALYZE", "STARTED", "正在分析问题", "已接收你的问题,正在理解意图并判断是否需要调用工具。", null); } public void onToolStart(String toolName, String mountName, String toolCallId, String inputSummary, long timestamp) { completeCurrentPhase(); emitToolEvent("STARTED", "开始调用工具", null, toolCallId, toolName, mountName, inputSummary, null, null, null, timestamp); } public void onToolResult(String toolName, String mountName, String toolCallId, String inputSummary, String outputSummary, String errorMessage, Long durationMs, long timestamp, boolean failed) { emitToolEvent(failed ? "FAILED" : "COMPLETED", failed ? "工具调用失败" : "工具调用完成", null, toolCallId, toolName, mountName, inputSummary, outputSummary, errorMessage, durationMs, timestamp); } public void startAnswer() { switchPhase("ANSWER", "STARTED", "正在整理答案", "已完成分析,正在组织最终回复内容。", null); } public void completeCurrentPhase() { if (!StringUtils.hasText(currentPhase) || isTerminalStatus(currentStatus)) { return; } currentStatus = "COMPLETED"; emitThinkingEvent(currentPhase, "COMPLETED", resolveCompleteTitle(currentPhase), resolveCompleteContent(currentPhase), null); } public void markTerminated(String terminalStatus) { if (!StringUtils.hasText(currentPhase) || isTerminalStatus(currentStatus)) { return; } currentStatus = terminalStatus; emitThinkingEvent(currentPhase, terminalStatus, "ABORTED".equals(terminalStatus) ? "思考已中止" : "思考失败", "ABORTED".equals(terminalStatus) ? "本轮对话已被中止,思考过程提前结束。" : "本轮对话在生成答案前失败,当前思考过程已停止。", null); } private void switchPhase(String nextPhase, String nextStatus, String title, String content, String toolCallId) { if (!Objects.equals(currentPhase, nextPhase)) { completeCurrentPhase(); } currentPhase = nextPhase; currentStatus = nextStatus; emitThinkingEvent(nextPhase, nextStatus, title, content, toolCallId); } private void emitThinkingEvent(String phase, String status, String title, String content, String toolCallId) { emitTraceEvent(AiChatTraceEventDto.builder() .requestId(requestId) .sessionId(sessionId) .traceType("thinking") .phase(phase) .status(status) .title(title) .content(content) .toolCallId(toolCallId) .timestamp(Instant.now().toEpochMilli()) .build(), buildThinkingTraceId(phase)); } private void emitToolEvent(String status, String title, String content, String toolCallId, String toolName, String mountName, String inputSummary, String outputSummary, String errorMessage, Long durationMs, long timestamp) { emitTraceEvent(AiChatTraceEventDto.builder() .requestId(requestId) .sessionId(sessionId) .traceType("tool") .status(status) .title(title) .content(content) .toolCallId(toolCallId) .toolName(toolName) .mountName(mountName) .inputSummary(inputSummary) .outputSummary(outputSummary) .errorMessage(errorMessage) .durationMs(durationMs) .timestamp(timestamp) .build(), buildToolTraceId(toolCallId)); } private void emitTraceEvent(AiChatTraceEventDto payload, String traceId) { long sequence = traceOrderMap.computeIfAbsent(traceId, ignored -> traceSequence.incrementAndGet()); payload.setSequence(sequence); payload.setTraceId(traceId); aiSseEventPublisher.emitSafely(emitter, TRACE_EVENT_NAME, payload); } private String buildThinkingTraceId(String phase) { return requestId + "-thinking-" + phase; } private String buildToolTraceId(String toolCallId) { return toolCallId; } private boolean isTerminalStatus(String status) { return "COMPLETED".equals(status) || "FAILED".equals(status) || "ABORTED".equals(status); } private String resolveCompleteTitle(String phase) { if ("ANSWER".equals(phase)) { return "答案整理完成"; } if ("TOOL_CALL".equals(phase)) { return "工具分析完成"; } return "问题分析完成"; } private String resolveCompleteContent(String phase) { if ("ANSWER".equals(phase)) { return "最终答复已生成完成。"; } if ("TOOL_CALL".equals(phase)) { return "工具调用阶段已结束,相关信息已整理完毕。"; } return "问题意图和处理方向已分析完成。"; } private String safeLabel(String value, String fallback) { return StringUtils.hasText(value) ? value : fallback; } }