package com.vincent.rsf.ai.gateway.service; import com.vincent.rsf.ai.gateway.config.AiGatewayProperties; import com.vincent.rsf.ai.gateway.dto.GatewayChatMessage; import com.vincent.rsf.ai.gateway.dto.GatewayChatRequest; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.stereotype.Service; import javax.annotation.Resource; import java.io.BufferedReader; import java.io.InputStream; import java.io.InputStreamReader; import java.io.InterruptedIOException; import java.io.OutputStream; import java.net.HttpURLConnection; import java.net.URL; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @Service public class AiGatewayService { private static final Logger logger = LoggerFactory.getLogger(AiGatewayService.class); @Resource private AiGatewayProperties aiGatewayProperties; @Resource private ObjectMapper objectMapper; public interface EventConsumer { void accept(GatewayStreamEvent event) throws Exception; } public void stream(GatewayChatRequest request, EventConsumer consumer) throws Exception { AiGatewayProperties.ModelConfig modelConfig = resolveModel(request); logger.info("AI gateway stream start: sessionId={}, routeCode={}, attemptNo={}, requestModelCode={}, resolvedModelCode={}, provider={}", request.getSessionId(), request.getRouteCode(), request.getAttemptNo(), request.getModelCode(), modelConfig == null ? null : modelConfig.getCode(), modelConfig == null ? null : modelConfig.getProvider()); if (modelConfig == null || modelConfig.getChatUrl() == null || modelConfig.getChatUrl().trim().isEmpty()) { logger.info("AI gateway use mock stream: sessionId={}, routeCode={}, attemptNo={}, modelCode={}, provider={}", request.getSessionId(), request.getRouteCode(), request.getAttemptNo(), modelConfig == null ? request.getModelCode() : modelConfig.getCode(), modelConfig == null ? "mock" : modelConfig.getProvider()); mockStream(request, modelConfig, consumer); return; } openAiCompatibleStream(request, modelConfig, consumer); } private AiGatewayProperties.ModelConfig resolveModel(GatewayChatRequest request) { String modelCode = request.getModelCode(); String targetCode = (modelCode == null || modelCode.trim().isEmpty()) ? aiGatewayProperties.getDefaultModelCode() : modelCode; for (AiGatewayProperties.ModelConfig model : aiGatewayProperties.getModels()) { if (Boolean.TRUE.equals(model.getEnabled()) && targetCode.equals(model.getCode())) { return mergeRequestOverride(model, request); } } if ((request.getChatUrl() != null && !request.getChatUrl().trim().isEmpty()) || (request.getModelName() != null && !request.getModelName().trim().isEmpty())) { AiGatewayProperties.ModelConfig modelConfig = new AiGatewayProperties.ModelConfig(); modelConfig.setCode(targetCode); modelConfig.setName(targetCode); modelConfig.setProvider("custom"); modelConfig.setEnabled(true); return mergeRequestOverride(modelConfig, request); } return null; } private AiGatewayProperties.ModelConfig mergeRequestOverride(AiGatewayProperties.ModelConfig source, GatewayChatRequest request) { AiGatewayProperties.ModelConfig target = new AiGatewayProperties.ModelConfig(); target.setCode(source.getCode()); target.setName(source.getName()); target.setProvider(source.getProvider()); target.setChatUrl(normalizeChatUrl(source.getChatUrl())); target.setApiKey(source.getApiKey()); target.setModelName(source.getModelName()); target.setEnabled(source.getEnabled()); if (request.getChatUrl() != null && !request.getChatUrl().trim().isEmpty()) { target.setChatUrl(normalizeChatUrl(request.getChatUrl().trim())); } if (request.getApiKey() != null && !request.getApiKey().trim().isEmpty()) { target.setApiKey(request.getApiKey().trim()); } if (request.getModelName() != null && !request.getModelName().trim().isEmpty()) { target.setModelName(request.getModelName().trim()); } return target; } private void mockStream(GatewayChatRequest request, AiGatewayProperties.ModelConfig modelConfig, EventConsumer consumer) throws Exception { long requestTime = System.currentTimeMillis(); String modelCode = modelConfig == null ? aiGatewayProperties.getDefaultModelCode() : modelConfig.getCode(); String lastQuestion = ""; List messages = request.getMessages(); for (int i = messages.size() - 1; i >= 0; i--) { GatewayChatMessage message = messages.get(i); if ("user".equalsIgnoreCase(message.getRole())) { lastQuestion = message.getContent(); break; } } String answer = "当前为演示模式,模型[" + modelCode + "]已收到你的问题:" + lastQuestion; logger.info("AI gateway mock stream emitting response: sessionId={}, routeCode={}, attemptNo={}, modelCode={}, answerLength={}", request.getSessionId(), request.getRouteCode(), request.getAttemptNo(), modelCode, answer.length()); for (char c : answer.toCharArray()) { consumer.accept(new GatewayStreamEvent() .setType("delta") .setModelCode(modelCode) .setContent(String.valueOf(c))); try { Thread.sleep(20L); } catch (InterruptedException e) { Thread.currentThread().interrupt(); return; } } consumer.accept(new GatewayStreamEvent() .setType("done") .setModelCode(modelCode) .setSuccess(true) .setRequestTime(requestTime) .setResponseTime(System.currentTimeMillis()) .setDurationMs(System.currentTimeMillis() - requestTime)); logger.info("AI gateway mock stream completed: sessionId={}, routeCode={}, attemptNo={}, modelCode={}, durationMs={}", request.getSessionId(), request.getRouteCode(), request.getAttemptNo(), modelCode, System.currentTimeMillis() - requestTime); } private void openAiCompatibleStream(GatewayChatRequest request, AiGatewayProperties.ModelConfig modelConfig, EventConsumer consumer) throws Exception { HttpURLConnection connection = null; long requestTime = System.currentTimeMillis(); boolean terminalEventSent = false; int eventLineCount = 0; int deltaCount = 0; int contentChars = 0; boolean firstDeltaLogged = false; String normalizedUrl = modelConfig == null ? null : modelConfig.getChatUrl(); try { logger.info("AI gateway opening upstream stream: sessionId={}, routeCode={}, attemptNo={}, modelCode={}, provider={}, url={}, modelName={}", request.getSessionId(), request.getRouteCode(), request.getAttemptNo(), modelConfig == null ? null : modelConfig.getCode(), modelConfig == null ? null : modelConfig.getProvider(), normalizedUrl, modelConfig == null ? null : modelConfig.getModelName()); connection = (HttpURLConnection) new URL(modelConfig.getChatUrl()).openConnection(); connection.setConnectTimeout(aiGatewayProperties.getConnectTimeoutMillis()); connection.setReadTimeout(aiGatewayProperties.getReadTimeoutMillis()); connection.setRequestMethod("POST"); connection.setDoOutput(true); connection.setRequestProperty("Content-Type", "application/json"); connection.setRequestProperty("Accept", "text/event-stream"); if (modelConfig.getApiKey() != null && !modelConfig.getApiKey().trim().isEmpty()) { connection.setRequestProperty("Authorization", "Bearer " + modelConfig.getApiKey().trim()); } Map body = new LinkedHashMap<>(); body.put("model", modelConfig.getModelName()); body.put("stream", true); body.put("messages", buildMessages(request)); try (OutputStream outputStream = connection.getOutputStream()) { outputStream.write(objectMapper.writeValueAsBytes(body)); outputStream.flush(); } int statusCode = connection.getResponseCode(); logger.info("AI gateway upstream response received: sessionId={}, routeCode={}, attemptNo={}, modelCode={}, statusCode={}", request.getSessionId(), request.getRouteCode(), request.getAttemptNo(), modelConfig.getCode(), statusCode); InputStream inputStream = statusCode >= 400 ? connection.getErrorStream() : connection.getInputStream(); if (inputStream == null) { logger.warn("AI gateway upstream returned empty stream: sessionId={}, routeCode={}, attemptNo={}, modelCode={}, url={}, statusCode={}", request.getSessionId(), request.getRouteCode(), request.getAttemptNo(), modelConfig.getCode(), normalizedUrl, statusCode); consumer.accept(new GatewayStreamEvent() .setType("error") .setModelCode(modelConfig.getCode()) .setMessage("模型服务无响应") .setSuccess(false) .setRequestTime(requestTime) .setResponseTime(System.currentTimeMillis()) .setDurationMs(System.currentTimeMillis() - requestTime)); terminalEventSent = true; return; } if (statusCode >= 400) { logger.warn("AI gateway upstream http error: sessionId={}, routeCode={}, attemptNo={}, modelCode={}, url={}, statusCode={}", request.getSessionId(), request.getRouteCode(), request.getAttemptNo(), modelConfig.getCode(), normalizedUrl, statusCode); consumer.accept(new GatewayStreamEvent() .setType("error") .setModelCode(modelConfig.getCode()) .setMessage(readErrorMessage(inputStream, statusCode)) .setSuccess(false) .setRequestTime(requestTime) .setResponseTime(System.currentTimeMillis()) .setDurationMs(System.currentTimeMillis() - requestTime)); terminalEventSent = true; return; } try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) { String line; while ((line = reader.readLine()) != null) { if (line.trim().isEmpty() || !line.startsWith("data:")) { continue; } eventLineCount++; String payload = line.substring(5).trim(); if ("[DONE]".equals(payload)) { long responseTime = System.currentTimeMillis(); logger.info("AI gateway upstream done marker received: sessionId={}, routeCode={}, attemptNo={}, modelCode={}, eventLines={}, deltaCount={}, contentChars={}, durationMs={}", request.getSessionId(), request.getRouteCode(), request.getAttemptNo(), modelConfig.getCode(), eventLineCount, deltaCount, contentChars, responseTime - requestTime); consumer.accept(new GatewayStreamEvent() .setType("done") .setModelCode(modelConfig.getCode()) .setSuccess(true) .setRequestTime(requestTime) .setResponseTime(responseTime) .setDurationMs(responseTime - requestTime)); terminalEventSent = true; break; } JsonNode root = objectMapper.readTree(payload); JsonNode choice = root.path("choices").path(0); JsonNode delta = choice.path("delta"); JsonNode contentNode = delta.path("content"); if (!contentNode.isMissingNode() && !contentNode.isNull()) { String content = contentNode.asText(); deltaCount++; contentChars += content.length(); if (!firstDeltaLogged) { logger.info("AI gateway upstream first delta received: sessionId={}, routeCode={}, attemptNo={}, modelCode={}, afterMs={}, sampleLength={}", request.getSessionId(), request.getRouteCode(), request.getAttemptNo(), modelConfig.getCode(), System.currentTimeMillis() - requestTime, content.length()); firstDeltaLogged = true; } consumer.accept(new GatewayStreamEvent() .setType("delta") .setModelCode(modelConfig.getCode()) .setContent(content)); } JsonNode finishReason = choice.path("finish_reason"); if (!finishReason.isMissingNode() && !finishReason.isNull()) { long responseTime = System.currentTimeMillis(); logger.info("AI gateway upstream finish_reason received: sessionId={}, routeCode={}, attemptNo={}, modelCode={}, finishReason={}, eventLines={}, deltaCount={}, contentChars={}, durationMs={}", request.getSessionId(), request.getRouteCode(), request.getAttemptNo(), modelConfig.getCode(), finishReason.asText(), eventLineCount, deltaCount, contentChars, responseTime - requestTime); consumer.accept(new GatewayStreamEvent() .setType("done") .setModelCode(modelConfig.getCode()) .setSuccess(true) .setRequestTime(requestTime) .setResponseTime(responseTime) .setDurationMs(responseTime - requestTime)); terminalEventSent = true; break; } } } if (!terminalEventSent) { long responseTime = System.currentTimeMillis(); logger.warn("AI gateway upstream ended without terminal event: sessionId={}, routeCode={}, attemptNo={}, modelCode={}, url={}, eventLines={}, deltaCount={}, contentChars={}, durationMs={}", request.getSessionId(), request.getRouteCode(), request.getAttemptNo(), modelConfig.getCode(), normalizedUrl, eventLineCount, deltaCount, contentChars, responseTime - requestTime); consumer.accept(new GatewayStreamEvent() .setType("error") .setModelCode(modelConfig.getCode()) .setMessage("模型流异常中断") .setSuccess(false) .setRequestTime(requestTime) .setResponseTime(responseTime) .setDurationMs(responseTime - requestTime)); } } catch (Exception e) { if (isInterruptedError(e)) { logger.warn("AI gateway upstream interrupted: sessionId={}, routeCode={}, attemptNo={}, modelCode={}, url={}, stage={}, message={}", request.getSessionId(), request.getRouteCode(), request.getAttemptNo(), modelConfig == null ? null : modelConfig.getCode(), normalizedUrl, terminalEventSent ? "after_terminal" : "streaming", e.getMessage()); if (e instanceof InterruptedException || e instanceof InterruptedIOException) { Thread.currentThread().interrupt(); } return; } logger.error("AI gateway upstream exception: sessionId={}, routeCode={}, attemptNo={}, modelCode={}, url={}, eventLines={}, deltaCount={}, contentChars={}, message={}", request.getSessionId(), request.getRouteCode(), request.getAttemptNo(), modelConfig == null ? null : modelConfig.getCode(), normalizedUrl, eventLineCount, deltaCount, contentChars, e.getMessage(), e); consumer.accept(new GatewayStreamEvent() .setType("error") .setModelCode(modelConfig.getCode()) .setMessage(e.getMessage()) .setSuccess(false) .setRequestTime(requestTime) .setResponseTime(System.currentTimeMillis()) .setDurationMs(System.currentTimeMillis() - requestTime)); } finally { if (connection != null) { connection.disconnect(); } logger.info("AI gateway upstream stream closed: sessionId={}, routeCode={}, attemptNo={}, modelCode={}, terminalEventSent={}, eventLines={}, deltaCount={}, contentChars={}", request.getSessionId(), request.getRouteCode(), request.getAttemptNo(), modelConfig == null ? null : modelConfig.getCode(), terminalEventSent, eventLineCount, deltaCount, contentChars); } } private List> buildMessages(GatewayChatRequest request) { List> output = new ArrayList<>(); if (request.getSystemPrompt() != null && !request.getSystemPrompt().trim().isEmpty()) { Map systemMessage = new LinkedHashMap<>(); systemMessage.put("role", "system"); systemMessage.put("content", request.getSystemPrompt()); output.add(systemMessage); } for (GatewayChatMessage message : request.getMessages()) { Map item = new LinkedHashMap<>(); item.put("role", message.getRole()); item.put("content", message.getContent()); output.add(item); } return output; } private String normalizeChatUrl(String chatUrl) { if (chatUrl == null) { return null; } String normalized = chatUrl.trim(); if (normalized.isEmpty()) { return normalized; } if (normalized.endsWith("/chat/completions") || normalized.endsWith("/v1/chat/completions")) { return normalized; } if (normalized.endsWith("/v1")) { return normalized + "/chat/completions"; } if (normalized.contains("/v1/")) { return normalized; } if (normalized.endsWith("/")) { return normalized + "v1/chat/completions"; } return normalized + "/v1/chat/completions"; } private String readErrorMessage(InputStream inputStream, int statusCode) { try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) { StringBuilder builder = new StringBuilder(); String line; while ((line = reader.readLine()) != null) { builder.append(line); } String body = builder.toString(); if (body.isEmpty()) { return "模型服务调用失败,状态码:" + statusCode; } JsonNode root = objectMapper.readTree(body); JsonNode errorNode = root.path("error"); if (!errorNode.isMissingNode() && !errorNode.isNull()) { String message = errorNode.path("message").asText(""); if (!message.isEmpty()) { return message; } } if (root.path("message").isTextual()) { return root.path("message").asText(); } return body; } catch (Exception ignore) { return "模型服务调用失败,状态码:" + statusCode; } } private boolean isInterruptedError(Throwable throwable) { Throwable current = throwable; while (current != null) { if (current instanceof InterruptedException || current instanceof InterruptedIOException) { return true; } String message = current.getMessage(); if (message != null) { String normalized = message.toLowerCase(); if (normalized.contains("interrupted") || normalized.contains("broken pipe") || normalized.contains("connection reset") || normalized.contains("forcibly closed")) { return true; } } current = current.getCause(); } return false; } }