package com.vincent.rsf.ai.gateway.service;
|
|
import com.vincent.rsf.ai.gateway.config.AiGatewayProperties;
|
import com.vincent.rsf.ai.gateway.dto.GatewayChatMessage;
|
import com.vincent.rsf.ai.gateway.dto.GatewayChatRequest;
|
import com.fasterxml.jackson.databind.JsonNode;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
import org.springframework.stereotype.Service;
|
|
import javax.annotation.Resource;
|
import java.io.BufferedReader;
|
import java.io.InputStream;
|
import java.io.InputStreamReader;
|
import java.io.OutputStream;
|
import java.net.HttpURLConnection;
|
import java.net.URL;
|
import java.nio.charset.StandardCharsets;
|
import java.util.ArrayList;
|
import java.util.LinkedHashMap;
|
import java.util.List;
|
import java.util.Map;
|
|
@Service
|
public class AiGatewayService {
|
|
@Resource
|
private AiGatewayProperties aiGatewayProperties;
|
@Resource
|
private ObjectMapper objectMapper;
|
|
public interface EventConsumer {
|
void accept(GatewayStreamEvent event) throws Exception;
|
}
|
|
public void stream(GatewayChatRequest request, EventConsumer consumer) throws Exception {
|
AiGatewayProperties.ModelConfig modelConfig = resolveModel(request);
|
if (modelConfig == null || modelConfig.getChatUrl() == null || modelConfig.getChatUrl().trim().isEmpty()) {
|
mockStream(request, modelConfig, consumer);
|
return;
|
}
|
openAiCompatibleStream(request, modelConfig, consumer);
|
}
|
|
private AiGatewayProperties.ModelConfig resolveModel(GatewayChatRequest request) {
|
String modelCode = request.getModelCode();
|
String targetCode = (modelCode == null || modelCode.trim().isEmpty())
|
? aiGatewayProperties.getDefaultModelCode()
|
: modelCode;
|
for (AiGatewayProperties.ModelConfig model : aiGatewayProperties.getModels()) {
|
if (Boolean.TRUE.equals(model.getEnabled()) && targetCode.equals(model.getCode())) {
|
return mergeRequestOverride(model, request);
|
}
|
}
|
if ((request.getChatUrl() != null && !request.getChatUrl().trim().isEmpty())
|
|| (request.getModelName() != null && !request.getModelName().trim().isEmpty())) {
|
AiGatewayProperties.ModelConfig modelConfig = new AiGatewayProperties.ModelConfig();
|
modelConfig.setCode(targetCode);
|
modelConfig.setName(targetCode);
|
modelConfig.setProvider("custom");
|
modelConfig.setEnabled(true);
|
return mergeRequestOverride(modelConfig, request);
|
}
|
return null;
|
}
|
|
private AiGatewayProperties.ModelConfig mergeRequestOverride(AiGatewayProperties.ModelConfig source,
|
GatewayChatRequest request) {
|
AiGatewayProperties.ModelConfig target = new AiGatewayProperties.ModelConfig();
|
target.setCode(source.getCode());
|
target.setName(source.getName());
|
target.setProvider(source.getProvider());
|
target.setChatUrl(normalizeChatUrl(source.getChatUrl()));
|
target.setApiKey(source.getApiKey());
|
target.setModelName(source.getModelName());
|
target.setEnabled(source.getEnabled());
|
if (request.getChatUrl() != null && !request.getChatUrl().trim().isEmpty()) {
|
target.setChatUrl(normalizeChatUrl(request.getChatUrl().trim()));
|
}
|
if (request.getApiKey() != null && !request.getApiKey().trim().isEmpty()) {
|
target.setApiKey(request.getApiKey().trim());
|
}
|
if (request.getModelName() != null && !request.getModelName().trim().isEmpty()) {
|
target.setModelName(request.getModelName().trim());
|
}
|
return target;
|
}
|
|
private void mockStream(GatewayChatRequest request, AiGatewayProperties.ModelConfig modelConfig,
|
EventConsumer consumer) throws Exception {
|
String modelCode = modelConfig == null ? aiGatewayProperties.getDefaultModelCode() : modelConfig.getCode();
|
String lastQuestion = "";
|
List<GatewayChatMessage> messages = request.getMessages();
|
for (int i = messages.size() - 1; i >= 0; i--) {
|
GatewayChatMessage message = messages.get(i);
|
if ("user".equalsIgnoreCase(message.getRole())) {
|
lastQuestion = message.getContent();
|
break;
|
}
|
}
|
String answer = "当前为演示模式,模型[" + modelCode + "]已收到你的问题:" + lastQuestion;
|
for (char c : answer.toCharArray()) {
|
consumer.accept(new GatewayStreamEvent()
|
.setType("delta")
|
.setModelCode(modelCode)
|
.setContent(String.valueOf(c)));
|
Thread.sleep(20L);
|
}
|
consumer.accept(new GatewayStreamEvent()
|
.setType("done")
|
.setModelCode(modelCode));
|
}
|
|
private void openAiCompatibleStream(GatewayChatRequest request, AiGatewayProperties.ModelConfig modelConfig,
|
EventConsumer consumer) throws Exception {
|
HttpURLConnection connection = null;
|
try {
|
connection = (HttpURLConnection) new URL(modelConfig.getChatUrl()).openConnection();
|
connection.setConnectTimeout(aiGatewayProperties.getConnectTimeoutMillis());
|
connection.setReadTimeout(aiGatewayProperties.getReadTimeoutMillis());
|
connection.setRequestMethod("POST");
|
connection.setDoOutput(true);
|
connection.setRequestProperty("Content-Type", "application/json");
|
connection.setRequestProperty("Accept", "text/event-stream");
|
if (modelConfig.getApiKey() != null && !modelConfig.getApiKey().trim().isEmpty()) {
|
connection.setRequestProperty("Authorization", "Bearer " + modelConfig.getApiKey().trim());
|
}
|
|
Map<String, Object> body = new LinkedHashMap<>();
|
body.put("model", modelConfig.getModelName());
|
body.put("stream", true);
|
body.put("messages", buildMessages(request));
|
|
try (OutputStream outputStream = connection.getOutputStream()) {
|
outputStream.write(objectMapper.writeValueAsBytes(body));
|
outputStream.flush();
|
}
|
|
int statusCode = connection.getResponseCode();
|
InputStream inputStream = statusCode >= 400 ? connection.getErrorStream() : connection.getInputStream();
|
if (inputStream == null) {
|
consumer.accept(new GatewayStreamEvent()
|
.setType("error")
|
.setModelCode(modelConfig.getCode())
|
.setMessage("模型服务无响应"));
|
return;
|
}
|
if (statusCode >= 400) {
|
consumer.accept(new GatewayStreamEvent()
|
.setType("error")
|
.setModelCode(modelConfig.getCode())
|
.setMessage(readErrorMessage(inputStream, statusCode)));
|
return;
|
}
|
|
try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) {
|
String line;
|
while ((line = reader.readLine()) != null) {
|
if (line.trim().isEmpty() || !line.startsWith("data:")) {
|
continue;
|
}
|
String payload = line.substring(5).trim();
|
if ("[DONE]".equals(payload)) {
|
consumer.accept(new GatewayStreamEvent()
|
.setType("done")
|
.setModelCode(modelConfig.getCode()));
|
break;
|
}
|
JsonNode root = objectMapper.readTree(payload);
|
JsonNode choice = root.path("choices").path(0);
|
JsonNode delta = choice.path("delta");
|
JsonNode contentNode = delta.path("content");
|
if (!contentNode.isMissingNode() && !contentNode.isNull()) {
|
consumer.accept(new GatewayStreamEvent()
|
.setType("delta")
|
.setModelCode(modelConfig.getCode())
|
.setContent(contentNode.asText()));
|
}
|
JsonNode finishReason = choice.path("finish_reason");
|
if (!finishReason.isMissingNode() && !finishReason.isNull()) {
|
consumer.accept(new GatewayStreamEvent()
|
.setType("done")
|
.setModelCode(modelConfig.getCode()));
|
break;
|
}
|
}
|
}
|
} catch (Exception e) {
|
consumer.accept(new GatewayStreamEvent()
|
.setType("error")
|
.setModelCode(modelConfig.getCode())
|
.setMessage(e.getMessage()));
|
} finally {
|
if (connection != null) {
|
connection.disconnect();
|
}
|
}
|
}
|
|
private List<Map<String, String>> buildMessages(GatewayChatRequest request) {
|
List<Map<String, String>> output = new ArrayList<>();
|
if (request.getSystemPrompt() != null && !request.getSystemPrompt().trim().isEmpty()) {
|
Map<String, String> systemMessage = new LinkedHashMap<>();
|
systemMessage.put("role", "system");
|
systemMessage.put("content", request.getSystemPrompt());
|
output.add(systemMessage);
|
}
|
for (GatewayChatMessage message : request.getMessages()) {
|
Map<String, String> item = new LinkedHashMap<>();
|
item.put("role", message.getRole());
|
item.put("content", message.getContent());
|
output.add(item);
|
}
|
return output;
|
}
|
|
private String normalizeChatUrl(String chatUrl) {
|
if (chatUrl == null) {
|
return null;
|
}
|
String normalized = chatUrl.trim();
|
if (normalized.isEmpty()) {
|
return normalized;
|
}
|
if (normalized.endsWith("/chat/completions") || normalized.endsWith("/v1/chat/completions")) {
|
return normalized;
|
}
|
if (normalized.endsWith("/v1")) {
|
return normalized + "/chat/completions";
|
}
|
if (normalized.contains("/v1/")) {
|
return normalized;
|
}
|
if (normalized.endsWith("/")) {
|
return normalized + "v1/chat/completions";
|
}
|
return normalized + "/v1/chat/completions";
|
}
|
|
private String readErrorMessage(InputStream inputStream, int statusCode) {
|
try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) {
|
StringBuilder builder = new StringBuilder();
|
String line;
|
while ((line = reader.readLine()) != null) {
|
builder.append(line);
|
}
|
String body = builder.toString();
|
if (body.isEmpty()) {
|
return "模型服务调用失败,状态码:" + statusCode;
|
}
|
JsonNode root = objectMapper.readTree(body);
|
JsonNode errorNode = root.path("error");
|
if (!errorNode.isMissingNode() && !errorNode.isNull()) {
|
String message = errorNode.path("message").asText("");
|
if (!message.isEmpty()) {
|
return message;
|
}
|
}
|
if (root.path("message").isTextual()) {
|
return root.path("message").asText();
|
}
|
return body;
|
} catch (Exception ignore) {
|
return "模型服务调用失败,状态码:" + statusCode;
|
}
|
}
|
|
}
|