Junjie
3 天以前 63b01db83d9aad8a15276b4236a9a22e4aeef065
src/main/java/com/zy/ai/service/LlmRoutingService.java
@@ -1,21 +1,24 @@
package com.zy.ai.service;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.alibaba.fastjson.JSON;
import com.zy.ai.entity.ChatCompletionRequest;
import com.zy.ai.entity.LlmRouteConfig;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Service;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Mono;
import org.springframework.web.client.RestClient;
import org.springframework.web.client.RestClientResponseException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
@@ -25,8 +28,13 @@
public class LlmRoutingService {
    private static final long CACHE_TTL_MS = 3000L;
    private static final String DEFAULT_PROVIDER_TYPE = "OPENAI_COMPATIBLE";
    private static final String DEFAULT_PROTOCOL_TYPE = "OPENAI_CHAT_COMPLETIONS";
    private static final String DEFAULT_AUTH_TYPE = "BEARER";
    private final LlmRouteConfigService llmRouteConfigService;
    private final LlmSpringAiClientService llmSpringAiClientService;
    private final RestClient routeTestClient = RestClient.builder().build();
    private volatile List<LlmRouteConfig> allRouteCache = Collections.emptyList();
    private volatile long cacheExpireAt = 0L;
@@ -179,6 +187,8 @@
    public LlmRouteConfig fillAndNormalize(LlmRouteConfig cfg, boolean isCreate) {
        Date now = new Date();
        trimStringFields(cfg);
        normalizeProtocolFields(cfg);
        if (isBlank(cfg.getName())) {
            cfg.setName("LLM_ROUTE_" + now.getTime());
        }
@@ -220,6 +230,8 @@
        HashMap<String, Object> result = new HashMap<>();
        long start = System.currentTimeMillis();
        try {
            trimStringFields(cfg);
            normalizeProtocolFields(cfg);
            TestHttpResult raw = testJavaRoute(cfg);
            fillTestResult(result, raw, start);
        } catch (Exception e) {
@@ -229,7 +241,52 @@
            result.put("message", "测试异常: " + safe(e.getMessage()));
            result.put("responseSnippet", "");
        }
        fillProtocolResult(result, cfg);
        return result;
    }
    private void trimStringFields(LlmRouteConfig cfg) {
        if (cfg == null) return;
        cfg.setName(trimToNull(cfg.getName()));
        cfg.setProviderType(trimToNull(cfg.getProviderType()));
        cfg.setProtocolType(trimToNull(cfg.getProtocolType()));
        cfg.setBaseUrl(trimToNull(cfg.getBaseUrl()));
        cfg.setEndpointPath(trimToNull(cfg.getEndpointPath()));
        cfg.setApiKey(trimToNull(cfg.getApiKey()));
        cfg.setAuthType(trimToNull(cfg.getAuthType()));
        cfg.setAuthHeaderName(trimToNull(cfg.getAuthHeaderName()));
        cfg.setModel(trimToNull(cfg.getModel()));
        cfg.setCapabilities(trimToNull(cfg.getCapabilities()));
        cfg.setRequestOptions(trimToNull(cfg.getRequestOptions()));
        cfg.setMemo(trimToNull(cfg.getMemo()));
    }
    private void normalizeProtocolFields(LlmRouteConfig cfg) {
        if (cfg == null) return;
        if (isBlank(cfg.getProviderType())) {
            cfg.setProviderType(DEFAULT_PROVIDER_TYPE);
        }
        if (isBlank(cfg.getProtocolType())) {
            cfg.setProtocolType(DEFAULT_PROTOCOL_TYPE);
        }
        if (isBlank(cfg.getAuthType())) {
            cfg.setAuthType(DEFAULT_AUTH_TYPE);
        }
    }
    private void fillProtocolResult(HashMap<String, Object> result, LlmRouteConfig cfg) {
        if (cfg == null) return;
        result.put("providerType", cfg.getProviderType());
        result.put("protocolType", cfg.getProtocolType());
        result.put("endpointPath", cfg.getEndpointPath());
        result.put("authType", cfg.getAuthType());
        result.put("authHeaderName", cfg.getAuthHeaderName());
    }
    private String trimToNull(String s) {
        if (s == null) return null;
        String trimmed = s.trim();
        return trimmed.isEmpty() ? null : trimmed;
    }
    private void fillTestResult(HashMap<String, Object> result, TestHttpResult raw, long start) {
@@ -242,31 +299,84 @@
    }
    private TestHttpResult testJavaRoute(LlmRouteConfig cfg) {
        HashMap<String, Object> req = new HashMap<>();
        req.put("model", cfg.getModel());
        List<Map<String, String>> messages = new ArrayList<>();
        HashMap<String, String> msg = new HashMap<>();
        msg.put("role", "user");
        msg.put("content", "ping");
        messages.add(msg);
        req.put("messages", messages);
        req.put("stream", false);
        req.put("max_tokens", 8);
        req.put("temperature", 0);
        if ("OPENAI_RESPONSES".equalsIgnoreCase(safe(cfg.getProtocolType()))) {
            return testResponsesRoute(cfg);
        }
        WebClient client = WebClient.builder().baseUrl(cfg.getBaseUrl()).build();
        return client.post()
                .uri("/chat/completions")
                .header(HttpHeaders.AUTHORIZATION, "Bearer " + cfg.getApiKey())
                .contentType(MediaType.APPLICATION_JSON)
                .accept(MediaType.APPLICATION_JSON, MediaType.TEXT_EVENT_STREAM)
                .bodyValue(req)
                .exchangeToMono(resp -> resp.bodyToMono(String.class)
                        .defaultIfEmpty("")
                        .map(body -> new TestHttpResult(resp.rawStatusCode(), body)))
                .timeout(Duration.ofSeconds(12))
                .onErrorResume(ex -> Mono.just(new TestHttpResult(-1, safe(ex.getMessage()))))
                .block();
        ChatCompletionRequest req = new ChatCompletionRequest();
        req.setModel(cfg.getModel());
        List<ChatCompletionRequest.Message> messages = new ArrayList<>();
        ChatCompletionRequest.Message msg = new ChatCompletionRequest.Message();
        msg.setRole("user");
        msg.setContent("ping");
        messages.add(msg);
        req.setMessages(messages);
        req.setStream(false);
        req.setMax_tokens(8);
        req.setTemperature(0D);
        if (cfg.getThinking() != null && cfg.getThinking() == 1) {
            ChatCompletionRequest.Thinking thinking = new ChatCompletionRequest.Thinking();
            thinking.setType("enable");
            req.setThinking(thinking);
        }
        try {
            LlmSpringAiClientService.CompletionCallResult result =
                    llmSpringAiClientService.callCompletion(cfg.getBaseUrl(), cfg.getApiKey(), req);
            return new TestHttpResult(result.getStatusCode(), result.getPayload());
        } catch (Throwable ex) {
            Integer statusCode = llmSpringAiClientService.statusCodeOf(ex);
            String body = llmSpringAiClientService.responseBodyOf(ex, 300);
            return new TestHttpResult(statusCode == null ? -1 : statusCode, safe(body != null ? body : ex.getMessage()));
        }
    }
    private TestHttpResult testResponsesRoute(LlmRouteConfig cfg) {
        LinkedHashMap<String, Object> body = new LinkedHashMap<>();
        body.put("model", cfg.getModel());
        body.put("input", "ping");
        body.put("max_output_tokens", 8);
        try {
            ResponseEntity<String> response = routeTestClient.post()
                    .uri(responsesEndpointUrl(cfg))
                    .contentType(MediaType.APPLICATION_JSON)
                    .headers(headers -> applyRouteAuth(headers, cfg))
                    .body(JSON.toJSONString(body))
                    .retrieve()
                    .toEntity(String.class);
            return new TestHttpResult(response.getStatusCode().value(), response.getBody());
        } catch (RestClientResponseException ex) {
            return new TestHttpResult(ex.getStatusCode().value(), trimBody(ex.getResponseBodyAsString()));
        } catch (Throwable ex) {
            return new TestHttpResult(-1, safe(ex.getMessage()));
        }
    }
    private String responsesEndpointUrl(LlmRouteConfig cfg) {
        String baseUrl = safe(cfg.getBaseUrl());
        String endpointPath = safe(cfg.getEndpointPath());
        if (endpointPath.isEmpty()) {
            endpointPath = "/responses";
        }
        if (baseUrl.endsWith("/") && endpointPath.startsWith("/")) {
            return baseUrl + endpointPath.substring(1);
        }
        if (!baseUrl.endsWith("/") && !endpointPath.startsWith("/")) {
            return baseUrl + "/" + endpointPath;
        }
        return baseUrl + endpointPath;
    }
    private void applyRouteAuth(HttpHeaders headers, LlmRouteConfig cfg) {
        String authType = safe(cfg.getAuthType()).toUpperCase();
        if ("NONE".equals(authType) || isBlank(cfg.getApiKey())) {
            return;
        }
        if ("API_KEY".equals(authType)) {
            String headerName = safe(cfg.getAuthHeaderName());
            headers.set(headerName.isEmpty() ? "X-API-Key" : headerName, cfg.getApiKey());
            return;
        }
        headers.setBearerAuth(cfg.getApiKey());
    }
    private String trimBody(String body) {