1
22 小时以前 b2deb1cc93b3d2c3fb9dc795e3589e1c62329a8f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
package com.vincent.rsf.server.ai.service.impl.chat;
 
import com.fasterxml.jackson.databind.ObjectMapper;
import com.vincent.rsf.framework.exception.CoolException;
import com.vincent.rsf.server.ai.dto.AiChatDoneDto;
import com.vincent.rsf.server.ai.dto.AiChatStatusDto;
import com.vincent.rsf.server.ai.enums.AiErrorCategory;
import com.vincent.rsf.server.ai.exception.AiChatException;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
 
import java.io.IOException;
import java.time.Instant;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
 
@Slf4j
@Component
@RequiredArgsConstructor
public class AiSseEventPublisher {
 
    private final ObjectMapper objectMapper;
 
    public void markFirstToken(AtomicReference<Long> firstTokenAtRef, SseEmitter emitter, String requestId,
                               Long sessionId, String model, long startedAt, AiThinkingTraceEmitter thinkingTraceEmitter) {
        if (!firstTokenAtRef.compareAndSet(null, System.currentTimeMillis())) {
            return;
        }
        if (thinkingTraceEmitter != null) {
            thinkingTraceEmitter.startAnswer();
        }
        emitSafely(emitter, "status", AiChatStatusDto.builder()
                .requestId(requestId)
                .sessionId(sessionId)
                .status("FIRST_TOKEN")
                .model(model)
                .timestamp(Instant.now().toEpochMilli())
                .elapsedMs(System.currentTimeMillis() - startedAt)
                .firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()))
                .build());
    }
 
    public void emitDone(SseEmitter emitter, String requestId, ChatResponseMetadata metadata, String fallbackModel,
                         Long sessionId, long startedAt, Long firstTokenAt) {
        Usage usage = metadata == null ? null : metadata.getUsage();
        emitStrict(emitter, "done", AiChatDoneDto.builder()
                .requestId(requestId)
                .sessionId(sessionId)
                .model(metadata != null && metadata.getModel() != null && !metadata.getModel().isBlank() ? metadata.getModel() : fallbackModel)
                .elapsedMs(System.currentTimeMillis() - startedAt)
                .firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAt))
                .promptTokens(usage == null ? null : usage.getPromptTokens())
                .completionTokens(usage == null ? null : usage.getCompletionTokens())
                .totalTokens(usage == null ? null : usage.getTotalTokens())
                .build());
    }
 
    public AiChatStatusDto buildTerminalStatus(String requestId, Long sessionId, String status, String model, long startedAt, Long firstTokenAt) {
        return AiChatStatusDto.builder()
                .requestId(requestId)
                .sessionId(sessionId)
                .status(status)
                .model(model)
                .timestamp(Instant.now().toEpochMilli())
                .elapsedMs(System.currentTimeMillis() - startedAt)
                .firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAt))
                .build();
    }
 
    public Long resolveFirstTokenLatency(long startedAt, Long firstTokenAt) {
        return firstTokenAt == null ? null : Math.max(0L, firstTokenAt - startedAt);
    }
 
    public Map<String, String> buildMessagePayload(String... keyValues) {
        Map<String, String> payload = new LinkedHashMap<>();
        if (keyValues == null || keyValues.length == 0) {
            return payload;
        }
        if (keyValues.length % 2 != 0) {
            throw new CoolException("消息载荷参数必须成对出现");
        }
        for (int i = 0; i < keyValues.length; i += 2) {
            payload.put(keyValues[i], keyValues[i + 1] == null ? "" : keyValues[i + 1]);
        }
        return payload;
    }
 
    public void emitStrict(SseEmitter emitter, String eventName, Object payload) {
        try {
            String data = objectMapper.writeValueAsString(payload);
            emitter.send(SseEmitter.event()
                    .name(eventName)
                    .data(data, MediaType.APPLICATION_JSON));
        } catch (IOException e) {
            throw new AiChatException("AI_SSE_EMIT_ERROR", AiErrorCategory.STREAM, "SSE_EMIT", "SSE 输出失败: " + e.getMessage(), e);
        }
    }
 
    public void emitSafely(SseEmitter emitter, String eventName, Object payload) {
        try {
            emitStrict(emitter, eventName, payload);
        } catch (Exception e) {
            log.warn("AI SSE event emit skipped, eventName={}, message={}", eventName, e.getMessage());
        }
    }
}