package com.vincent.rsf.server.ai.service.impl.chat;
|
|
import com.vincent.rsf.server.ai.dto.AiChatTraceEventDto;
|
import org.springframework.util.StringUtils;
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
|
import java.time.Instant;
|
import java.util.Map;
|
import java.util.Objects;
|
import java.util.concurrent.ConcurrentHashMap;
|
import java.util.concurrent.atomic.AtomicLong;
|
|
public class AiChatTraceEmitter {
|
|
private static final String TRACE_EVENT_NAME = "trace";
|
|
private final AiSseEventPublisher aiSseEventPublisher;
|
private final SseEmitter emitter;
|
private final String requestId;
|
private final Long sessionId;
|
private final AtomicLong traceSequence;
|
private final Map<String, Long> traceOrderMap = new ConcurrentHashMap<>();
|
private String currentPhase;
|
private String currentStatus;
|
|
public AiChatTraceEmitter(AiSseEventPublisher aiSseEventPublisher, SseEmitter emitter, String requestId,
|
Long sessionId, AtomicLong traceSequence) {
|
this.aiSseEventPublisher = aiSseEventPublisher;
|
this.emitter = emitter;
|
this.requestId = requestId;
|
this.sessionId = sessionId;
|
this.traceSequence = traceSequence;
|
}
|
|
public void startAnalyze() {
|
if (currentPhase != null) {
|
return;
|
}
|
currentPhase = "ANALYZE";
|
currentStatus = "STARTED";
|
emitThinkingEvent("ANALYZE", "STARTED", "正在分析问题",
|
"已接收你的问题,正在理解意图并判断是否需要调用工具。", null);
|
}
|
|
public void onToolStart(String toolName, String mountName, String toolCallId, String inputSummary, long timestamp) {
|
completeCurrentPhase();
|
emitToolEvent("STARTED", "开始调用工具", null, toolCallId, toolName, mountName, inputSummary,
|
null, null, null, timestamp);
|
}
|
|
public void onToolResult(String toolName, String mountName, String toolCallId, String inputSummary,
|
String outputSummary, String errorMessage, Long durationMs, long timestamp,
|
boolean failed) {
|
emitToolEvent(failed ? "FAILED" : "COMPLETED",
|
failed ? "工具调用失败" : "工具调用完成",
|
null,
|
toolCallId,
|
toolName,
|
mountName,
|
inputSummary,
|
outputSummary,
|
errorMessage,
|
durationMs,
|
timestamp);
|
}
|
|
public void startAnswer() {
|
switchPhase("ANSWER", "STARTED", "正在整理答案", "已完成分析,正在组织最终回复内容。", null);
|
}
|
|
public void completeCurrentPhase() {
|
if (!StringUtils.hasText(currentPhase) || isTerminalStatus(currentStatus)) {
|
return;
|
}
|
currentStatus = "COMPLETED";
|
emitThinkingEvent(currentPhase, "COMPLETED", resolveCompleteTitle(currentPhase),
|
resolveCompleteContent(currentPhase), null);
|
}
|
|
public void markTerminated(String terminalStatus) {
|
if (!StringUtils.hasText(currentPhase) || isTerminalStatus(currentStatus)) {
|
return;
|
}
|
currentStatus = terminalStatus;
|
emitThinkingEvent(currentPhase, terminalStatus,
|
"ABORTED".equals(terminalStatus) ? "思考已中止" : "思考失败",
|
"ABORTED".equals(terminalStatus)
|
? "本轮对话已被中止,思考过程提前结束。"
|
: "本轮对话在生成答案前失败,当前思考过程已停止。",
|
null);
|
}
|
|
private void switchPhase(String nextPhase, String nextStatus, String title, String content, String toolCallId) {
|
if (!Objects.equals(currentPhase, nextPhase)) {
|
completeCurrentPhase();
|
}
|
currentPhase = nextPhase;
|
currentStatus = nextStatus;
|
emitThinkingEvent(nextPhase, nextStatus, title, content, toolCallId);
|
}
|
|
private void emitThinkingEvent(String phase, String status, String title, String content, String toolCallId) {
|
emitTraceEvent(AiChatTraceEventDto.builder()
|
.requestId(requestId)
|
.sessionId(sessionId)
|
.traceType("thinking")
|
.phase(phase)
|
.status(status)
|
.title(title)
|
.content(content)
|
.toolCallId(toolCallId)
|
.timestamp(Instant.now().toEpochMilli())
|
.build(), buildThinkingTraceId(phase));
|
}
|
|
private void emitToolEvent(String status, String title, String content, String toolCallId, String toolName,
|
String mountName, String inputSummary, String outputSummary, String errorMessage,
|
Long durationMs, long timestamp) {
|
emitTraceEvent(AiChatTraceEventDto.builder()
|
.requestId(requestId)
|
.sessionId(sessionId)
|
.traceType("tool")
|
.status(status)
|
.title(title)
|
.content(content)
|
.toolCallId(toolCallId)
|
.toolName(toolName)
|
.mountName(mountName)
|
.inputSummary(inputSummary)
|
.outputSummary(outputSummary)
|
.errorMessage(errorMessage)
|
.durationMs(durationMs)
|
.timestamp(timestamp)
|
.build(), buildToolTraceId(toolCallId));
|
}
|
|
private void emitTraceEvent(AiChatTraceEventDto payload, String traceId) {
|
long sequence = traceOrderMap.computeIfAbsent(traceId, ignored -> traceSequence.incrementAndGet());
|
payload.setSequence(sequence);
|
payload.setTraceId(traceId);
|
aiSseEventPublisher.emitSafely(emitter, TRACE_EVENT_NAME, payload);
|
}
|
|
private String buildThinkingTraceId(String phase) {
|
return requestId + "-thinking-" + phase;
|
}
|
|
private String buildToolTraceId(String toolCallId) {
|
return toolCallId;
|
}
|
|
private boolean isTerminalStatus(String status) {
|
return "COMPLETED".equals(status) || "FAILED".equals(status) || "ABORTED".equals(status);
|
}
|
|
private String resolveCompleteTitle(String phase) {
|
if ("ANSWER".equals(phase)) {
|
return "答案整理完成";
|
}
|
if ("TOOL_CALL".equals(phase)) {
|
return "工具分析完成";
|
}
|
return "问题分析完成";
|
}
|
|
private String resolveCompleteContent(String phase) {
|
if ("ANSWER".equals(phase)) {
|
return "最终答复已生成完成。";
|
}
|
if ("TOOL_CALL".equals(phase)) {
|
return "工具调用阶段已结束,相关信息已整理完毕。";
|
}
|
return "问题意图和处理方向已分析完成。";
|
}
|
|
private String safeLabel(String value, String fallback) {
|
return StringUtils.hasText(value) ? value : fallback;
|
}
|
}
|