package com.vincent.rsf.server.ai.service.impl.chat;
|
|
import com.vincent.rsf.server.ai.dto.AiChatErrorDto;
|
import com.vincent.rsf.server.ai.enums.AiErrorCategory;
|
import com.vincent.rsf.server.ai.exception.AiChatException;
|
import com.vincent.rsf.server.ai.service.AiCallLogService;
|
import com.vincent.rsf.server.ai.store.AiStreamStateStore;
|
import lombok.RequiredArgsConstructor;
|
import lombok.extern.slf4j.Slf4j;
|
import org.springframework.stereotype.Component;
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
|
import java.time.Instant;
|
|
@Slf4j
|
@Component
|
@RequiredArgsConstructor
|
public class AiChatFailureHandler {
|
|
private final AiSseEventPublisher aiSseEventPublisher;
|
private final AiCallLogService aiCallLogService;
|
private final AiStreamStateStore aiStreamStateStore;
|
|
public AiChatException buildAiException(String code, AiErrorCategory category, String stage, String message, Throwable cause) {
|
return new AiChatException(code, category, stage, message, cause);
|
}
|
|
public void handleStreamFailure(SseEmitter emitter, String requestId, Long sessionId, String model, long startedAt,
|
Long firstTokenAt, AiChatException exception, Long callLogId,
|
long toolSuccessCount, long toolFailureCount,
|
AiThinkingTraceEmitter thinkingTraceEmitter,
|
Long tenantId, Long userId, String promptCode) {
|
if (isClientAbortException(exception)) {
|
log.warn("AI chat aborted by client, requestId={}, sessionId={}, stage={}, message={}",
|
requestId, sessionId, exception.getStage(), exception.getMessage());
|
if (thinkingTraceEmitter != null) {
|
thinkingTraceEmitter.markTerminated("ABORTED");
|
}
|
aiSseEventPublisher.emitSafely(emitter, "status",
|
aiSseEventPublisher.buildTerminalStatus(requestId, sessionId, "ABORTED", model, startedAt, firstTokenAt));
|
aiCallLogService.failCallLog(
|
callLogId,
|
"ABORTED",
|
exception.getCategory().name(),
|
exception.getStage(),
|
exception.getMessage(),
|
System.currentTimeMillis() - startedAt,
|
aiSseEventPublisher.resolveFirstTokenLatency(startedAt, firstTokenAt),
|
toolSuccessCount,
|
toolFailureCount
|
);
|
aiStreamStateStore.markStreamState(requestId, tenantId, userId, sessionId, promptCode, "ABORTED", exception.getMessage());
|
emitter.completeWithError(exception);
|
return;
|
}
|
log.error("AI chat failed, requestId={}, sessionId={}, category={}, stage={}, message={}",
|
requestId, sessionId, exception.getCategory(), exception.getStage(), exception.getMessage(), exception);
|
if (thinkingTraceEmitter != null) {
|
thinkingTraceEmitter.markTerminated("FAILED");
|
}
|
aiSseEventPublisher.emitSafely(emitter, "status",
|
aiSseEventPublisher.buildTerminalStatus(requestId, sessionId, "FAILED", model, startedAt, firstTokenAt));
|
aiSseEventPublisher.emitSafely(emitter, "error", AiChatErrorDto.builder()
|
.requestId(requestId)
|
.sessionId(sessionId)
|
.code(exception.getCode())
|
.category(exception.getCategory().name())
|
.stage(exception.getStage())
|
.message(exception.getMessage())
|
.timestamp(Instant.now().toEpochMilli())
|
.build());
|
aiCallLogService.failCallLog(
|
callLogId,
|
"FAILED",
|
exception.getCategory().name(),
|
exception.getStage(),
|
exception.getMessage(),
|
System.currentTimeMillis() - startedAt,
|
aiSseEventPublisher.resolveFirstTokenLatency(startedAt, firstTokenAt),
|
toolSuccessCount,
|
toolFailureCount
|
);
|
aiStreamStateStore.markStreamState(requestId, tenantId, userId, sessionId, promptCode, "FAILED", exception.getMessage());
|
emitter.completeWithError(exception);
|
}
|
|
private boolean isClientAbortException(Throwable throwable) {
|
Throwable current = throwable;
|
while (current != null) {
|
String message = current.getMessage();
|
if (message != null) {
|
String normalized = message.toLowerCase();
|
if (normalized.contains("broken pipe")
|
|| normalized.contains("connection reset")
|
|| normalized.contains("forcibly closed")
|
|| normalized.contains("abort")) {
|
return true;
|
}
|
}
|
current = current.getCause();
|
}
|
return false;
|
}
|
}
|