#
zhou zhou
10 小时以前 66d766c88ec5d1ab4715fd9f2c22ce42b459d957
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
package com.vincent.rsf.server.ai.service.impl.chat;
 
import com.vincent.rsf.server.ai.dto.AiChatErrorDto;
import com.vincent.rsf.server.ai.enums.AiErrorCategory;
import com.vincent.rsf.server.ai.exception.AiChatException;
import com.vincent.rsf.server.ai.service.AiCallLogService;
import com.vincent.rsf.server.ai.store.AiStreamStateStore;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import org.springframework.web.reactive.function.client.WebClientResponseException;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
 
import java.time.Instant;
 
@Slf4j
@Component
@RequiredArgsConstructor
public class AiChatFailureHandler {
 
    private final AiSseEventPublisher aiSseEventPublisher;
    private final AiCallLogService aiCallLogService;
    private final AiStreamStateStore aiStreamStateStore;
 
    public AiChatException buildAiException(String code, AiErrorCategory category, String stage, String message, Throwable cause) {
        return new AiChatException(code, category, stage, resolveExceptionMessage(message, cause), cause);
    }
 
    public void handleStreamFailure(SseEmitter emitter, String requestId, Long sessionId, String model, long startedAt,
                                    Long firstTokenAt, AiChatException exception, Long callLogId,
                                    long toolSuccessCount, long toolFailureCount,
                                    AiChatTraceEmitter traceEmitter,
                                    Long tenantId, Long userId, String promptCode) {
        if (isClientAbortException(exception)) {
            log.warn("AI chat aborted by client, requestId={}, sessionId={}, stage={}, message={}",
                    requestId, sessionId, exception.getStage(), exception.getMessage());
            if (traceEmitter != null) {
                traceEmitter.markTerminated("ABORTED");
            }
            aiSseEventPublisher.emitSafely(emitter, "status",
                    aiSseEventPublisher.buildTerminalStatus(requestId, sessionId, "ABORTED", model, startedAt, firstTokenAt));
            aiCallLogService.failCallLog(
                    callLogId,
                    "ABORTED",
                    exception.getCategory().name(),
                    exception.getStage(),
                    exception.getMessage(),
                    System.currentTimeMillis() - startedAt,
                    aiSseEventPublisher.resolveFirstTokenLatency(startedAt, firstTokenAt),
                    toolSuccessCount,
                    toolFailureCount
            );
            aiStreamStateStore.markStreamState(requestId, tenantId, userId, sessionId, promptCode, "ABORTED", exception.getMessage());
            emitter.complete();
            return;
        }
        log.error("AI chat failed, requestId={}, sessionId={}, category={}, stage={}, message={}",
                requestId, sessionId, exception.getCategory(), exception.getStage(), exception.getMessage(), exception);
        if (traceEmitter != null) {
            traceEmitter.markTerminated("FAILED");
        }
        aiSseEventPublisher.emitSafely(emitter, "status",
                aiSseEventPublisher.buildTerminalStatus(requestId, sessionId, "FAILED", model, startedAt, firstTokenAt));
        aiSseEventPublisher.emitSafely(emitter, "error", AiChatErrorDto.builder()
                .requestId(requestId)
                .sessionId(sessionId)
                .code(exception.getCode())
                .category(exception.getCategory().name())
                .stage(exception.getStage())
                .message(exception.getMessage())
                .timestamp(Instant.now().toEpochMilli())
                .build());
        aiCallLogService.failCallLog(
                callLogId,
                "FAILED",
                exception.getCategory().name(),
                exception.getStage(),
                exception.getMessage(),
                System.currentTimeMillis() - startedAt,
                aiSseEventPublisher.resolveFirstTokenLatency(startedAt, firstTokenAt),
                toolSuccessCount,
                toolFailureCount
        );
        aiStreamStateStore.markStreamState(requestId, tenantId, userId, sessionId, promptCode, "FAILED", exception.getMessage());
        emitter.complete();
    }
 
    private String resolveExceptionMessage(String message, Throwable cause) {
        String upstreamMessage = extractUpstreamResponseBody(cause);
        if (StringUtils.hasText(upstreamMessage)) {
            return truncateMessage(upstreamMessage);
        }
        return truncateMessage(message);
    }
 
    private String extractUpstreamResponseBody(Throwable throwable) {
        Throwable current = throwable;
        while (current != null) {
            if (current instanceof WebClientResponseException webClientResponseException) {
                String responseBody = webClientResponseException.getResponseBodyAsString();
                if (StringUtils.hasText(responseBody)) {
                    return responseBody.replace('\n', ' ').replace('\r', ' ').trim();
                }
            }
            current = current.getCause();
        }
        return null;
    }
 
    private String truncateMessage(String message) {
        if (!StringUtils.hasText(message)) {
            return message;
        }
        return message.length() > 900 ? message.substring(0, 900) : message;
    }
 
    private boolean isClientAbortException(Throwable throwable) {
        Throwable current = throwable;
        while (current != null) {
            String message = current.getMessage();
            if (message != null) {
                String normalized = message.toLowerCase();
                if (normalized.contains("broken pipe")
                        || normalized.contains("connection reset")
                        || normalized.contains("forcibly closed")
                        || normalized.contains("abort")) {
                    return true;
                }
            }
            current = current.getCause();
        }
        return false;
    }
}