package com.vincent.rsf.ai.gateway.service; import com.vincent.rsf.ai.gateway.config.AiGatewayProperties; import com.vincent.rsf.ai.gateway.dto.GatewayChatMessage; import com.vincent.rsf.ai.gateway.dto.GatewayChatRequest; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import org.springframework.stereotype.Service; import javax.annotation.Resource; import java.io.BufferedReader; import java.io.InputStream; import java.io.InputStreamReader; import java.io.OutputStream; import java.net.HttpURLConnection; import java.net.URL; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @Service public class AiGatewayService { @Resource private AiGatewayProperties aiGatewayProperties; @Resource private ObjectMapper objectMapper; public interface EventConsumer { void accept(GatewayStreamEvent event) throws Exception; } public void stream(GatewayChatRequest request, EventConsumer consumer) throws Exception { AiGatewayProperties.ModelConfig modelConfig = resolveModel(request); if (modelConfig == null || modelConfig.getChatUrl() == null || modelConfig.getChatUrl().trim().isEmpty()) { mockStream(request, modelConfig, consumer); return; } openAiCompatibleStream(request, modelConfig, consumer); } private AiGatewayProperties.ModelConfig resolveModel(GatewayChatRequest request) { String modelCode = request.getModelCode(); String targetCode = (modelCode == null || modelCode.trim().isEmpty()) ? aiGatewayProperties.getDefaultModelCode() : modelCode; for (AiGatewayProperties.ModelConfig model : aiGatewayProperties.getModels()) { if (Boolean.TRUE.equals(model.getEnabled()) && targetCode.equals(model.getCode())) { return mergeRequestOverride(model, request); } } if ((request.getChatUrl() != null && !request.getChatUrl().trim().isEmpty()) || (request.getModelName() != null && !request.getModelName().trim().isEmpty())) { AiGatewayProperties.ModelConfig modelConfig = new AiGatewayProperties.ModelConfig(); modelConfig.setCode(targetCode); modelConfig.setName(targetCode); modelConfig.setProvider("custom"); modelConfig.setEnabled(true); return mergeRequestOverride(modelConfig, request); } return null; } private AiGatewayProperties.ModelConfig mergeRequestOverride(AiGatewayProperties.ModelConfig source, GatewayChatRequest request) { AiGatewayProperties.ModelConfig target = new AiGatewayProperties.ModelConfig(); target.setCode(source.getCode()); target.setName(source.getName()); target.setProvider(source.getProvider()); target.setChatUrl(normalizeChatUrl(source.getChatUrl())); target.setApiKey(source.getApiKey()); target.setModelName(source.getModelName()); target.setEnabled(source.getEnabled()); if (request.getChatUrl() != null && !request.getChatUrl().trim().isEmpty()) { target.setChatUrl(normalizeChatUrl(request.getChatUrl().trim())); } if (request.getApiKey() != null && !request.getApiKey().trim().isEmpty()) { target.setApiKey(request.getApiKey().trim()); } if (request.getModelName() != null && !request.getModelName().trim().isEmpty()) { target.setModelName(request.getModelName().trim()); } return target; } private void mockStream(GatewayChatRequest request, AiGatewayProperties.ModelConfig modelConfig, EventConsumer consumer) throws Exception { String modelCode = modelConfig == null ? aiGatewayProperties.getDefaultModelCode() : modelConfig.getCode(); String lastQuestion = ""; List messages = request.getMessages(); for (int i = messages.size() - 1; i >= 0; i--) { GatewayChatMessage message = messages.get(i); if ("user".equalsIgnoreCase(message.getRole())) { lastQuestion = message.getContent(); break; } } String answer = "当前为演示模式,模型[" + modelCode + "]已收到你的问题:" + lastQuestion; for (char c : answer.toCharArray()) { consumer.accept(new GatewayStreamEvent() .setType("delta") .setModelCode(modelCode) .setContent(String.valueOf(c))); Thread.sleep(20L); } consumer.accept(new GatewayStreamEvent() .setType("done") .setModelCode(modelCode)); } private void openAiCompatibleStream(GatewayChatRequest request, AiGatewayProperties.ModelConfig modelConfig, EventConsumer consumer) throws Exception { HttpURLConnection connection = null; try { connection = (HttpURLConnection) new URL(modelConfig.getChatUrl()).openConnection(); connection.setConnectTimeout(aiGatewayProperties.getConnectTimeoutMillis()); connection.setReadTimeout(aiGatewayProperties.getReadTimeoutMillis()); connection.setRequestMethod("POST"); connection.setDoOutput(true); connection.setRequestProperty("Content-Type", "application/json"); connection.setRequestProperty("Accept", "text/event-stream"); if (modelConfig.getApiKey() != null && !modelConfig.getApiKey().trim().isEmpty()) { connection.setRequestProperty("Authorization", "Bearer " + modelConfig.getApiKey().trim()); } Map body = new LinkedHashMap<>(); body.put("model", modelConfig.getModelName()); body.put("stream", true); body.put("messages", buildMessages(request)); try (OutputStream outputStream = connection.getOutputStream()) { outputStream.write(objectMapper.writeValueAsBytes(body)); outputStream.flush(); } int statusCode = connection.getResponseCode(); InputStream inputStream = statusCode >= 400 ? connection.getErrorStream() : connection.getInputStream(); if (inputStream == null) { consumer.accept(new GatewayStreamEvent() .setType("error") .setModelCode(modelConfig.getCode()) .setMessage("模型服务无响应")); return; } if (statusCode >= 400) { consumer.accept(new GatewayStreamEvent() .setType("error") .setModelCode(modelConfig.getCode()) .setMessage(readErrorMessage(inputStream, statusCode))); return; } try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) { String line; while ((line = reader.readLine()) != null) { if (line.trim().isEmpty() || !line.startsWith("data:")) { continue; } String payload = line.substring(5).trim(); if ("[DONE]".equals(payload)) { consumer.accept(new GatewayStreamEvent() .setType("done") .setModelCode(modelConfig.getCode())); break; } JsonNode root = objectMapper.readTree(payload); JsonNode choice = root.path("choices").path(0); JsonNode delta = choice.path("delta"); JsonNode contentNode = delta.path("content"); if (!contentNode.isMissingNode() && !contentNode.isNull()) { consumer.accept(new GatewayStreamEvent() .setType("delta") .setModelCode(modelConfig.getCode()) .setContent(contentNode.asText())); } JsonNode finishReason = choice.path("finish_reason"); if (!finishReason.isMissingNode() && !finishReason.isNull()) { consumer.accept(new GatewayStreamEvent() .setType("done") .setModelCode(modelConfig.getCode())); break; } } } } catch (Exception e) { consumer.accept(new GatewayStreamEvent() .setType("error") .setModelCode(modelConfig.getCode()) .setMessage(e.getMessage())); } finally { if (connection != null) { connection.disconnect(); } } } private List> buildMessages(GatewayChatRequest request) { List> output = new ArrayList<>(); if (request.getSystemPrompt() != null && !request.getSystemPrompt().trim().isEmpty()) { Map systemMessage = new LinkedHashMap<>(); systemMessage.put("role", "system"); systemMessage.put("content", request.getSystemPrompt()); output.add(systemMessage); } for (GatewayChatMessage message : request.getMessages()) { Map item = new LinkedHashMap<>(); item.put("role", message.getRole()); item.put("content", message.getContent()); output.add(item); } return output; } private String normalizeChatUrl(String chatUrl) { if (chatUrl == null) { return null; } String normalized = chatUrl.trim(); if (normalized.isEmpty()) { return normalized; } if (normalized.endsWith("/chat/completions") || normalized.endsWith("/v1/chat/completions")) { return normalized; } if (normalized.endsWith("/v1")) { return normalized + "/chat/completions"; } if (normalized.contains("/v1/")) { return normalized; } if (normalized.endsWith("/")) { return normalized + "v1/chat/completions"; } return normalized + "/v1/chat/completions"; } private String readErrorMessage(InputStream inputStream, int statusCode) { try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) { StringBuilder builder = new StringBuilder(); String line; while ((line = reader.readLine()) != null) { builder.append(line); } String body = builder.toString(); if (body.isEmpty()) { return "模型服务调用失败,状态码:" + statusCode; } JsonNode root = objectMapper.readTree(body); JsonNode errorNode = root.path("error"); if (!errorNode.isMissingNode() && !errorNode.isNull()) { String message = errorNode.path("message").asText(""); if (!message.isEmpty()) { return message; } } if (root.path("message").isTextual()) { return root.path("message").asText(); } return body; } catch (Exception ignore) { return "模型服务调用失败,状态码:" + statusCode; } } }