zhou zhou
16 小时以前 b05f094ac51dce91eb8c00235226d54a04658c6d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
package com.vincent.rsf.server.ai.service.impl.chat;
 
import com.vincent.rsf.server.ai.dto.AiChatTraceEventDto;
import org.springframework.util.StringUtils;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
 
import java.time.Instant;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
 
public class AiChatTraceEmitter {
 
    private static final String TRACE_EVENT_NAME = "trace";
 
    private final AiSseEventPublisher aiSseEventPublisher;
    private final SseEmitter emitter;
    private final String requestId;
    private final Long sessionId;
    private final AtomicLong traceSequence;
    private final Map<String, Long> traceOrderMap = new ConcurrentHashMap<>();
    private String currentPhase;
    private String currentStatus;
 
    public AiChatTraceEmitter(AiSseEventPublisher aiSseEventPublisher, SseEmitter emitter, String requestId,
                              Long sessionId, AtomicLong traceSequence) {
        this.aiSseEventPublisher = aiSseEventPublisher;
        this.emitter = emitter;
        this.requestId = requestId;
        this.sessionId = sessionId;
        this.traceSequence = traceSequence;
    }
 
    public void startAnalyze() {
        if (currentPhase != null) {
            return;
        }
        currentPhase = "ANALYZE";
        currentStatus = "STARTED";
        emitThinkingEvent("ANALYZE", "STARTED", "正在分析问题",
                "已接收你的问题,正在理解意图并判断是否需要调用工具。", null);
    }
 
    public void onToolStart(String toolName, String mountName, String toolCallId, String inputSummary, long timestamp) {
        completeCurrentPhase();
        emitToolEvent("STARTED", "开始调用工具", null, toolCallId, toolName, mountName, inputSummary,
                null, null, null, timestamp);
    }
 
    public void onToolResult(String toolName, String mountName, String toolCallId, String inputSummary,
                             String outputSummary, String errorMessage, Long durationMs, long timestamp,
                             boolean failed) {
        emitToolEvent(failed ? "FAILED" : "COMPLETED",
                failed ? "工具调用失败" : "工具调用完成",
                null,
                toolCallId,
                toolName,
                mountName,
                inputSummary,
                outputSummary,
                errorMessage,
                durationMs,
                timestamp);
    }
 
    public void startAnswer() {
        switchPhase("ANSWER", "STARTED", "正在整理答案", "已完成分析,正在组织最终回复内容。", null);
    }
 
    public void completeCurrentPhase() {
        if (!StringUtils.hasText(currentPhase) || isTerminalStatus(currentStatus)) {
            return;
        }
        currentStatus = "COMPLETED";
        emitThinkingEvent(currentPhase, "COMPLETED", resolveCompleteTitle(currentPhase),
                resolveCompleteContent(currentPhase), null);
    }
 
    public void markTerminated(String terminalStatus) {
        if (!StringUtils.hasText(currentPhase) || isTerminalStatus(currentStatus)) {
            return;
        }
        currentStatus = terminalStatus;
        emitThinkingEvent(currentPhase, terminalStatus,
                "ABORTED".equals(terminalStatus) ? "思考已中止" : "思考失败",
                "ABORTED".equals(terminalStatus)
                        ? "本轮对话已被中止,思考过程提前结束。"
                        : "本轮对话在生成答案前失败,当前思考过程已停止。",
                null);
    }
 
    private void switchPhase(String nextPhase, String nextStatus, String title, String content, String toolCallId) {
        if (!Objects.equals(currentPhase, nextPhase)) {
            completeCurrentPhase();
        }
        currentPhase = nextPhase;
        currentStatus = nextStatus;
        emitThinkingEvent(nextPhase, nextStatus, title, content, toolCallId);
    }
 
    private void emitThinkingEvent(String phase, String status, String title, String content, String toolCallId) {
        emitTraceEvent(AiChatTraceEventDto.builder()
                .requestId(requestId)
                .sessionId(sessionId)
                .traceType("thinking")
                .phase(phase)
                .status(status)
                .title(title)
                .content(content)
                .toolCallId(toolCallId)
                .timestamp(Instant.now().toEpochMilli())
                .build(), buildThinkingTraceId(phase));
    }
 
    private void emitToolEvent(String status, String title, String content, String toolCallId, String toolName,
                               String mountName, String inputSummary, String outputSummary, String errorMessage,
                               Long durationMs, long timestamp) {
        emitTraceEvent(AiChatTraceEventDto.builder()
                .requestId(requestId)
                .sessionId(sessionId)
                .traceType("tool")
                .status(status)
                .title(title)
                .content(content)
                .toolCallId(toolCallId)
                .toolName(toolName)
                .mountName(mountName)
                .inputSummary(inputSummary)
                .outputSummary(outputSummary)
                .errorMessage(errorMessage)
                .durationMs(durationMs)
                .timestamp(timestamp)
                .build(), buildToolTraceId(toolCallId));
    }
 
    private void emitTraceEvent(AiChatTraceEventDto payload, String traceId) {
        long sequence = traceOrderMap.computeIfAbsent(traceId, ignored -> traceSequence.incrementAndGet());
        payload.setSequence(sequence);
        payload.setTraceId(traceId);
        aiSseEventPublisher.emitSafely(emitter, TRACE_EVENT_NAME, payload);
    }
 
    private String buildThinkingTraceId(String phase) {
        return requestId + "-thinking-" + phase;
    }
 
    private String buildToolTraceId(String toolCallId) {
        return toolCallId;
    }
 
    private boolean isTerminalStatus(String status) {
        return "COMPLETED".equals(status) || "FAILED".equals(status) || "ABORTED".equals(status);
    }
 
    private String resolveCompleteTitle(String phase) {
        if ("ANSWER".equals(phase)) {
            return "答案整理完成";
        }
        if ("TOOL_CALL".equals(phase)) {
            return "工具分析完成";
        }
        return "问题分析完成";
    }
 
    private String resolveCompleteContent(String phase) {
        if ("ANSWER".equals(phase)) {
            return "最终答复已生成完成。";
        }
        if ("TOOL_CALL".equals(phase)) {
            return "工具调用阶段已结束,相关信息已整理完毕。";
        }
        return "问题意图和处理方向已分析完成。";
    }
 
    private String safeLabel(String value, String fallback) {
        return StringUtils.hasText(value) ? value : fallback;
    }
}