| | |
| | | package com.zy.ai.service; |
| | | |
| | | import com.baomidou.mybatisplus.mapper.EntityWrapper; |
| | | import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; |
| | | import com.alibaba.fastjson.JSON; |
| | | import com.zy.ai.entity.ChatCompletionRequest; |
| | | import com.zy.ai.entity.LlmRouteConfig; |
| | | import lombok.RequiredArgsConstructor; |
| | | import lombok.extern.slf4j.Slf4j; |
| | | import org.springframework.http.HttpHeaders; |
| | | import org.springframework.http.MediaType; |
| | | import org.springframework.http.ResponseEntity; |
| | | import org.springframework.stereotype.Service; |
| | | import org.springframework.web.reactive.function.client.WebClient; |
| | | import reactor.core.publisher.Mono; |
| | | import org.springframework.web.client.RestClient; |
| | | import org.springframework.web.client.RestClientResponseException; |
| | | |
| | | import java.time.Duration; |
| | | import java.util.ArrayList; |
| | | import java.util.Comparator; |
| | | import java.util.Collections; |
| | | import java.util.Date; |
| | | import java.util.HashMap; |
| | | import java.util.LinkedHashMap; |
| | | import java.util.List; |
| | | import java.util.Map; |
| | | |
| | |
| | | public class LlmRoutingService { |
| | | |
| | | private static final long CACHE_TTL_MS = 3000L; |
| | | private static final String DEFAULT_PROVIDER_TYPE = "OPENAI_COMPATIBLE"; |
| | | private static final String DEFAULT_PROTOCOL_TYPE = "OPENAI_CHAT_COMPLETIONS"; |
| | | private static final String DEFAULT_AUTH_TYPE = "BEARER"; |
| | | |
| | | private final LlmRouteConfigService llmRouteConfigService; |
| | | private final LlmSpringAiClientService llmSpringAiClientService; |
| | | private final RestClient routeTestClient = RestClient.builder().build(); |
| | | |
| | | private volatile List<LlmRouteConfig> allRouteCache = Collections.emptyList(); |
| | | private volatile long cacheExpireAt = 0L; |
| | |
| | | public void markSuccess(Long routeId) { |
| | | if (routeId == null) return; |
| | | try { |
| | | LlmRouteConfig db = llmRouteConfigService.selectById(routeId); |
| | | LlmRouteConfig db = llmRouteConfigService.getById(routeId); |
| | | if (db == null) return; |
| | | db.setSuccessCount(nvl(db.getSuccessCount()) + 1); |
| | | db.setConsecutiveFailCount(0); |
| | |
| | | public void markFailure(Long routeId, String errorText, boolean enterCooldown, Integer cooldownSeconds) { |
| | | if (routeId == null) return; |
| | | try { |
| | | LlmRouteConfig db = llmRouteConfigService.selectById(routeId); |
| | | LlmRouteConfig db = llmRouteConfigService.getById(routeId); |
| | | if (db == null) return; |
| | | Date now = new Date(); |
| | | db.setFailCount(nvl(db.getFailCount()) + 1); |
| | |
| | | if (now < cacheExpireAt && allRouteCache != null) { |
| | | return allRouteCache; |
| | | } |
| | | EntityWrapper<LlmRouteConfig> wrapper = new EntityWrapper<>(); |
| | | wrapper.orderBy("priority", true).orderBy("id", true); |
| | | List<LlmRouteConfig> list = llmRouteConfigService.selectList(wrapper); |
| | | QueryWrapper<LlmRouteConfig> wrapper = new QueryWrapper<>(); |
| | | wrapper.orderBy(true, true, "priority").orderBy(true, true, "id"); |
| | | List<LlmRouteConfig> list = llmRouteConfigService.list(wrapper); |
| | | if (list == null) { |
| | | allRouteCache = Collections.emptyList(); |
| | | } else { |
| | |
| | | |
| | | public LlmRouteConfig fillAndNormalize(LlmRouteConfig cfg, boolean isCreate) { |
| | | Date now = new Date(); |
| | | trimStringFields(cfg); |
| | | normalizeProtocolFields(cfg); |
| | | if (isBlank(cfg.getName())) { |
| | | cfg.setName("LLM_ROUTE_" + now.getTime()); |
| | | } |
| | |
| | | HashMap<String, Object> result = new HashMap<>(); |
| | | long start = System.currentTimeMillis(); |
| | | try { |
| | | trimStringFields(cfg); |
| | | normalizeProtocolFields(cfg); |
| | | TestHttpResult raw = testJavaRoute(cfg); |
| | | fillTestResult(result, raw, start); |
| | | } catch (Exception e) { |
| | |
| | | result.put("message", "测试异常: " + safe(e.getMessage())); |
| | | result.put("responseSnippet", ""); |
| | | } |
| | | fillProtocolResult(result, cfg); |
| | | return result; |
| | | } |
| | | |
| | | private void trimStringFields(LlmRouteConfig cfg) { |
| | | if (cfg == null) return; |
| | | cfg.setName(trimToNull(cfg.getName())); |
| | | cfg.setProviderType(trimToNull(cfg.getProviderType())); |
| | | cfg.setProtocolType(trimToNull(cfg.getProtocolType())); |
| | | cfg.setBaseUrl(trimToNull(cfg.getBaseUrl())); |
| | | cfg.setEndpointPath(trimToNull(cfg.getEndpointPath())); |
| | | cfg.setApiKey(trimToNull(cfg.getApiKey())); |
| | | cfg.setAuthType(trimToNull(cfg.getAuthType())); |
| | | cfg.setAuthHeaderName(trimToNull(cfg.getAuthHeaderName())); |
| | | cfg.setModel(trimToNull(cfg.getModel())); |
| | | cfg.setCapabilities(trimToNull(cfg.getCapabilities())); |
| | | cfg.setRequestOptions(trimToNull(cfg.getRequestOptions())); |
| | | cfg.setMemo(trimToNull(cfg.getMemo())); |
| | | } |
| | | |
| | | private void normalizeProtocolFields(LlmRouteConfig cfg) { |
| | | if (cfg == null) return; |
| | | if (isBlank(cfg.getProviderType())) { |
| | | cfg.setProviderType(DEFAULT_PROVIDER_TYPE); |
| | | } |
| | | if (isBlank(cfg.getProtocolType())) { |
| | | cfg.setProtocolType(DEFAULT_PROTOCOL_TYPE); |
| | | } |
| | | if (isBlank(cfg.getAuthType())) { |
| | | cfg.setAuthType(DEFAULT_AUTH_TYPE); |
| | | } |
| | | } |
| | | |
| | | private void fillProtocolResult(HashMap<String, Object> result, LlmRouteConfig cfg) { |
| | | if (cfg == null) return; |
| | | result.put("providerType", cfg.getProviderType()); |
| | | result.put("protocolType", cfg.getProtocolType()); |
| | | result.put("endpointPath", cfg.getEndpointPath()); |
| | | result.put("authType", cfg.getAuthType()); |
| | | result.put("authHeaderName", cfg.getAuthHeaderName()); |
| | | } |
| | | |
| | | private String trimToNull(String s) { |
| | | if (s == null) return null; |
| | | String trimmed = s.trim(); |
| | | return trimmed.isEmpty() ? null : trimmed; |
| | | } |
| | | |
| | | private void fillTestResult(HashMap<String, Object> result, TestHttpResult raw, long start) { |
| | |
| | | } |
| | | |
| | | private TestHttpResult testJavaRoute(LlmRouteConfig cfg) { |
| | | HashMap<String, Object> req = new HashMap<>(); |
| | | req.put("model", cfg.getModel()); |
| | | List<Map<String, String>> messages = new ArrayList<>(); |
| | | HashMap<String, String> msg = new HashMap<>(); |
| | | msg.put("role", "user"); |
| | | msg.put("content", "ping"); |
| | | messages.add(msg); |
| | | req.put("messages", messages); |
| | | req.put("stream", false); |
| | | req.put("max_tokens", 8); |
| | | req.put("temperature", 0); |
| | | if ("OPENAI_RESPONSES".equalsIgnoreCase(safe(cfg.getProtocolType()))) { |
| | | return testResponsesRoute(cfg); |
| | | } |
| | | |
| | | WebClient client = WebClient.builder().baseUrl(cfg.getBaseUrl()).build(); |
| | | return client.post() |
| | | .uri("/chat/completions") |
| | | .header(HttpHeaders.AUTHORIZATION, "Bearer " + cfg.getApiKey()) |
| | | .contentType(MediaType.APPLICATION_JSON) |
| | | .accept(MediaType.APPLICATION_JSON, MediaType.TEXT_EVENT_STREAM) |
| | | .bodyValue(req) |
| | | .exchangeToMono(resp -> resp.bodyToMono(String.class) |
| | | .defaultIfEmpty("") |
| | | .map(body -> new TestHttpResult(resp.rawStatusCode(), body))) |
| | | .timeout(Duration.ofSeconds(12)) |
| | | .onErrorResume(ex -> Mono.just(new TestHttpResult(-1, safe(ex.getMessage())))) |
| | | .block(); |
| | | ChatCompletionRequest req = new ChatCompletionRequest(); |
| | | req.setModel(cfg.getModel()); |
| | | List<ChatCompletionRequest.Message> messages = new ArrayList<>(); |
| | | ChatCompletionRequest.Message msg = new ChatCompletionRequest.Message(); |
| | | msg.setRole("user"); |
| | | msg.setContent("ping"); |
| | | messages.add(msg); |
| | | req.setMessages(messages); |
| | | req.setStream(false); |
| | | req.setMax_tokens(8); |
| | | req.setTemperature(0D); |
| | | if (cfg.getThinking() != null && cfg.getThinking() == 1) { |
| | | ChatCompletionRequest.Thinking thinking = new ChatCompletionRequest.Thinking(); |
| | | thinking.setType("enable"); |
| | | req.setThinking(thinking); |
| | | } |
| | | try { |
| | | LlmSpringAiClientService.CompletionCallResult result = |
| | | llmSpringAiClientService.callCompletion(cfg.getBaseUrl(), cfg.getApiKey(), req); |
| | | return new TestHttpResult(result.getStatusCode(), result.getPayload()); |
| | | } catch (Throwable ex) { |
| | | Integer statusCode = llmSpringAiClientService.statusCodeOf(ex); |
| | | String body = llmSpringAiClientService.responseBodyOf(ex, 300); |
| | | return new TestHttpResult(statusCode == null ? -1 : statusCode, safe(body != null ? body : ex.getMessage())); |
| | | } |
| | | } |
| | | |
| | | private TestHttpResult testResponsesRoute(LlmRouteConfig cfg) { |
| | | LinkedHashMap<String, Object> body = new LinkedHashMap<>(); |
| | | body.put("model", cfg.getModel()); |
| | | body.put("input", "ping"); |
| | | body.put("max_output_tokens", 8); |
| | | try { |
| | | ResponseEntity<String> response = routeTestClient.post() |
| | | .uri(responsesEndpointUrl(cfg)) |
| | | .contentType(MediaType.APPLICATION_JSON) |
| | | .headers(headers -> applyRouteAuth(headers, cfg)) |
| | | .body(JSON.toJSONString(body)) |
| | | .retrieve() |
| | | .toEntity(String.class); |
| | | return new TestHttpResult(response.getStatusCode().value(), response.getBody()); |
| | | } catch (RestClientResponseException ex) { |
| | | return new TestHttpResult(ex.getStatusCode().value(), trimBody(ex.getResponseBodyAsString())); |
| | | } catch (Throwable ex) { |
| | | return new TestHttpResult(-1, safe(ex.getMessage())); |
| | | } |
| | | } |
| | | |
| | | private String responsesEndpointUrl(LlmRouteConfig cfg) { |
| | | String baseUrl = safe(cfg.getBaseUrl()); |
| | | String endpointPath = safe(cfg.getEndpointPath()); |
| | | if (endpointPath.isEmpty()) { |
| | | endpointPath = "/responses"; |
| | | } |
| | | if (baseUrl.endsWith("/") && endpointPath.startsWith("/")) { |
| | | return baseUrl + endpointPath.substring(1); |
| | | } |
| | | if (!baseUrl.endsWith("/") && !endpointPath.startsWith("/")) { |
| | | return baseUrl + "/" + endpointPath; |
| | | } |
| | | return baseUrl + endpointPath; |
| | | } |
| | | |
| | | private void applyRouteAuth(HttpHeaders headers, LlmRouteConfig cfg) { |
| | | String authType = safe(cfg.getAuthType()).toUpperCase(); |
| | | if ("NONE".equals(authType) || isBlank(cfg.getApiKey())) { |
| | | return; |
| | | } |
| | | if ("API_KEY".equals(authType)) { |
| | | String headerName = safe(cfg.getAuthHeaderName()); |
| | | headers.set(headerName.isEmpty() ? "X-API-Key" : headerName, cfg.getApiKey()); |
| | | return; |
| | | } |
| | | headers.setBearerAuth(cfg.getApiKey()); |
| | | } |
| | | |
| | | private String trimBody(String body) { |