package com.zy.ai.service;
|
|
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
import com.alibaba.fastjson.JSON;
|
import com.zy.ai.entity.ChatCompletionRequest;
|
import com.zy.ai.entity.LlmRouteConfig;
|
import lombok.RequiredArgsConstructor;
|
import lombok.extern.slf4j.Slf4j;
|
import org.springframework.http.HttpHeaders;
|
import org.springframework.http.MediaType;
|
import org.springframework.http.ResponseEntity;
|
import org.springframework.stereotype.Service;
|
import org.springframework.web.client.RestClient;
|
import org.springframework.web.client.RestClientResponseException;
|
|
import java.util.ArrayList;
|
import java.util.Comparator;
|
import java.util.Collections;
|
import java.util.Date;
|
import java.util.HashMap;
|
import java.util.LinkedHashMap;
|
import java.util.List;
|
import java.util.Map;
|
|
@Slf4j
|
@Service
|
@RequiredArgsConstructor
|
public class LlmRoutingService {
|
|
private static final long CACHE_TTL_MS = 3000L;
|
private static final String DEFAULT_PROVIDER_TYPE = "OPENAI_COMPATIBLE";
|
private static final String DEFAULT_PROTOCOL_TYPE = "OPENAI_CHAT_COMPLETIONS";
|
private static final String DEFAULT_AUTH_TYPE = "BEARER";
|
|
private final LlmRouteConfigService llmRouteConfigService;
|
private final LlmSpringAiClientService llmSpringAiClientService;
|
private final RestClient routeTestClient = RestClient.builder().build();
|
|
private volatile List<LlmRouteConfig> allRouteCache = Collections.emptyList();
|
private volatile long cacheExpireAt = 0L;
|
private static final Comparator<LlmRouteConfig> ROUTE_ORDER = (a, b) -> {
|
int pa = a == null || a.getPriority() == null ? Integer.MAX_VALUE : a.getPriority();
|
int pb = b == null || b.getPriority() == null ? Integer.MAX_VALUE : b.getPriority();
|
if (pa != pb) return Integer.compare(pa, pb);
|
long ia = a == null || a.getId() == null ? Long.MAX_VALUE : a.getId();
|
long ib = b == null || b.getId() == null ? Long.MAX_VALUE : b.getId();
|
return Long.compare(ia, ib);
|
};
|
|
public void evictCache() {
|
cacheExpireAt = 0L;
|
}
|
|
public List<LlmRouteConfig> listAllOrdered() {
|
return new ArrayList<>(loadAllRoutes());
|
}
|
|
public List<LlmRouteConfig> listAvailableRoutes() {
|
Date now = new Date();
|
List<LlmRouteConfig> result = new ArrayList<>();
|
List<LlmRouteConfig> coolingRoutes = new ArrayList<>();
|
int total = 0;
|
int disabled = 0;
|
int invalid = 0;
|
for (LlmRouteConfig c : loadAllRoutes()) {
|
total++;
|
if (!isEnabled(c)) {
|
disabled++;
|
continue;
|
}
|
if (isBlank(c.getBaseUrl()) || isBlank(c.getApiKey()) || isBlank(c.getModel())) {
|
invalid++;
|
continue;
|
}
|
if (isCooling(c, now)) {
|
coolingRoutes.add(c);
|
continue;
|
}
|
result.add(c);
|
}
|
if (result.isEmpty() && !coolingRoutes.isEmpty()) {
|
// 避免所有路由都处于冷却时系统完全不可用,降级允许使用冷却路由
|
coolingRoutes.sort(ROUTE_ORDER);
|
log.warn("LLM 路由均处于冷却,降级启用冷却路由。cooling={}, total={}", coolingRoutes.size(), total);
|
return coolingRoutes;
|
}
|
result.sort(ROUTE_ORDER);
|
if (result.isEmpty()) {
|
log.warn("未找到可用 LLM 路由。total={}, disabled={}, invalid={}", total, disabled, invalid);
|
}
|
return result;
|
}
|
|
public void markSuccess(Long routeId) {
|
if (routeId == null) return;
|
try {
|
LlmRouteConfig db = llmRouteConfigService.getById(routeId);
|
if (db == null) return;
|
db.setSuccessCount(nvl(db.getSuccessCount()) + 1);
|
db.setConsecutiveFailCount(0);
|
db.setLastUsedTime(new Date());
|
db.setUpdateTime(new Date());
|
llmRouteConfigService.updateById(db);
|
evictCache();
|
} catch (Exception e) {
|
log.warn("更新路由成功状态失败, routeId={}", routeId, e);
|
}
|
}
|
|
public void markFailure(Long routeId, String errorText, boolean enterCooldown, Integer cooldownSeconds) {
|
if (routeId == null) return;
|
try {
|
LlmRouteConfig db = llmRouteConfigService.getById(routeId);
|
if (db == null) return;
|
Date now = new Date();
|
db.setFailCount(nvl(db.getFailCount()) + 1);
|
db.setConsecutiveFailCount(nvl(db.getConsecutiveFailCount()) + 1);
|
db.setLastFailTime(now);
|
db.setLastError(trimError(errorText));
|
if (enterCooldown) {
|
int sec = cooldownSeconds != null && cooldownSeconds > 0
|
? cooldownSeconds
|
: defaultCooldown(db.getCooldownSeconds());
|
db.setCooldownUntil(new Date(now.getTime() + sec * 1000L));
|
}
|
db.setUpdateTime(now);
|
llmRouteConfigService.updateById(db);
|
evictCache();
|
} catch (Exception e) {
|
log.warn("更新路由失败状态失败, routeId={}", routeId, e);
|
}
|
}
|
|
private int defaultCooldown(Integer sec) {
|
return sec == null || sec <= 0 ? 300 : sec;
|
}
|
|
private String trimError(String err) {
|
if (err == null) return null;
|
String x = err.replace("\n", " ").replace("\r", " ");
|
return x.length() > 500 ? x.substring(0, 500) : x;
|
}
|
|
private Integer nvl(Integer x) {
|
return x == null ? 0 : x;
|
}
|
|
private boolean isEnabled(LlmRouteConfig c) {
|
return c != null && c.getStatus() != null && c.getStatus() == 1;
|
}
|
|
private boolean isCooling(LlmRouteConfig c, Date now) {
|
return c != null && c.getCooldownUntil() != null && c.getCooldownUntil().after(now);
|
}
|
|
private List<LlmRouteConfig> loadAllRoutes() {
|
long now = System.currentTimeMillis();
|
if (now < cacheExpireAt && allRouteCache != null) {
|
return allRouteCache;
|
}
|
synchronized (this) {
|
now = System.currentTimeMillis();
|
if (now < cacheExpireAt && allRouteCache != null) {
|
return allRouteCache;
|
}
|
QueryWrapper<LlmRouteConfig> wrapper = new QueryWrapper<>();
|
wrapper.orderBy(true, true, "priority").orderBy(true, true, "id");
|
List<LlmRouteConfig> list = llmRouteConfigService.list(wrapper);
|
if (list == null) {
|
allRouteCache = Collections.emptyList();
|
} else {
|
list.sort(ROUTE_ORDER);
|
allRouteCache = list;
|
}
|
cacheExpireAt = System.currentTimeMillis() + CACHE_TTL_MS;
|
return allRouteCache;
|
}
|
}
|
|
private String safe(String s) {
|
return s == null ? "" : s.trim();
|
}
|
|
private boolean isBlank(String s) {
|
return s == null || s.trim().isEmpty();
|
}
|
|
public LlmRouteConfig fillAndNormalize(LlmRouteConfig cfg, boolean isCreate) {
|
Date now = new Date();
|
trimStringFields(cfg);
|
normalizeProtocolFields(cfg);
|
if (isBlank(cfg.getName())) {
|
cfg.setName("LLM_ROUTE_" + now.getTime());
|
}
|
if (cfg.getThinking() == null) {
|
cfg.setThinking((short) 0);
|
}
|
if (cfg.getPriority() == null) {
|
cfg.setPriority(100);
|
}
|
if (cfg.getStatus() == null) {
|
cfg.setStatus((short) 1);
|
}
|
if (cfg.getSwitchOnQuota() == null) {
|
cfg.setSwitchOnQuota((short) 1);
|
}
|
if (cfg.getSwitchOnError() == null) {
|
cfg.setSwitchOnError((short) 1);
|
}
|
if (cfg.getCooldownSeconds() == null || cfg.getCooldownSeconds() < 0) {
|
cfg.setCooldownSeconds(300);
|
}
|
if (cfg.getFailCount() == null) {
|
cfg.setFailCount(0);
|
}
|
if (cfg.getSuccessCount() == null) {
|
cfg.setSuccessCount(0);
|
}
|
if (cfg.getConsecutiveFailCount() == null) {
|
cfg.setConsecutiveFailCount(0);
|
}
|
if (isCreate) {
|
cfg.setCreateTime(now);
|
}
|
cfg.setUpdateTime(now);
|
return cfg;
|
}
|
|
public Map<String, Object> testRoute(LlmRouteConfig cfg) {
|
HashMap<String, Object> result = new HashMap<>();
|
long start = System.currentTimeMillis();
|
try {
|
trimStringFields(cfg);
|
normalizeProtocolFields(cfg);
|
TestHttpResult raw = testJavaRoute(cfg);
|
fillTestResult(result, raw, start);
|
} catch (Exception e) {
|
result.put("ok", false);
|
result.put("statusCode", -1);
|
result.put("latencyMs", System.currentTimeMillis() - start);
|
result.put("message", "测试异常: " + safe(e.getMessage()));
|
result.put("responseSnippet", "");
|
}
|
fillProtocolResult(result, cfg);
|
return result;
|
}
|
|
private void trimStringFields(LlmRouteConfig cfg) {
|
if (cfg == null) return;
|
cfg.setName(trimToNull(cfg.getName()));
|
cfg.setProviderType(trimToNull(cfg.getProviderType()));
|
cfg.setProtocolType(trimToNull(cfg.getProtocolType()));
|
cfg.setBaseUrl(trimToNull(cfg.getBaseUrl()));
|
cfg.setEndpointPath(trimToNull(cfg.getEndpointPath()));
|
cfg.setApiKey(trimToNull(cfg.getApiKey()));
|
cfg.setAuthType(trimToNull(cfg.getAuthType()));
|
cfg.setAuthHeaderName(trimToNull(cfg.getAuthHeaderName()));
|
cfg.setModel(trimToNull(cfg.getModel()));
|
cfg.setCapabilities(trimToNull(cfg.getCapabilities()));
|
cfg.setRequestOptions(trimToNull(cfg.getRequestOptions()));
|
cfg.setMemo(trimToNull(cfg.getMemo()));
|
}
|
|
private void normalizeProtocolFields(LlmRouteConfig cfg) {
|
if (cfg == null) return;
|
if (isBlank(cfg.getProviderType())) {
|
cfg.setProviderType(DEFAULT_PROVIDER_TYPE);
|
}
|
if (isBlank(cfg.getProtocolType())) {
|
cfg.setProtocolType(DEFAULT_PROTOCOL_TYPE);
|
}
|
if (isBlank(cfg.getAuthType())) {
|
cfg.setAuthType(DEFAULT_AUTH_TYPE);
|
}
|
}
|
|
private void fillProtocolResult(HashMap<String, Object> result, LlmRouteConfig cfg) {
|
if (cfg == null) return;
|
result.put("providerType", cfg.getProviderType());
|
result.put("protocolType", cfg.getProtocolType());
|
result.put("endpointPath", cfg.getEndpointPath());
|
result.put("authType", cfg.getAuthType());
|
result.put("authHeaderName", cfg.getAuthHeaderName());
|
}
|
|
private String trimToNull(String s) {
|
if (s == null) return null;
|
String trimmed = s.trim();
|
return trimmed.isEmpty() ? null : trimmed;
|
}
|
|
private void fillTestResult(HashMap<String, Object> result, TestHttpResult raw, long start) {
|
boolean ok = raw.statusCode >= 200 && raw.statusCode < 300;
|
result.put("ok", ok);
|
result.put("statusCode", raw.statusCode);
|
result.put("latencyMs", System.currentTimeMillis() - start);
|
result.put("message", ok ? "测试成功" : "测试失败");
|
result.put("responseSnippet", trimBody(raw.body));
|
}
|
|
private TestHttpResult testJavaRoute(LlmRouteConfig cfg) {
|
if ("OPENAI_RESPONSES".equalsIgnoreCase(safe(cfg.getProtocolType()))) {
|
return testResponsesRoute(cfg);
|
}
|
|
ChatCompletionRequest req = new ChatCompletionRequest();
|
req.setModel(cfg.getModel());
|
List<ChatCompletionRequest.Message> messages = new ArrayList<>();
|
ChatCompletionRequest.Message msg = new ChatCompletionRequest.Message();
|
msg.setRole("user");
|
msg.setContent("ping");
|
messages.add(msg);
|
req.setMessages(messages);
|
req.setStream(false);
|
req.setMax_tokens(8);
|
req.setTemperature(0D);
|
if (cfg.getThinking() != null && cfg.getThinking() == 1) {
|
ChatCompletionRequest.Thinking thinking = new ChatCompletionRequest.Thinking();
|
thinking.setType("enable");
|
req.setThinking(thinking);
|
}
|
try {
|
LlmSpringAiClientService.CompletionCallResult result =
|
llmSpringAiClientService.callCompletion(cfg.getBaseUrl(), cfg.getApiKey(), req);
|
return new TestHttpResult(result.getStatusCode(), result.getPayload());
|
} catch (Throwable ex) {
|
Integer statusCode = llmSpringAiClientService.statusCodeOf(ex);
|
String body = llmSpringAiClientService.responseBodyOf(ex, 300);
|
return new TestHttpResult(statusCode == null ? -1 : statusCode, safe(body != null ? body : ex.getMessage()));
|
}
|
}
|
|
private TestHttpResult testResponsesRoute(LlmRouteConfig cfg) {
|
LinkedHashMap<String, Object> body = new LinkedHashMap<>();
|
body.put("model", cfg.getModel());
|
body.put("input", "ping");
|
body.put("max_output_tokens", 8);
|
try {
|
ResponseEntity<String> response = routeTestClient.post()
|
.uri(responsesEndpointUrl(cfg))
|
.contentType(MediaType.APPLICATION_JSON)
|
.headers(headers -> applyRouteAuth(headers, cfg))
|
.body(JSON.toJSONString(body))
|
.retrieve()
|
.toEntity(String.class);
|
return new TestHttpResult(response.getStatusCode().value(), response.getBody());
|
} catch (RestClientResponseException ex) {
|
return new TestHttpResult(ex.getStatusCode().value(), trimBody(ex.getResponseBodyAsString()));
|
} catch (Throwable ex) {
|
return new TestHttpResult(-1, safe(ex.getMessage()));
|
}
|
}
|
|
private String responsesEndpointUrl(LlmRouteConfig cfg) {
|
String baseUrl = safe(cfg.getBaseUrl());
|
String endpointPath = safe(cfg.getEndpointPath());
|
if (endpointPath.isEmpty()) {
|
endpointPath = "/responses";
|
}
|
if (baseUrl.endsWith("/") && endpointPath.startsWith("/")) {
|
return baseUrl + endpointPath.substring(1);
|
}
|
if (!baseUrl.endsWith("/") && !endpointPath.startsWith("/")) {
|
return baseUrl + "/" + endpointPath;
|
}
|
return baseUrl + endpointPath;
|
}
|
|
private void applyRouteAuth(HttpHeaders headers, LlmRouteConfig cfg) {
|
String authType = safe(cfg.getAuthType()).toUpperCase();
|
if ("NONE".equals(authType) || isBlank(cfg.getApiKey())) {
|
return;
|
}
|
if ("API_KEY".equals(authType)) {
|
String headerName = safe(cfg.getAuthHeaderName());
|
headers.set(headerName.isEmpty() ? "X-API-Key" : headerName, cfg.getApiKey());
|
return;
|
}
|
headers.setBearerAuth(cfg.getApiKey());
|
}
|
|
private String trimBody(String body) {
|
String x = safe(body).replace("\r", " ").replace("\n", " ");
|
return x.length() > 300 ? x.substring(0, 300) : x;
|
}
|
|
private static class TestHttpResult {
|
private final int statusCode;
|
private final String body;
|
|
private TestHttpResult(int statusCode, String body) {
|
this.statusCode = statusCode;
|
this.body = body;
|
}
|
}
|
}
|