package com.vincent.rsf.server.ai.service.impl.chat; import com.vincent.rsf.server.ai.dto.AiChatErrorDto; import com.vincent.rsf.server.ai.enums.AiErrorCategory; import com.vincent.rsf.server.ai.exception.AiChatException; import com.vincent.rsf.server.ai.service.AiCallLogService; import com.vincent.rsf.server.ai.store.AiStreamStateStore; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Component; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import java.time.Instant; @Slf4j @Component @RequiredArgsConstructor public class AiChatFailureHandler { private final AiSseEventPublisher aiSseEventPublisher; private final AiCallLogService aiCallLogService; private final AiStreamStateStore aiStreamStateStore; public AiChatException buildAiException(String code, AiErrorCategory category, String stage, String message, Throwable cause) { return new AiChatException(code, category, stage, message, cause); } public void handleStreamFailure(SseEmitter emitter, String requestId, Long sessionId, String model, long startedAt, Long firstTokenAt, AiChatException exception, Long callLogId, long toolSuccessCount, long toolFailureCount, AiThinkingTraceEmitter thinkingTraceEmitter, Long tenantId, Long userId, String promptCode) { if (isClientAbortException(exception)) { log.warn("AI chat aborted by client, requestId={}, sessionId={}, stage={}, message={}", requestId, sessionId, exception.getStage(), exception.getMessage()); if (thinkingTraceEmitter != null) { thinkingTraceEmitter.markTerminated("ABORTED"); } aiSseEventPublisher.emitSafely(emitter, "status", aiSseEventPublisher.buildTerminalStatus(requestId, sessionId, "ABORTED", model, startedAt, firstTokenAt)); aiCallLogService.failCallLog( callLogId, "ABORTED", exception.getCategory().name(), exception.getStage(), exception.getMessage(), System.currentTimeMillis() - startedAt, aiSseEventPublisher.resolveFirstTokenLatency(startedAt, firstTokenAt), toolSuccessCount, toolFailureCount ); aiStreamStateStore.markStreamState(requestId, tenantId, userId, sessionId, promptCode, "ABORTED", exception.getMessage()); emitter.completeWithError(exception); return; } log.error("AI chat failed, requestId={}, sessionId={}, category={}, stage={}, message={}", requestId, sessionId, exception.getCategory(), exception.getStage(), exception.getMessage(), exception); if (thinkingTraceEmitter != null) { thinkingTraceEmitter.markTerminated("FAILED"); } aiSseEventPublisher.emitSafely(emitter, "status", aiSseEventPublisher.buildTerminalStatus(requestId, sessionId, "FAILED", model, startedAt, firstTokenAt)); aiSseEventPublisher.emitSafely(emitter, "error", AiChatErrorDto.builder() .requestId(requestId) .sessionId(sessionId) .code(exception.getCode()) .category(exception.getCategory().name()) .stage(exception.getStage()) .message(exception.getMessage()) .timestamp(Instant.now().toEpochMilli()) .build()); aiCallLogService.failCallLog( callLogId, "FAILED", exception.getCategory().name(), exception.getStage(), exception.getMessage(), System.currentTimeMillis() - startedAt, aiSseEventPublisher.resolveFirstTokenLatency(startedAt, firstTokenAt), toolSuccessCount, toolFailureCount ); aiStreamStateStore.markStreamState(requestId, tenantId, userId, sessionId, promptCode, "FAILED", exception.getMessage()); emitter.completeWithError(exception); } private boolean isClientAbortException(Throwable throwable) { Throwable current = throwable; while (current != null) { String message = current.getMessage(); if (message != null) { String normalized = message.toLowerCase(); if (normalized.contains("broken pipe") || normalized.contains("connection reset") || normalized.contains("forcibly closed") || normalized.contains("abort")) { return true; } } current = current.getCause(); } return false; } }