package com.vincent.rsf.server.ai.service.diagnosis; import com.fasterxml.jackson.databind.JsonNode; import com.vincent.rsf.server.ai.constant.AiSceneCode; import com.vincent.rsf.server.ai.dto.AiChatStreamRequest; import com.vincent.rsf.server.ai.dto.GatewayChatMessage; import com.vincent.rsf.server.ai.dto.GatewayChatRequest; import com.vincent.rsf.server.ai.model.AiChatMessage; import com.vincent.rsf.server.ai.model.AiChatSession; import com.vincent.rsf.server.ai.model.AiDiagnosticToolResult; import com.vincent.rsf.server.ai.model.AiPromptContext; import com.vincent.rsf.server.ai.service.AiGatewayClient; import com.vincent.rsf.server.ai.service.AiModelRouteRuntimeService; import com.vincent.rsf.server.ai.service.AiPromptRuntimeService; import com.vincent.rsf.server.ai.service.AiSessionService; import com.vincent.rsf.server.system.entity.AiDiagnosisRecord; import org.springframework.http.MediaType; import org.springframework.stereotype.Service; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import javax.annotation.Resource; import java.io.IOException; import java.io.InterruptedIOException; import java.util.ArrayList; import java.util.Date; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @Service public class AiChatStreamOrchestrator { @Resource private AiSessionService aiSessionService; @Resource private AiGatewayClient aiGatewayClient; @Resource private AiPromptRuntimeService aiPromptRuntimeService; @Resource private AiDiagnosticToolService aiDiagnosticToolService; @Resource private AiModelRouteRuntimeService aiModelRouteRuntimeService; @Resource private AiDiagnosisRuntimeService aiDiagnosisRuntimeService; @Resource private AiDiagnosisMcpRuntimeService aiDiagnosisMcpRuntimeService; /** * 从候选模型列表中挑出最大的上下文窗口,用于提前截断会话历史。 */ public int resolveContextSize(List candidates) { int max = 12; for (AiModelRouteRuntimeService.RouteCandidate item : candidates) { if (item != null && item.getRuntimeConfig() != null && item.getRuntimeConfig().getMaxContextMessages() != null && item.getRuntimeConfig().getMaxContextMessages() > max) { max = item.getRuntimeConfig().getMaxContextMessages(); } } return max; } /** * 执行一次完整的流式聊天/诊断编排。 * 这里统一负责 MCP 结果准备、模型重试、事件转发、调用日志和诊断收尾。 */ public void executeStream(SseEmitter emitter, Long tenantId, Long userId, AiChatSession session, AiChatStreamRequest request, AiPromptContext promptContext, List contextMessages, AiDiagnosisRecord diagnosisRecord, List candidates) { StringBuilder assistantReply = new StringBuilder(); String finalModelCode = session.getModelCode(); String finalErrorMessage = null; boolean stopped = false; boolean success = false; boolean assistantSaved = false; boolean errorSent = false; List runtimeDiagnosticResults = new ArrayList<>(); String toolSummary = "[]"; try { emitter.send(SseEmitter.event().name("session").data(buildSessionPayload(session), MediaType.APPLICATION_JSON)); if (aiDiagnosisMcpRuntimeService.shouldUseMcp(promptContext)) { runtimeDiagnosticResults = aiDiagnosisMcpRuntimeService.resolveToolResults( tenantId, promptContext, contextMessages, candidates.isEmpty() ? null : candidates.get(0) ); toolSummary = aiDiagnosticToolService.serializeResults(runtimeDiagnosticResults); } int attemptNo = 1; for (AiModelRouteRuntimeService.RouteCandidate candidate : candidates) { AttemptState attemptState = new AttemptState(); Date requestTime = new Date(); try { GatewayChatRequest gatewayChatRequest = buildGatewayRequest( session, contextMessages, candidate, request, promptContext, runtimeDiagnosticResults, attemptNo ); aiGatewayClient.stream(gatewayChatRequest, event -> handleGatewayEvent( emitter, event, session, assistantReply, attemptState )); } catch (Exception e) { attemptState.setSuccess(false); attemptState.setErrorMessage(e.getMessage()); attemptState.setInterrupted(isInterruptedError(e)); attemptState.setResponseTime(new Date()); } if (attemptState.getResponseTime() == null) { attemptState.setResponseTime(new Date()); } String actualModelCode = attemptState.getActualModelCode() == null ? candidate.getAttemptModelCode() : attemptState.getActualModelCode(); finalModelCode = actualModelCode; aiDiagnosisRuntimeService.saveCallLog( tenantId, userId, session.getId(), diagnosisRecord == null ? null : diagnosisRecord.getId(), resolveRouteCode(candidate, request), actualModelCode, attemptNo, requestTime, attemptState.getResponseTime(), Boolean.TRUE.equals(attemptState.getSuccess()) ? 1 : 0, attemptState.getErrorMessage() ); if (Boolean.TRUE.equals(attemptState.getSuccess())) { aiModelRouteRuntimeService.markSuccess(candidate.getRouteId()); success = assistantReply.length() > 0; if (!success) { finalErrorMessage = "模型未返回有效内容"; } break; } if (attemptState.isStopped() || aiSessionService.isStopRequested(session.getId())) { stopped = true; break; } if (!attemptState.isInterrupted()) { aiModelRouteRuntimeService.markFailure(candidate.getRouteId()); } finalErrorMessage = attemptState.getErrorMessage(); if (attemptState.isReceivedDelta() || attemptNo >= candidates.size()) { if (!attemptState.isInterrupted() && finalErrorMessage != null && !finalErrorMessage.trim().isEmpty()) { emitter.send(SseEmitter.event().name("error").data(buildErrorPayload(finalErrorMessage), MediaType.APPLICATION_JSON)); errorSent = true; } break; } attemptNo++; } if (aiSessionService.isStopRequested(session.getId())) { stopped = true; } if (stopped) { if (assistantReply.length() > 0) { aiSessionService.appendMessage(tenantId, userId, session.getId(), "assistant", assistantReply.toString(), finalModelCode); assistantSaved = true; } emitter.send(SseEmitter.event().name("done").data(buildDonePayload(session, finalModelCode, true), MediaType.APPLICATION_JSON)); if (diagnosisRecord != null) { aiDiagnosisRuntimeService.finishDiagnosisFailure(diagnosisRecord, assistantReply.toString(), "用户已停止生成", toolSummary); } return; } if (success) { aiSessionService.appendMessage(tenantId, userId, session.getId(), "assistant", assistantReply.toString(), finalModelCode); assistantSaved = true; emitter.send(SseEmitter.event().name("done").data(buildDonePayload(session, finalModelCode, false), MediaType.APPLICATION_JSON)); if (diagnosisRecord != null) { aiDiagnosisRuntimeService.finishDiagnosisSuccess(diagnosisRecord, assistantReply.toString(), finalModelCode, toolSummary); } return; } if (assistantReply.length() > 0 && !assistantSaved) { aiSessionService.appendMessage(tenantId, userId, session.getId(), "assistant", assistantReply.toString(), finalModelCode); assistantSaved = true; } if (diagnosisRecord != null) { aiDiagnosisRuntimeService.finishDiagnosisFailure(diagnosisRecord, assistantReply.toString(), finalErrorMessage, toolSummary); } if (!errorSent && finalErrorMessage != null && !finalErrorMessage.trim().isEmpty()) { emitter.send(SseEmitter.event().name("error").data(buildErrorPayload(finalErrorMessage), MediaType.APPLICATION_JSON)); } } catch (Exception e) { if (diagnosisRecord != null) { aiDiagnosisRuntimeService.finishDiagnosisFailure(diagnosisRecord, assistantReply.toString(), e.getMessage(), toolSummary); } if (!isInterruptedError(e)) { try { emitter.send(SseEmitter.event().name("error").data(buildErrorPayload(e.getMessage()), MediaType.APPLICATION_JSON)); } catch (IOException ignore) { } } else { Thread.currentThread().interrupt(); } } finally { emitter.complete(); aiSessionService.clearStopFlag(session.getId()); } } /** * 消费网关返回的单条流式事件,并把状态写回本次尝试上下文。 */ private boolean handleGatewayEvent(SseEmitter emitter, JsonNode event, AiChatSession session, StringBuilder assistantReply, AttemptState attemptState) throws Exception { if (aiSessionService.isStopRequested(session.getId())) { attemptState.setStopped(true); attemptState.setResponseTime(new Date()); return false; } String type = event.path("type").asText(); String modelCode = event.path("modelCode").asText(session.getModelCode()); if ("delta".equals(type)) { String content = event.path("content").asText(""); assistantReply.append(content); attemptState.setReceivedDelta(true); attemptState.setActualModelCode(modelCode); emitter.send(SseEmitter.event().name("delta").data(buildDeltaPayload(session, modelCode, content), MediaType.APPLICATION_JSON)); return true; } if ("error".equals(type)) { String message = event.path("message").asText("模型调用失败"); attemptState.setSuccess(false); attemptState.setErrorMessage(message); attemptState.setActualModelCode(modelCode); attemptState.setResponseTime(parseResponseTime(event)); attemptState.setInterrupted(isInterruptedMessage(message)); return false; } if ("done".equals(type)) { attemptState.setSuccess(true); attemptState.setActualModelCode(modelCode); attemptState.setResponseTime(parseResponseTime(event)); return false; } if ("ping".equals(type)) { emitter.send(SseEmitter.event().name("ping").data(buildPingPayload(modelCode), MediaType.APPLICATION_JSON)); return true; } return true; } /** * 将当前尝试需要的上下文、系统 Prompt 和模型信息组装成网关请求。 */ private GatewayChatRequest buildGatewayRequest(AiChatSession session, List contextMessages, AiModelRouteRuntimeService.RouteCandidate candidate, AiChatStreamRequest chatRequest, AiPromptContext promptContext, List diagnosticResults, Integer attemptNo) { GatewayChatRequest request = new GatewayChatRequest(); request.setSessionId(session.getId()); request.setModelCode(candidate.getAttemptModelCode()); request.setRouteCode(resolveRouteCode(candidate, chatRequest)); request.setAttemptNo(attemptNo); request.setSystemPrompt(aiPromptRuntimeService.buildSystemPrompt( chatRequest.getSceneCode(), candidate.getRuntimeConfig().getSystemPrompt(), promptContext, diagnosticResults )); request.setChatUrl(candidate.getRuntimeConfig().getChatUrl()); request.setApiKey(candidate.getRuntimeConfig().getApiKey()); request.setModelName(candidate.getRuntimeConfig().getModelName()); for (AiChatMessage contextMessage : contextMessages) { GatewayChatMessage item = new GatewayChatMessage(); item.setRole(contextMessage.getRole()); item.setContent(contextMessage.getContent()); request.getMessages().add(item); } return request; } /** * 解析当前尝试应落到哪个路由编码,优先使用路由候选自带的 routeCode。 */ private String resolveRouteCode(AiModelRouteRuntimeService.RouteCandidate candidate, AiChatStreamRequest request) { if (candidate != null && candidate.getRouteCode() != null && !candidate.getRouteCode().trim().isEmpty()) { return candidate.getRouteCode(); } return AiSceneCode.SYSTEM_DIAGNOSE.equals(request.getSceneCode()) ? AiSceneCode.SYSTEM_DIAGNOSE : AiSceneCode.GENERAL_CHAT; } /** * 从网关事件中解析响应时间,缺省时回退为当前时间。 */ private Date parseResponseTime(JsonNode event) { long millis = event.path("responseTime").asLong(0L); return millis <= 0L ? new Date() : new Date(millis); } /** * 构造会话初始化事件给前端。 */ private Map buildSessionPayload(AiChatSession session) { Map payload = new LinkedHashMap<>(); payload.put("sessionId", session.getId()); payload.put("title", session.getTitle()); payload.put("modelCode", session.getModelCode()); return payload; } /** * 构造增量输出事件给前端。 */ private Map buildDeltaPayload(AiChatSession session, String modelCode, String content) { Map payload = new LinkedHashMap<>(); payload.put("sessionId", session.getId()); payload.put("modelCode", modelCode); payload.put("content", content); payload.put("timestamp", new Date().getTime()); return payload; } /** * 构造流式完成事件给前端。 */ private Map buildDonePayload(AiChatSession session, String modelCode, boolean stopped) { Map payload = new LinkedHashMap<>(); payload.put("sessionId", session.getId()); payload.put("modelCode", modelCode); payload.put("stopped", stopped); return payload; } /** * 构造错误事件给前端。 */ private Map buildErrorPayload(String message) { Map payload = new LinkedHashMap<>(); payload.put("message", message == null ? "AI服务异常" : message); return payload; } /** * 构造心跳事件给前端,用于维持长连接活性。 */ private Map buildPingPayload(String modelCode) { Map payload = new LinkedHashMap<>(); payload.put("modelCode", modelCode); payload.put("timestamp", new Date().getTime()); return payload; } /** * 判断异常链中是否包含线程中断类错误。 */ private boolean isInterruptedError(Throwable throwable) { Throwable current = throwable; while (current != null) { if (current instanceof InterruptedException || current instanceof InterruptedIOException) { return true; } if (isInterruptedMessage(current.getMessage())) { return true; } current = current.getCause(); } return false; } private boolean isInterruptedMessage(String message) { if (message == null || message.trim().isEmpty()) { return false; } String normalized = message.toLowerCase(); return normalized.contains("interrupted") || normalized.contains("broken pipe") || normalized.contains("connection reset") || normalized.contains("forcibly closed"); } private static class AttemptState { private Boolean success; private String actualModelCode; private String errorMessage; private boolean receivedDelta; private boolean interrupted; private boolean stopped; private Date responseTime; public Boolean getSuccess() { return success; } public void setSuccess(Boolean success) { this.success = success; } public String getActualModelCode() { return actualModelCode; } public void setActualModelCode(String actualModelCode) { this.actualModelCode = actualModelCode; } public String getErrorMessage() { return errorMessage; } public void setErrorMessage(String errorMessage) { this.errorMessage = errorMessage; } public boolean isReceivedDelta() { return receivedDelta; } public void setReceivedDelta(boolean receivedDelta) { this.receivedDelta = receivedDelta; } public boolean isInterrupted() { return interrupted; } public void setInterrupted(boolean interrupted) { this.interrupted = interrupted; } public boolean isStopped() { return stopped; } public void setStopped(boolean stopped) { this.stopped = stopped; } public Date getResponseTime() { return responseTime; } public void setResponseTime(Date responseTime) { this.responseTime = responseTime; } } }