package com.vincent.rsf.server.ai.service.diagnosis;
|
|
import com.fasterxml.jackson.databind.JsonNode;
|
import com.vincent.rsf.server.ai.constant.AiSceneCode;
|
import com.vincent.rsf.server.ai.dto.AiChatStreamRequest;
|
import com.vincent.rsf.server.ai.dto.GatewayChatMessage;
|
import com.vincent.rsf.server.ai.dto.GatewayChatRequest;
|
import com.vincent.rsf.server.ai.model.AiChatMessage;
|
import com.vincent.rsf.server.ai.model.AiChatSession;
|
import com.vincent.rsf.server.ai.model.AiDiagnosticToolResult;
|
import com.vincent.rsf.server.ai.model.AiPromptContext;
|
import com.vincent.rsf.server.ai.service.AiGatewayClient;
|
import com.vincent.rsf.server.ai.service.AiModelRouteRuntimeService;
|
import com.vincent.rsf.server.ai.service.AiPromptRuntimeService;
|
import com.vincent.rsf.server.ai.service.AiSessionService;
|
import com.vincent.rsf.server.system.entity.AiDiagnosisRecord;
|
import org.springframework.http.MediaType;
|
import org.springframework.stereotype.Service;
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
|
import javax.annotation.Resource;
|
import java.io.IOException;
|
import java.io.InterruptedIOException;
|
import java.util.ArrayList;
|
import java.util.Date;
|
import java.util.LinkedHashMap;
|
import java.util.List;
|
import java.util.Map;
|
|
@Service
|
public class AiChatStreamOrchestrator {
|
|
@Resource
|
private AiSessionService aiSessionService;
|
@Resource
|
private AiGatewayClient aiGatewayClient;
|
@Resource
|
private AiPromptRuntimeService aiPromptRuntimeService;
|
@Resource
|
private AiDiagnosticToolService aiDiagnosticToolService;
|
@Resource
|
private AiModelRouteRuntimeService aiModelRouteRuntimeService;
|
@Resource
|
private AiDiagnosisRuntimeService aiDiagnosisRuntimeService;
|
@Resource
|
private AiDiagnosisMcpRuntimeService aiDiagnosisMcpRuntimeService;
|
|
/**
|
* 从候选模型列表中挑出最大的上下文窗口,用于提前截断会话历史。
|
*/
|
public int resolveContextSize(List<AiModelRouteRuntimeService.RouteCandidate> candidates) {
|
int max = 12;
|
for (AiModelRouteRuntimeService.RouteCandidate item : candidates) {
|
if (item != null && item.getRuntimeConfig() != null
|
&& item.getRuntimeConfig().getMaxContextMessages() != null
|
&& item.getRuntimeConfig().getMaxContextMessages() > max) {
|
max = item.getRuntimeConfig().getMaxContextMessages();
|
}
|
}
|
return max;
|
}
|
|
/**
|
* 执行一次完整的流式聊天/诊断编排。
|
* 这里统一负责 MCP 结果准备、模型重试、事件转发、调用日志和诊断收尾。
|
*/
|
public void executeStream(SseEmitter emitter,
|
Long tenantId,
|
Long userId,
|
AiChatSession session,
|
AiChatStreamRequest request,
|
AiPromptContext promptContext,
|
List<AiChatMessage> contextMessages,
|
AiDiagnosisRecord diagnosisRecord,
|
List<AiModelRouteRuntimeService.RouteCandidate> candidates) {
|
StringBuilder assistantReply = new StringBuilder();
|
String finalModelCode = session.getModelCode();
|
String finalErrorMessage = null;
|
boolean stopped = false;
|
boolean success = false;
|
boolean assistantSaved = false;
|
boolean errorSent = false;
|
List<AiDiagnosticToolResult> runtimeDiagnosticResults = new ArrayList<>();
|
String toolSummary = "[]";
|
try {
|
emitter.send(SseEmitter.event().name("session").data(buildSessionPayload(session), MediaType.APPLICATION_JSON));
|
if (aiDiagnosisMcpRuntimeService.shouldUseMcp(promptContext)) {
|
runtimeDiagnosticResults = aiDiagnosisMcpRuntimeService.resolveToolResults(
|
tenantId,
|
promptContext,
|
contextMessages,
|
candidates.isEmpty() ? null : candidates.get(0)
|
);
|
toolSummary = aiDiagnosticToolService.serializeResults(runtimeDiagnosticResults);
|
}
|
int attemptNo = 1;
|
for (AiModelRouteRuntimeService.RouteCandidate candidate : candidates) {
|
AttemptState attemptState = new AttemptState();
|
Date requestTime = new Date();
|
try {
|
GatewayChatRequest gatewayChatRequest = buildGatewayRequest(
|
session,
|
contextMessages,
|
candidate,
|
request,
|
promptContext,
|
runtimeDiagnosticResults,
|
attemptNo
|
);
|
aiGatewayClient.stream(gatewayChatRequest, event -> handleGatewayEvent(
|
emitter,
|
event,
|
session,
|
assistantReply,
|
attemptState
|
));
|
} catch (Exception e) {
|
attemptState.setSuccess(false);
|
attemptState.setErrorMessage(e.getMessage());
|
attemptState.setInterrupted(isInterruptedError(e));
|
attemptState.setResponseTime(new Date());
|
}
|
if (attemptState.getResponseTime() == null) {
|
attemptState.setResponseTime(new Date());
|
}
|
String actualModelCode = attemptState.getActualModelCode() == null
|
? candidate.getAttemptModelCode()
|
: attemptState.getActualModelCode();
|
finalModelCode = actualModelCode;
|
aiDiagnosisRuntimeService.saveCallLog(
|
tenantId,
|
userId,
|
session.getId(),
|
diagnosisRecord == null ? null : diagnosisRecord.getId(),
|
resolveRouteCode(candidate, request),
|
actualModelCode,
|
attemptNo,
|
requestTime,
|
attemptState.getResponseTime(),
|
Boolean.TRUE.equals(attemptState.getSuccess()) ? 1 : 0,
|
attemptState.getErrorMessage()
|
);
|
if (Boolean.TRUE.equals(attemptState.getSuccess())) {
|
aiModelRouteRuntimeService.markSuccess(candidate.getRouteId());
|
success = assistantReply.length() > 0;
|
if (!success) {
|
finalErrorMessage = "模型未返回有效内容";
|
}
|
break;
|
}
|
if (attemptState.isStopped() || aiSessionService.isStopRequested(session.getId())) {
|
stopped = true;
|
break;
|
}
|
if (!attemptState.isInterrupted()) {
|
aiModelRouteRuntimeService.markFailure(candidate.getRouteId());
|
}
|
finalErrorMessage = attemptState.getErrorMessage();
|
if (attemptState.isReceivedDelta() || attemptNo >= candidates.size()) {
|
if (!attemptState.isInterrupted() && finalErrorMessage != null && !finalErrorMessage.trim().isEmpty()) {
|
emitter.send(SseEmitter.event().name("error").data(buildErrorPayload(finalErrorMessage), MediaType.APPLICATION_JSON));
|
errorSent = true;
|
}
|
break;
|
}
|
attemptNo++;
|
}
|
|
if (aiSessionService.isStopRequested(session.getId())) {
|
stopped = true;
|
}
|
|
if (stopped) {
|
if (assistantReply.length() > 0) {
|
aiSessionService.appendMessage(tenantId, userId, session.getId(), "assistant", assistantReply.toString(), finalModelCode);
|
assistantSaved = true;
|
}
|
emitter.send(SseEmitter.event().name("done").data(buildDonePayload(session, finalModelCode, true), MediaType.APPLICATION_JSON));
|
if (diagnosisRecord != null) {
|
aiDiagnosisRuntimeService.finishDiagnosisFailure(diagnosisRecord, assistantReply.toString(), "用户已停止生成", toolSummary);
|
}
|
return;
|
}
|
|
if (success) {
|
aiSessionService.appendMessage(tenantId, userId, session.getId(), "assistant", assistantReply.toString(), finalModelCode);
|
assistantSaved = true;
|
emitter.send(SseEmitter.event().name("done").data(buildDonePayload(session, finalModelCode, false), MediaType.APPLICATION_JSON));
|
if (diagnosisRecord != null) {
|
aiDiagnosisRuntimeService.finishDiagnosisSuccess(diagnosisRecord, assistantReply.toString(), finalModelCode, toolSummary);
|
}
|
return;
|
}
|
|
if (assistantReply.length() > 0 && !assistantSaved) {
|
aiSessionService.appendMessage(tenantId, userId, session.getId(), "assistant", assistantReply.toString(), finalModelCode);
|
assistantSaved = true;
|
}
|
if (diagnosisRecord != null) {
|
aiDiagnosisRuntimeService.finishDiagnosisFailure(diagnosisRecord, assistantReply.toString(), finalErrorMessage, toolSummary);
|
}
|
if (!errorSent && finalErrorMessage != null && !finalErrorMessage.trim().isEmpty()) {
|
emitter.send(SseEmitter.event().name("error").data(buildErrorPayload(finalErrorMessage), MediaType.APPLICATION_JSON));
|
}
|
} catch (Exception e) {
|
if (diagnosisRecord != null) {
|
aiDiagnosisRuntimeService.finishDiagnosisFailure(diagnosisRecord, assistantReply.toString(), e.getMessage(), toolSummary);
|
}
|
if (!isInterruptedError(e)) {
|
try {
|
emitter.send(SseEmitter.event().name("error").data(buildErrorPayload(e.getMessage()), MediaType.APPLICATION_JSON));
|
} catch (IOException ignore) {
|
}
|
} else {
|
Thread.currentThread().interrupt();
|
}
|
} finally {
|
emitter.complete();
|
aiSessionService.clearStopFlag(session.getId());
|
}
|
}
|
|
/**
|
* 消费网关返回的单条流式事件,并把状态写回本次尝试上下文。
|
*/
|
private boolean handleGatewayEvent(SseEmitter emitter, JsonNode event, AiChatSession session,
|
StringBuilder assistantReply, AttemptState attemptState) throws Exception {
|
if (aiSessionService.isStopRequested(session.getId())) {
|
attemptState.setStopped(true);
|
attemptState.setResponseTime(new Date());
|
return false;
|
}
|
String type = event.path("type").asText();
|
String modelCode = event.path("modelCode").asText(session.getModelCode());
|
if ("delta".equals(type)) {
|
String content = event.path("content").asText("");
|
assistantReply.append(content);
|
attemptState.setReceivedDelta(true);
|
attemptState.setActualModelCode(modelCode);
|
emitter.send(SseEmitter.event().name("delta").data(buildDeltaPayload(session, modelCode, content), MediaType.APPLICATION_JSON));
|
return true;
|
}
|
if ("error".equals(type)) {
|
String message = event.path("message").asText("模型调用失败");
|
attemptState.setSuccess(false);
|
attemptState.setErrorMessage(message);
|
attemptState.setActualModelCode(modelCode);
|
attemptState.setResponseTime(parseResponseTime(event));
|
attemptState.setInterrupted(isInterruptedMessage(message));
|
return false;
|
}
|
if ("done".equals(type)) {
|
attemptState.setSuccess(true);
|
attemptState.setActualModelCode(modelCode);
|
attemptState.setResponseTime(parseResponseTime(event));
|
return false;
|
}
|
if ("ping".equals(type)) {
|
emitter.send(SseEmitter.event().name("ping").data(buildPingPayload(modelCode), MediaType.APPLICATION_JSON));
|
return true;
|
}
|
return true;
|
}
|
|
/**
|
* 将当前尝试需要的上下文、系统 Prompt 和模型信息组装成网关请求。
|
*/
|
private GatewayChatRequest buildGatewayRequest(AiChatSession session,
|
List<AiChatMessage> contextMessages,
|
AiModelRouteRuntimeService.RouteCandidate candidate,
|
AiChatStreamRequest chatRequest,
|
AiPromptContext promptContext,
|
List<AiDiagnosticToolResult> diagnosticResults,
|
Integer attemptNo) {
|
GatewayChatRequest request = new GatewayChatRequest();
|
request.setSessionId(session.getId());
|
request.setModelCode(candidate.getAttemptModelCode());
|
request.setRouteCode(resolveRouteCode(candidate, chatRequest));
|
request.setAttemptNo(attemptNo);
|
request.setSystemPrompt(aiPromptRuntimeService.buildSystemPrompt(
|
chatRequest.getSceneCode(),
|
candidate.getRuntimeConfig().getSystemPrompt(),
|
promptContext,
|
diagnosticResults
|
));
|
request.setChatUrl(candidate.getRuntimeConfig().getChatUrl());
|
request.setApiKey(candidate.getRuntimeConfig().getApiKey());
|
request.setModelName(candidate.getRuntimeConfig().getModelName());
|
for (AiChatMessage contextMessage : contextMessages) {
|
GatewayChatMessage item = new GatewayChatMessage();
|
item.setRole(contextMessage.getRole());
|
item.setContent(contextMessage.getContent());
|
request.getMessages().add(item);
|
}
|
return request;
|
}
|
|
/**
|
* 解析当前尝试应落到哪个路由编码,优先使用路由候选自带的 routeCode。
|
*/
|
private String resolveRouteCode(AiModelRouteRuntimeService.RouteCandidate candidate, AiChatStreamRequest request) {
|
if (candidate != null && candidate.getRouteCode() != null && !candidate.getRouteCode().trim().isEmpty()) {
|
return candidate.getRouteCode();
|
}
|
return AiSceneCode.SYSTEM_DIAGNOSE.equals(request.getSceneCode())
|
? AiSceneCode.SYSTEM_DIAGNOSE
|
: AiSceneCode.GENERAL_CHAT;
|
}
|
|
/**
|
* 从网关事件中解析响应时间,缺省时回退为当前时间。
|
*/
|
private Date parseResponseTime(JsonNode event) {
|
long millis = event.path("responseTime").asLong(0L);
|
return millis <= 0L ? new Date() : new Date(millis);
|
}
|
|
/**
|
* 构造会话初始化事件给前端。
|
*/
|
private Map<String, Object> buildSessionPayload(AiChatSession session) {
|
Map<String, Object> payload = new LinkedHashMap<>();
|
payload.put("sessionId", session.getId());
|
payload.put("title", session.getTitle());
|
payload.put("modelCode", session.getModelCode());
|
return payload;
|
}
|
|
/**
|
* 构造增量输出事件给前端。
|
*/
|
private Map<String, Object> buildDeltaPayload(AiChatSession session, String modelCode, String content) {
|
Map<String, Object> payload = new LinkedHashMap<>();
|
payload.put("sessionId", session.getId());
|
payload.put("modelCode", modelCode);
|
payload.put("content", content);
|
payload.put("timestamp", new Date().getTime());
|
return payload;
|
}
|
|
/**
|
* 构造流式完成事件给前端。
|
*/
|
private Map<String, Object> buildDonePayload(AiChatSession session, String modelCode, boolean stopped) {
|
Map<String, Object> payload = new LinkedHashMap<>();
|
payload.put("sessionId", session.getId());
|
payload.put("modelCode", modelCode);
|
payload.put("stopped", stopped);
|
return payload;
|
}
|
|
/**
|
* 构造错误事件给前端。
|
*/
|
private Map<String, Object> buildErrorPayload(String message) {
|
Map<String, Object> payload = new LinkedHashMap<>();
|
payload.put("message", message == null ? "AI服务异常" : message);
|
return payload;
|
}
|
|
/**
|
* 构造心跳事件给前端,用于维持长连接活性。
|
*/
|
private Map<String, Object> buildPingPayload(String modelCode) {
|
Map<String, Object> payload = new LinkedHashMap<>();
|
payload.put("modelCode", modelCode);
|
payload.put("timestamp", new Date().getTime());
|
return payload;
|
}
|
|
/**
|
* 判断异常链中是否包含线程中断类错误。
|
*/
|
private boolean isInterruptedError(Throwable throwable) {
|
Throwable current = throwable;
|
while (current != null) {
|
if (current instanceof InterruptedException || current instanceof InterruptedIOException) {
|
return true;
|
}
|
if (isInterruptedMessage(current.getMessage())) {
|
return true;
|
}
|
current = current.getCause();
|
}
|
return false;
|
}
|
|
private boolean isInterruptedMessage(String message) {
|
if (message == null || message.trim().isEmpty()) {
|
return false;
|
}
|
String normalized = message.toLowerCase();
|
return normalized.contains("interrupted")
|
|| normalized.contains("broken pipe")
|
|| normalized.contains("connection reset")
|
|| normalized.contains("forcibly closed");
|
}
|
|
private static class AttemptState {
|
private Boolean success;
|
private String actualModelCode;
|
private String errorMessage;
|
private boolean receivedDelta;
|
private boolean interrupted;
|
private boolean stopped;
|
private Date responseTime;
|
|
public Boolean getSuccess() {
|
return success;
|
}
|
|
public void setSuccess(Boolean success) {
|
this.success = success;
|
}
|
|
public String getActualModelCode() {
|
return actualModelCode;
|
}
|
|
public void setActualModelCode(String actualModelCode) {
|
this.actualModelCode = actualModelCode;
|
}
|
|
public String getErrorMessage() {
|
return errorMessage;
|
}
|
|
public void setErrorMessage(String errorMessage) {
|
this.errorMessage = errorMessage;
|
}
|
|
public boolean isReceivedDelta() {
|
return receivedDelta;
|
}
|
|
public void setReceivedDelta(boolean receivedDelta) {
|
this.receivedDelta = receivedDelta;
|
}
|
|
public boolean isInterrupted() {
|
return interrupted;
|
}
|
|
public void setInterrupted(boolean interrupted) {
|
this.interrupted = interrupted;
|
}
|
|
public boolean isStopped() {
|
return stopped;
|
}
|
|
public void setStopped(boolean stopped) {
|
this.stopped = stopped;
|
}
|
|
public Date getResponseTime() {
|
return responseTime;
|
}
|
|
public void setResponseTime(Date responseTime) {
|
this.responseTime = responseTime;
|
}
|
}
|
}
|