package com.vincent.rsf.server.ai.service.impl.chat;
|
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.vincent.rsf.framework.exception.CoolException;
|
import com.vincent.rsf.server.ai.dto.AiChatDoneDto;
|
import com.vincent.rsf.server.ai.dto.AiChatStatusDto;
|
import com.vincent.rsf.server.ai.enums.AiErrorCategory;
|
import com.vincent.rsf.server.ai.exception.AiChatException;
|
import lombok.RequiredArgsConstructor;
|
import lombok.extern.slf4j.Slf4j;
|
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
|
import org.springframework.ai.chat.metadata.Usage;
|
import org.springframework.http.MediaType;
|
import org.springframework.stereotype.Component;
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
|
import java.io.IOException;
|
import java.time.Instant;
|
import java.util.LinkedHashMap;
|
import java.util.Map;
|
import java.util.concurrent.atomic.AtomicReference;
|
|
@Slf4j
|
@Component
|
@RequiredArgsConstructor
|
public class AiSseEventPublisher {
|
|
private final ObjectMapper objectMapper;
|
|
public void markFirstToken(AtomicReference<Long> firstTokenAtRef, SseEmitter emitter, String requestId,
|
Long sessionId, String model, long startedAt, AiChatTraceEmitter traceEmitter) {
|
if (!firstTokenAtRef.compareAndSet(null, System.currentTimeMillis())) {
|
return;
|
}
|
if (traceEmitter != null) {
|
traceEmitter.startAnswer();
|
}
|
emitSafely(emitter, "status", AiChatStatusDto.builder()
|
.requestId(requestId)
|
.sessionId(sessionId)
|
.status("FIRST_TOKEN")
|
.model(model)
|
.timestamp(Instant.now().toEpochMilli())
|
.elapsedMs(System.currentTimeMillis() - startedAt)
|
.firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()))
|
.build());
|
}
|
|
public void emitDone(SseEmitter emitter, String requestId, ChatResponseMetadata metadata, String fallbackModel,
|
Long sessionId, long startedAt, Long firstTokenAt) {
|
Usage usage = metadata == null ? null : metadata.getUsage();
|
emitStrict(emitter, "done", AiChatDoneDto.builder()
|
.requestId(requestId)
|
.sessionId(sessionId)
|
.model(metadata != null && metadata.getModel() != null && !metadata.getModel().isBlank() ? metadata.getModel() : fallbackModel)
|
.elapsedMs(System.currentTimeMillis() - startedAt)
|
.firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAt))
|
.promptTokens(usage == null ? null : usage.getPromptTokens())
|
.completionTokens(usage == null ? null : usage.getCompletionTokens())
|
.totalTokens(usage == null ? null : usage.getTotalTokens())
|
.build());
|
}
|
|
public AiChatStatusDto buildTerminalStatus(String requestId, Long sessionId, String status, String model, long startedAt, Long firstTokenAt) {
|
return AiChatStatusDto.builder()
|
.requestId(requestId)
|
.sessionId(sessionId)
|
.status(status)
|
.model(model)
|
.timestamp(Instant.now().toEpochMilli())
|
.elapsedMs(System.currentTimeMillis() - startedAt)
|
.firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAt))
|
.build();
|
}
|
|
public Long resolveFirstTokenLatency(long startedAt, Long firstTokenAt) {
|
return firstTokenAt == null ? null : Math.max(0L, firstTokenAt - startedAt);
|
}
|
|
public Map<String, String> buildMessagePayload(String... keyValues) {
|
Map<String, String> payload = new LinkedHashMap<>();
|
if (keyValues == null || keyValues.length == 0) {
|
return payload;
|
}
|
if (keyValues.length % 2 != 0) {
|
throw new CoolException("消息载荷参数必须成对出现");
|
}
|
for (int i = 0; i < keyValues.length; i += 2) {
|
payload.put(keyValues[i], keyValues[i + 1] == null ? "" : keyValues[i + 1]);
|
}
|
return payload;
|
}
|
|
public void emitStrict(SseEmitter emitter, String eventName, Object payload) {
|
try {
|
String data = objectMapper.writeValueAsString(payload);
|
emitter.send(SseEmitter.event()
|
.name(eventName)
|
.data(data, MediaType.APPLICATION_JSON));
|
} catch (IOException e) {
|
throw new AiChatException("AI_SSE_EMIT_ERROR", AiErrorCategory.STREAM, "SSE_EMIT", "SSE 输出失败: " + e.getMessage(), e);
|
}
|
}
|
|
public void emitSafely(SseEmitter emitter, String eventName, Object payload) {
|
try {
|
emitStrict(emitter, eventName, payload);
|
} catch (Exception e) {
|
log.warn("AI SSE event emit skipped, eventName={}, message={}", eventName, e.getMessage());
|
}
|
}
|
}
|