package com.vincent.rsf.server.ai.service.impl.chat; import com.fasterxml.jackson.databind.ObjectMapper; import com.vincent.rsf.framework.exception.CoolException; import com.vincent.rsf.server.ai.dto.AiChatDoneDto; import com.vincent.rsf.server.ai.dto.AiChatStatusDto; import com.vincent.rsf.server.ai.enums.AiErrorCategory; import com.vincent.rsf.server.ai.exception.AiChatException; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.metadata.Usage; import org.springframework.http.MediaType; import org.springframework.stereotype.Component; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import java.io.IOException; import java.time.Instant; import java.util.LinkedHashMap; import java.util.Map; import java.util.concurrent.atomic.AtomicReference; @Slf4j @Component @RequiredArgsConstructor public class AiSseEventPublisher { private final ObjectMapper objectMapper; public void markFirstToken(AtomicReference firstTokenAtRef, SseEmitter emitter, String requestId, Long sessionId, String model, long startedAt, AiThinkingTraceEmitter thinkingTraceEmitter) { if (!firstTokenAtRef.compareAndSet(null, System.currentTimeMillis())) { return; } if (thinkingTraceEmitter != null) { thinkingTraceEmitter.startAnswer(); } emitSafely(emitter, "status", AiChatStatusDto.builder() .requestId(requestId) .sessionId(sessionId) .status("FIRST_TOKEN") .model(model) .timestamp(Instant.now().toEpochMilli()) .elapsedMs(System.currentTimeMillis() - startedAt) .firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAtRef.get())) .build()); } public void emitDone(SseEmitter emitter, String requestId, ChatResponseMetadata metadata, String fallbackModel, Long sessionId, long startedAt, Long firstTokenAt) { Usage usage = metadata == null ? null : metadata.getUsage(); emitStrict(emitter, "done", AiChatDoneDto.builder() .requestId(requestId) .sessionId(sessionId) .model(metadata != null && metadata.getModel() != null && !metadata.getModel().isBlank() ? metadata.getModel() : fallbackModel) .elapsedMs(System.currentTimeMillis() - startedAt) .firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAt)) .promptTokens(usage == null ? null : usage.getPromptTokens()) .completionTokens(usage == null ? null : usage.getCompletionTokens()) .totalTokens(usage == null ? null : usage.getTotalTokens()) .build()); } public AiChatStatusDto buildTerminalStatus(String requestId, Long sessionId, String status, String model, long startedAt, Long firstTokenAt) { return AiChatStatusDto.builder() .requestId(requestId) .sessionId(sessionId) .status(status) .model(model) .timestamp(Instant.now().toEpochMilli()) .elapsedMs(System.currentTimeMillis() - startedAt) .firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAt)) .build(); } public Long resolveFirstTokenLatency(long startedAt, Long firstTokenAt) { return firstTokenAt == null ? null : Math.max(0L, firstTokenAt - startedAt); } public Map buildMessagePayload(String... keyValues) { Map payload = new LinkedHashMap<>(); if (keyValues == null || keyValues.length == 0) { return payload; } if (keyValues.length % 2 != 0) { throw new CoolException("消息载荷参数必须成对出现"); } for (int i = 0; i < keyValues.length; i += 2) { payload.put(keyValues[i], keyValues[i + 1] == null ? "" : keyValues[i + 1]); } return payload; } public void emitStrict(SseEmitter emitter, String eventName, Object payload) { try { String data = objectMapper.writeValueAsString(payload); emitter.send(SseEmitter.event() .name(eventName) .data(data, MediaType.APPLICATION_JSON)); } catch (IOException e) { throw new AiChatException("AI_SSE_EMIT_ERROR", AiErrorCategory.STREAM, "SSE_EMIT", "SSE 输出失败: " + e.getMessage(), e); } } public void emitSafely(SseEmitter emitter, String eventName, Object payload) { try { emitStrict(emitter, eventName, payload); } catch (Exception e) { log.warn("AI SSE event emit skipped, eventName={}, message={}", eventName, e.getMessage()); } } }