Junjie
昨天 a4f07b2a0ddb6c210e05afbbb491feeb466203e7
src/main/java/com/zy/ai/service/LlmChatService.java
@@ -5,18 +5,28 @@
import com.alibaba.fastjson.JSONObject;
import com.zy.ai.entity.ChatCompletionRequest;
import com.zy.ai.entity.ChatCompletionResponse;
import com.zy.ai.entity.LlmCallLog;
import com.zy.ai.entity.LlmRouteConfig;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Service;
import org.springframework.web.client.RestClientResponseException;
import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.reactive.function.client.WebClientResponseException;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
@@ -27,7 +37,10 @@
@RequiredArgsConstructor
public class LlmChatService {
    private static final int LOG_TEXT_LIMIT = 16000;
    private final LlmRoutingService llmRoutingService;
    private final LlmCallLogService llmCallLogService;
    @Value("${llm.base-url:}")
    private String fallbackBaseUrl;
@@ -54,7 +67,7 @@
        req.setMax_tokens(maxTokens != null ? maxTokens : 1024);
        req.setStream(false);
        ChatCompletionResponse response = complete(req);
        ChatCompletionResponse response = complete(req, "chat");
        if (response == null ||
                response.getChoices() == null ||
@@ -81,13 +94,20 @@
            req.setTools(tools);
            req.setTool_choice("auto");
        }
        return complete(req);
        return complete(req, tools != null && !tools.isEmpty() ? "chat_completion_tools" : "chat_completion");
    }
    public ChatCompletionResponse complete(ChatCompletionRequest req) {
        return complete(req, "completion");
    }
    private ChatCompletionResponse complete(ChatCompletionRequest req, String scene) {
        String traceId = nextTraceId();
        List<ResolvedRoute> routes = resolveRoutes();
        if (routes.isEmpty()) {
            log.error("调用 LLM 失败: 未配置可用 LLM 路由");
            recordCall(traceId, scene, false, 1, null, false, null, 0L, req, null, "none",
                    new RuntimeException("未配置可用 LLM 路由"), "no_route");
            return null;
        }
@@ -95,19 +115,39 @@
        for (int i = 0; i < routes.size(); i++) {
            ResolvedRoute route = routes.get(i);
            boolean hasNext = i < routes.size() - 1;
            ChatCompletionRequest routeReq = applyRoute(cloneRequest(req), route, false);
            long start = System.currentTimeMillis();
            try {
                ChatCompletionRequest routeReq = applyRoute(cloneRequest(req), route, false);
                ChatCompletionResponse resp = callCompletion(route, routeReq);
                CompletionCallResult callResult = callCompletion(route, routeReq);
                ChatCompletionResponse resp = callResult.response;
                if (!isValidCompletion(resp)) {
                    throw new RuntimeException("LLM 响应为空");
                    RuntimeException ex = new RuntimeException("LLM 响应为空");
                    boolean canSwitch = shouldSwitch(route, false);
                    markFailure(route, ex, canSwitch);
                    recordCall(traceId, scene, false, i + 1, route, false, callResult.statusCode,
                            System.currentTimeMillis() - start, routeReq, callResult.payload, "error", ex,
                            "invalid_completion");
                    if (hasNext && canSwitch) {
                        log.warn("LLM 切换到下一路由, current={}, reason={}", route.tag(), ex.getMessage());
                        continue;
                    }
                    log.error("调用 LLM 失败, route={}", route.tag(), ex);
                    last = ex;
                    break;
                }
                markSuccess(route);
                recordCall(traceId, scene, false, i + 1, route, true, callResult.statusCode,
                        System.currentTimeMillis() - start, routeReq, buildResponseText(resp, callResult.payload),
                        "none", null, null);
                return resp;
            } catch (Throwable ex) {
                last = ex;
                boolean quota = isQuotaExhausted(ex);
                boolean canSwitch = shouldSwitch(route, quota);
                markFailure(route, ex, canSwitch);
                recordCall(traceId, scene, false, i + 1, route, false, statusCodeOf(ex),
                        System.currentTimeMillis() - start, routeReq, responseBodyOf(ex),
                        quota ? "quota" : "error", ex, null);
                if (hasNext && canSwitch) {
                    log.warn("LLM 切换到下一路由, current={}, reason={}", route.tag(), errorText(ex));
                    continue;
@@ -136,7 +176,7 @@
        req.setMax_tokens(maxTokens != null ? maxTokens : 1024);
        req.setStream(true);
        streamWithFailover(req, onChunk, onComplete, onError);
        streamWithFailover(req, onChunk, onComplete, onError, "chat_stream");
    }
    public void chatStreamWithTools(List<ChatCompletionRequest.Message> messages,
@@ -155,19 +195,23 @@
            req.setTools(tools);
            req.setTool_choice("auto");
        }
        streamWithFailover(req, onChunk, onComplete, onError);
        streamWithFailover(req, onChunk, onComplete, onError, tools != null && !tools.isEmpty() ? "chat_stream_tools" : "chat_stream");
    }
    private void streamWithFailover(ChatCompletionRequest req,
                                    Consumer<String> onChunk,
                                    Runnable onComplete,
                                    Consumer<Throwable> onError) {
                                    Consumer<Throwable> onError,
                                    String scene) {
        String traceId = nextTraceId();
        List<ResolvedRoute> routes = resolveRoutes();
        if (routes.isEmpty()) {
            recordCall(traceId, scene, true, 1, null, false, null, 0L, req, null, "none",
                    new RuntimeException("未配置可用 LLM 路由"), "no_route");
            if (onError != null) onError.accept(new RuntimeException("未配置可用 LLM 路由"));
            return;
        }
        attemptStream(routes, 0, req, onChunk, onComplete, onError);
        attemptStream(routes, 0, req, onChunk, onComplete, onError, traceId, scene);
    }
    private void attemptStream(List<ResolvedRoute> routes,
@@ -175,7 +219,9 @@
                               ChatCompletionRequest req,
                               Consumer<String> onChunk,
                               Runnable onComplete,
                               Consumer<Throwable> onError) {
                               Consumer<Throwable> onError,
                               String traceId,
                               String scene) {
        if (index >= routes.size()) {
            if (onError != null) onError.accept(new RuntimeException("LLM 路由全部失败"));
            return;
@@ -183,6 +229,8 @@
        ResolvedRoute route = routes.get(index);
        ChatCompletionRequest routeReq = applyRoute(cloneRequest(req), route, true);
        long start = System.currentTimeMillis();
        StringBuilder outputBuffer = new StringBuilder();
        AtomicBoolean doneSeen = new AtomicBoolean(false);
        AtomicBoolean errorSeen = new AtomicBoolean(false);
@@ -216,8 +264,15 @@
        drain.setDaemon(true);
        drain.start();
        streamFlux(route, routeReq).subscribe(payload -> {
        boolean springAiStreaming = canUseSpringAi(routeReq);
        Flux<String> streamSource = springAiStreaming ? streamFluxWithSpringAi(route, routeReq) : streamFlux(route, routeReq);
        streamSource.subscribe(payload -> {
            if (payload == null || payload.isEmpty()) return;
            if (springAiStreaming) {
                queue.offer(payload);
                appendLimited(outputBuffer, payload);
                return;
            }
            String[] events = payload.split("\\r?\\n\\r?\\n");
            for (String part : events) {
                String s = part;
@@ -240,6 +295,7 @@
                            String content = delta.getString("content");
                            if (content != null) {
                                queue.offer(content);
                                appendLimited(outputBuffer, content);
                            }
                        }
                    }
@@ -253,9 +309,12 @@
            boolean quota = isQuotaExhausted(err);
            boolean canSwitch = shouldSwitch(route, quota);
            markFailure(route, err, canSwitch);
            recordCall(traceId, scene, true, index + 1, route, false, statusCodeOf(err),
                    System.currentTimeMillis() - start, routeReq, outputBuffer.toString(),
                    quota ? "quota" : "error", err, "emitted=" + emitted.get());
            if (!emitted.get() && canSwitch && index < routes.size() - 1) {
                log.warn("LLM 路由失败,自动切换,current={}, reason={}", route.tag(), errorText(err));
                attemptStream(routes, index + 1, req, onChunk, onComplete, onError);
                attemptStream(routes, index + 1, req, onChunk, onComplete, onError, traceId, scene);
                return;
            }
            if (onError != null) onError.accept(err);
@@ -266,14 +325,20 @@
                doneSeen.set(true);
                boolean canSwitch = shouldSwitch(route, false);
                markFailure(route, ex, canSwitch);
                recordCall(traceId, scene, true, index + 1, route, false, 200,
                        System.currentTimeMillis() - start, routeReq, outputBuffer.toString(),
                        "error", ex, "unexpected_stream_end");
                if (!emitted.get() && canSwitch && index < routes.size() - 1) {
                    log.warn("LLM 路由流异常完成,自动切换,current={}", route.tag());
                    attemptStream(routes, index + 1, req, onChunk, onComplete, onError);
                    attemptStream(routes, index + 1, req, onChunk, onComplete, onError, traceId, scene);
                } else {
                    if (onError != null) onError.accept(ex);
                }
            } else {
                markSuccess(route);
                recordCall(traceId, scene, true, index + 1, route, true, 200,
                        System.currentTimeMillis() - start, routeReq, outputBuffer.toString(),
                        "none", null, null);
                doneSeen.set(true);
            }
        });
@@ -299,7 +364,26 @@
                .doOnError(ex -> log.error("调用 LLM 流式失败, route={}", route.tag(), ex));
    }
    private ChatCompletionResponse callCompletion(ResolvedRoute route, ChatCompletionRequest req) {
    private Flux<String> streamFluxWithSpringAi(ResolvedRoute route, ChatCompletionRequest req) {
        OpenAiApi api = buildOpenAiApi(route);
        OpenAiApi.ChatCompletionRequest springReq = buildSpringAiRequest(route, req, true);
        return api.chatCompletionStream(springReq)
                .flatMapIterable(chunk -> chunk == null || chunk.choices() == null ? List.<OpenAiApi.ChatCompletionChunk.ChunkChoice>of() : chunk.choices())
                .map(OpenAiApi.ChatCompletionChunk.ChunkChoice::delta)
                .filter(Objects::nonNull)
                .map(this::extractSpringAiContent)
                .filter(text -> text != null && !text.isEmpty())
                .doOnError(ex -> log.error("调用 Spring AI 流式失败, route={}", route.tag(), ex));
    }
    private CompletionCallResult callCompletion(ResolvedRoute route, ChatCompletionRequest req) {
        if (canUseSpringAi(req)) {
            return callCompletionWithSpringAi(route, req);
        }
        return callCompletionWithWebClient(route, req);
    }
    private CompletionCallResult callCompletionWithWebClient(ResolvedRoute route, ChatCompletionRequest req) {
        WebClient client = WebClient.builder().baseUrl(route.baseUrl).build();
        RawCompletionResult raw = client.post()
                .uri("/chat/completions")
@@ -318,7 +402,17 @@
        if (raw.statusCode < 200 || raw.statusCode >= 300) {
            throw new LlmRouteException(raw.statusCode, raw.payload);
        }
        return parseCompletion(raw.payload);
        return new CompletionCallResult(raw.statusCode, raw.payload, parseCompletion(raw.payload));
    }
    private CompletionCallResult callCompletionWithSpringAi(ResolvedRoute route, ChatCompletionRequest req) {
        OpenAiApi api = buildOpenAiApi(route);
        OpenAiApi.ChatCompletionRequest springReq = buildSpringAiRequest(route, req, false);
        ResponseEntity<OpenAiApi.ChatCompletion> entity = api.chatCompletionEntity(springReq);
        OpenAiApi.ChatCompletion body = entity.getBody();
        return new CompletionCallResult(entity.getStatusCode().value(),
                body == null ? null : JSON.toJSONString(body),
                toLegacyResponse(body));
    }
    private ChatCompletionRequest applyRoute(ChatCompletionRequest req, ResolvedRoute route, boolean stream) {
@@ -365,6 +459,10 @@
        return quota ? route.switchOnQuota : route.switchOnError;
    }
    private boolean canUseSpringAi(ChatCompletionRequest req) {
        return req != null && (req.getTools() == null || req.getTools().isEmpty());
    }
    private void markSuccess(ResolvedRoute route) {
        if (route.id != null) {
            llmRoutingService.markSuccess(route.id);
@@ -387,14 +485,32 @@
            }
            return "status=" + e.statusCode + ", body=" + body;
        }
        if (ex instanceof RestClientResponseException) {
            RestClientResponseException e = (RestClientResponseException) ex;
            String body = e.getResponseBodyAsString();
            if (body != null && body.length() > 240) {
                body = body.substring(0, 240);
            }
            return "status=" + e.getStatusCode().value() + ", body=" + body;
        }
        if (ex instanceof WebClientResponseException) {
            WebClientResponseException e = (WebClientResponseException) ex;
            String body = e.getResponseBodyAsString();
            if (body != null && body.length() > 240) {
                body = body.substring(0, 240);
            }
            return "status=" + e.getStatusCode().value() + ", body=" + body;
        }
        return ex.getMessage() == null ? ex.toString() : ex.getMessage();
    }
    private boolean isQuotaExhausted(Throwable ex) {
        if (!(ex instanceof LlmRouteException)) return false;
        LlmRouteException e = (LlmRouteException) ex;
        if (e.statusCode == 429) return true;
        String text = (e.body == null ? "" : e.body).toLowerCase();
        Integer status = statusCodeOf(ex);
        if (status != null && status == 429) {
            return true;
        }
        String text = responseBodyOf(ex);
        text = text == null ? "" : text.toLowerCase(Locale.ROOT);
        return text.contains("insufficient_quota")
                || text.contains("quota")
                || text.contains("余额")
@@ -517,6 +633,276 @@
        return r;
    }
    private String nextTraceId() {
        return UUID.randomUUID().toString().replace("-", "");
    }
    private void appendLimited(StringBuilder sb, String text) {
        if (sb == null || text == null || text.isEmpty()) {
            return;
        }
        int remain = LOG_TEXT_LIMIT - sb.length();
        if (remain <= 0) {
            return;
        }
        if (text.length() <= remain) {
            sb.append(text);
        } else {
            sb.append(text, 0, remain);
        }
    }
    private Integer statusCodeOf(Throwable ex) {
        if (ex instanceof LlmRouteException) {
            return ((LlmRouteException) ex).statusCode;
        }
        if (ex instanceof RestClientResponseException) {
            return ((RestClientResponseException) ex).getStatusCode().value();
        }
        if (ex instanceof WebClientResponseException) {
            return ((WebClientResponseException) ex).getStatusCode().value();
        }
        return null;
    }
    private String responseBodyOf(Throwable ex) {
        if (ex instanceof LlmRouteException) {
            return cut(((LlmRouteException) ex).body, LOG_TEXT_LIMIT);
        }
        if (ex instanceof RestClientResponseException) {
            return cut(((RestClientResponseException) ex).getResponseBodyAsString(), LOG_TEXT_LIMIT);
        }
        if (ex instanceof WebClientResponseException) {
            return cut(((WebClientResponseException) ex).getResponseBodyAsString(), LOG_TEXT_LIMIT);
        }
        return null;
    }
    private String buildResponseText(ChatCompletionResponse resp, String fallbackPayload) {
        if (resp != null && resp.getChoices() != null && !resp.getChoices().isEmpty()
                && resp.getChoices().get(0) != null && resp.getChoices().get(0).getMessage() != null) {
            ChatCompletionRequest.Message m = resp.getChoices().get(0).getMessage();
            if (!isBlank(m.getContent())) {
                return cut(m.getContent(), LOG_TEXT_LIMIT);
            }
            if (m.getTool_calls() != null && !m.getTool_calls().isEmpty()) {
                return cut(JSON.toJSONString(m), LOG_TEXT_LIMIT);
            }
        }
        return cut(fallbackPayload, LOG_TEXT_LIMIT);
    }
    private String safeName(Throwable ex) {
        return ex == null ? null : ex.getClass().getSimpleName();
    }
    private OpenAiApi buildOpenAiApi(ResolvedRoute route) {
        return OpenAiApi.builder()
                .baseUrl(route.baseUrl)
                .apiKey(route.apiKey)
                .build();
    }
    private OpenAiApi.ChatCompletionRequest buildSpringAiRequest(ResolvedRoute route,
                                                                 ChatCompletionRequest req,
                                                                 boolean stream) {
        HashMap<String, Object> extraBody = new HashMap<>();
        if (route.thinkingEnabled || req.getThinking() != null) {
            HashMap<String, Object> thinking = new HashMap<>();
            thinking.put("type", req.getThinking() != null && req.getThinking().getType() != null
                    ? req.getThinking().getType()
                    : "enable");
            extraBody.put("thinking", thinking);
        }
        return new OpenAiApi.ChatCompletionRequest(
                toSpringAiMessages(req.getMessages()),
                route.model,
                null,
                null,
                null,
                null,
                null,
                null,
                req.getMax_tokens(),
                null,
                1,
                null,
                null,
                null,
                null,
                null,
                null,
                null,
                stream,
                stream ? OpenAiApi.ChatCompletionRequest.StreamOptions.INCLUDE_USAGE : null,
                req.getTemperature(),
                null,
                null,
                null,
                null,
                null,
                null,
                null,
                null,
                null,
                null,
                extraBody.isEmpty() ? null : extraBody
        );
    }
    private List<OpenAiApi.ChatCompletionMessage> toSpringAiMessages(List<ChatCompletionRequest.Message> messages) {
        ArrayList<OpenAiApi.ChatCompletionMessage> result = new ArrayList<>();
        if (messages == null) {
            return result;
        }
        for (ChatCompletionRequest.Message message : messages) {
            if (message == null) {
                continue;
            }
            result.add(new OpenAiApi.ChatCompletionMessage(
                    message.getContent(),
                    toSpringAiRole(message.getRole())
            ));
        }
        return result;
    }
    private OpenAiApi.ChatCompletionMessage.Role toSpringAiRole(String role) {
        if (role == null) {
            return OpenAiApi.ChatCompletionMessage.Role.USER;
        }
        switch (role.trim().toLowerCase(Locale.ROOT)) {
            case "system":
                return OpenAiApi.ChatCompletionMessage.Role.SYSTEM;
            case "assistant":
                return OpenAiApi.ChatCompletionMessage.Role.ASSISTANT;
            case "tool":
                return OpenAiApi.ChatCompletionMessage.Role.TOOL;
            default:
                return OpenAiApi.ChatCompletionMessage.Role.USER;
        }
    }
    private ChatCompletionResponse toLegacyResponse(OpenAiApi.ChatCompletion completion) {
        if (completion == null) {
            return null;
        }
        ChatCompletionResponse response = new ChatCompletionResponse();
        response.setId(completion.id());
        response.setCreated(completion.created());
        response.setObjectName(completion.object());
        if (completion.usage() != null) {
            ChatCompletionResponse.Usage usage = new ChatCompletionResponse.Usage();
            usage.setPromptTokens(completion.usage().promptTokens());
            usage.setCompletionTokens(completion.usage().completionTokens());
            usage.setTotalTokens(completion.usage().totalTokens());
            response.setUsage(usage);
        }
        if (completion.choices() != null) {
            ArrayList<ChatCompletionResponse.Choice> choices = new ArrayList<>();
            for (OpenAiApi.ChatCompletion.Choice choice : completion.choices()) {
                ChatCompletionResponse.Choice item = new ChatCompletionResponse.Choice();
                item.setIndex(choice.index());
                if (choice.finishReason() != null) {
                    item.setFinishReason(choice.finishReason().name().toLowerCase(Locale.ROOT));
                }
                item.setMessage(toLegacyMessage(choice.message()));
                choices.add(item);
            }
            response.setChoices(choices);
        }
        return response;
    }
    private ChatCompletionRequest.Message toLegacyMessage(OpenAiApi.ChatCompletionMessage message) {
        if (message == null) {
            return null;
        }
        ChatCompletionRequest.Message result = new ChatCompletionRequest.Message();
        result.setContent(extractSpringAiContent(message));
        if (message.role() != null) {
            result.setRole(message.role().name().toLowerCase(Locale.ROOT));
        }
        result.setName(message.name());
        result.setTool_call_id(message.toolCallId());
        return result;
    }
    private String extractSpringAiContent(OpenAiApi.ChatCompletionMessage message) {
        if (message == null || message.rawContent() == null) {
            return null;
        }
        Object content = message.rawContent();
        if (content instanceof String) {
            return (String) content;
        }
        if (content instanceof List) {
            try {
                @SuppressWarnings("unchecked")
                List<OpenAiApi.ChatCompletionMessage.MediaContent> media =
                        (List<OpenAiApi.ChatCompletionMessage.MediaContent>) content;
                return OpenAiApi.getTextContent(media);
            } catch (ClassCastException ignore) {
            }
        }
        return String.valueOf(content);
    }
    private String cut(String text, int maxLen) {
        if (text == null) return null;
        String clean = text.replace("\r", " ");
        return clean.length() > maxLen ? clean.substring(0, maxLen) : clean;
    }
    private void recordCall(String traceId,
                            String scene,
                            boolean stream,
                            int attemptNo,
                            ResolvedRoute route,
                            boolean success,
                            Integer httpStatus,
                            long latencyMs,
                            ChatCompletionRequest req,
                            String response,
                            String switchMode,
                            Throwable err,
                            String extra) {
        LlmCallLog item = new LlmCallLog();
        item.setTraceId(cut(traceId, 64));
        item.setScene(cut(scene, 64));
        item.setStream((short) (stream ? 1 : 0));
        item.setAttemptNo(attemptNo);
        if (route != null) {
            item.setRouteId(route.id);
            item.setRouteName(cut(route.name, 128));
            item.setBaseUrl(cut(route.baseUrl, 255));
            item.setModel(cut(route.model, 128));
        }
        item.setSuccess((short) (success ? 1 : 0));
        item.setHttpStatus(httpStatus);
        item.setLatencyMs(latencyMs < 0 ? 0 : latencyMs);
        item.setSwitchMode(cut(switchMode, 32));
        item.setRequestContent(cut(JSON.toJSONString(req), LOG_TEXT_LIMIT));
        item.setResponseContent(cut(response, LOG_TEXT_LIMIT));
        item.setErrorType(cut(safeName(err), 128));
        item.setErrorMessage(err == null ? null : cut(errorText(err), 1024));
        item.setExtra(cut(extra, 512));
        item.setCreateTime(new Date());
        llmCallLogService.saveIgnoreError(item);
    }
    private static class CompletionCallResult {
        private final int statusCode;
        private final String payload;
        private final ChatCompletionResponse response;
        private CompletionCallResult(int statusCode, String payload, ChatCompletionResponse response) {
            this.statusCode = statusCode;
            this.payload = payload;
            this.response = response;
        }
    }
    private static class RawCompletionResult {
        private final int statusCode;
        private final String payload;