package com.zy.ai.service; import com.alibaba.fastjson.JSON; import com.zy.ai.entity.ChatCompletionRequest; import com.zy.ai.entity.ChatCompletionResponse; import com.zy.ai.entity.LlmCallLog; import com.zy.ai.entity.LlmRouteConfig; import com.zy.ai.gateway.AiGatewayService; import com.zy.ai.gateway.adapter.openai.OpenAiChatCompletionsMapper; import com.zy.ai.gateway.model.AiRequest; import com.zy.ai.gateway.model.AiResponse; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; import org.springframework.web.client.RestClientResponseException; import org.springframework.web.reactive.function.client.WebClientResponseException; import reactor.core.publisher.Flux; import java.util.ArrayList; import java.util.Date; import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.UUID; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; @Slf4j @Service @RequiredArgsConstructor public class LlmChatService { private static final int LOG_TEXT_LIMIT = 16000; private final LlmRoutingService llmRoutingService; private final LlmCallLogService llmCallLogService; private final LlmSpringAiClientService llmSpringAiClientService; private final AiGatewayService aiGatewayService; private final OpenAiChatCompletionsMapper openAiChatCompletionsMapper; private final AiTokenUsageService aiTokenUsageService; @Value("${llm.base-url:}") private String fallbackBaseUrl; @Value("${llm.api-key:}") private String fallbackApiKey; @Value("${llm.model:}") private String fallbackModel; @Value("${llm.thinking:false}") private String fallbackThinking; /** * 通用对话方法:传入 messages,返回大模型文本回复 */ public String chat(List messages, Double temperature, Integer maxTokens) { ChatCompletionRequest req = new ChatCompletionRequest(); req.setMessages(messages); req.setTemperature(temperature != null ? temperature : 0.3); req.setMax_tokens(maxTokens != null ? maxTokens : 1024); req.setStream(false); ChatCompletionResponse response = complete(req, "chat"); if (response == null || response.getChoices() == null || response.getChoices().isEmpty() || response.getChoices().get(0).getMessage() == null || response.getChoices().get(0).getMessage().getContent() == null || response.getChoices().get(0).getMessage().getContent().isEmpty()) { return null; } return response.getChoices().get(0).getMessage().getContent(); } public ChatCompletionResponse chatCompletion(List messages, Double temperature, Integer maxTokens, List tools) { ChatCompletionRequest req = new ChatCompletionRequest(); req.setMessages(messages); req.setTemperature(temperature != null ? temperature : 0.3); req.setMax_tokens(maxTokens != null ? maxTokens : 1024); req.setStream(false); if (tools != null && !tools.isEmpty()) { req.setTools(tools); req.setTool_choice("auto"); } return complete(req, tools != null && !tools.isEmpty() ? "chat_completion_tools" : "chat_completion"); } public ChatCompletionResponse chatCompletionOrThrow(List messages, Double temperature, Integer maxTokens, List tools) { ChatCompletionRequest req = new ChatCompletionRequest(); req.setMessages(messages); req.setTemperature(temperature != null ? temperature : 0.3); req.setMax_tokens(maxTokens != null ? maxTokens : 1024); req.setStream(false); if (tools != null && !tools.isEmpty()) { req.setTools(tools); req.setTool_choice("auto"); } return completeOrThrow(req, tools != null && !tools.isEmpty() ? "chat_completion_tools" : "chat_completion"); } public ChatCompletionResponse complete(ChatCompletionRequest req) { return complete(req, "completion"); } public ChatCompletionResponse complete(ChatCompletionRequest req, String scene) { try { return completeOrThrow(req, scene); } catch (Throwable ex) { log.error("调用 LLM 失败, scene={}", scene, ex); return null; } } public ChatCompletionResponse completeOrThrow(ChatCompletionRequest req, String scene) { AiRequest aiRequest = openAiChatCompletionsMapper.toAiRequest(req); aiRequest.setScene(scene); AiResponse response = aiGatewayService.generate(aiRequest); return openAiChatCompletionsMapper.toChatResponse(response); } public void chatStream(List messages, Double temperature, Integer maxTokens, Consumer onChunk, Runnable onComplete, Consumer onError) { ChatCompletionRequest req = new ChatCompletionRequest(); req.setMessages(messages); req.setTemperature(temperature != null ? temperature : 0.3); req.setMax_tokens(maxTokens != null ? maxTokens : 1024); req.setStream(true); streamWithFailover(req, onChunk, onComplete, onError, null, "chat_stream"); } public void chatStreamWithTools(List messages, Double temperature, Integer maxTokens, List tools, Consumer onChunk, Runnable onComplete, Consumer onError, Consumer onUsage) { ChatCompletionRequest req = new ChatCompletionRequest(); req.setMessages(messages); req.setTemperature(temperature != null ? temperature : 0.3); req.setMax_tokens(maxTokens != null ? maxTokens : 1024); req.setStream(true); if (tools != null && !tools.isEmpty()) { req.setTools(tools); req.setTool_choice("auto"); } streamWithFailover(req, onChunk, onComplete, onError, onUsage, tools != null && !tools.isEmpty() ? "chat_stream_tools" : "chat_stream"); } private void streamWithFailover(ChatCompletionRequest req, Consumer onChunk, Runnable onComplete, Consumer onError, Consumer onUsage, String scene) { String traceId = nextTraceId(); List routes = resolveRoutes(); if (routes.isEmpty()) { recordCall(traceId, scene, true, 1, null, false, null, 0L, req, null, null, "none", new RuntimeException("未配置可用 LLM 路由"), "no_route"); if (onError != null) onError.accept(new RuntimeException("未配置可用 LLM 路由")); return; } attemptStream(routes, 0, req, onChunk, onComplete, onError, onUsage, traceId, scene); } private void attemptStream(List routes, int index, ChatCompletionRequest req, Consumer onChunk, Runnable onComplete, Consumer onError, Consumer onUsage, String traceId, String scene) { if (index >= routes.size()) { if (onError != null) onError.accept(new RuntimeException("LLM 路由全部失败")); return; } ResolvedRoute route = routes.get(index); ChatCompletionRequest routeReq = applyRoute(cloneRequest(req), route, true); long start = System.currentTimeMillis(); StringBuilder outputBuffer = new StringBuilder(); AtomicBoolean doneSeen = new AtomicBoolean(false); AtomicBoolean errorSeen = new AtomicBoolean(false); AtomicBoolean emitted = new AtomicBoolean(false); AtomicReference usageRef = new AtomicReference<>(); LinkedBlockingQueue queue = new LinkedBlockingQueue<>(); Thread drain = new Thread(() -> { try { while (true) { String s = queue.poll(2, TimeUnit.SECONDS); if (s != null) { emitted.set(true); try { onChunk.accept(s); } catch (Exception ignore) { } } if (doneSeen.get() && queue.isEmpty()) { if (!errorSeen.get()) { try { if (onComplete != null) onComplete.run(); } catch (Exception ignore) { } } break; } } } catch (InterruptedException ignore) { } }); drain.setDaemon(true); drain.start(); Flux streamSource = streamFluxWithSpringAi(route, routeReq, usageRef::set); streamSource.subscribe(payload -> { if (payload == null || payload.isEmpty()) return; queue.offer(payload); appendLimited(outputBuffer, payload); }, err -> { errorSeen.set(true); doneSeen.set(true); boolean quota = isQuotaExhausted(err); boolean canSwitch = shouldSwitch(route, quota); markFailure(route, err, canSwitch); recordCall(traceId, scene, true, index + 1, route, false, statusCodeOf(err), System.currentTimeMillis() - start, routeReq, usageResponse(usageRef.get()), outputBuffer.toString(), quota ? "quota" : "error", err, "emitted=" + emitted.get()); if (!emitted.get() && canSwitch && index < routes.size() - 1) { log.warn("LLM 路由失败,自动切换,current={}, reason={}", route.tag(), errorText(err)); attemptStream(routes, index + 1, req, onChunk, onComplete, onError, onUsage, traceId, scene); return; } if (onError != null) onError.accept(err); }, () -> { markSuccess(route); if (onUsage != null && usageRef.get() != null) { try { onUsage.accept(usageRef.get()); } catch (Exception ignore) { } } recordCall(traceId, scene, true, index + 1, route, true, 200, System.currentTimeMillis() - start, routeReq, usageResponse(usageRef.get()), outputBuffer.toString(), "none", null, null); doneSeen.set(true); }); } private Flux streamFluxWithSpringAi(ResolvedRoute route, ChatCompletionRequest req, Consumer usageConsumer) { return llmSpringAiClientService.streamCompletion(route.baseUrl, route.apiKey, req, usageConsumer) .doOnError(ex -> log.error("调用 Spring AI 流式失败, route={}", route.tag(), ex)); } private ChatCompletionRequest applyRoute(ChatCompletionRequest req, ResolvedRoute route, boolean stream) { req.setModel(route.model); req.setStream(stream); if (route.thinkingEnabled) { ChatCompletionRequest.Thinking t = new ChatCompletionRequest.Thinking(); t.setType("enable"); req.setThinking(t); } else { req.setThinking(null); } return req; } private ChatCompletionRequest cloneRequest(ChatCompletionRequest src) { ChatCompletionRequest req = new ChatCompletionRequest(); req.setModel(src.getModel()); req.setMessages(src.getMessages()); req.setTemperature(src.getTemperature()); req.setMax_tokens(src.getMax_tokens()); req.setStream(src.getStream()); req.setTools(src.getTools()); req.setTool_choice(src.getTool_choice()); req.setThinking(src.getThinking()); return req; } private boolean shouldSwitch(ResolvedRoute route, boolean quota) { return quota ? route.switchOnQuota : route.switchOnError; } private void markSuccess(ResolvedRoute route) { if (route.id != null) { llmRoutingService.markSuccess(route.id); } } private void markFailure(ResolvedRoute route, Throwable ex, boolean enterCooldown) { if (route.id != null) { llmRoutingService.markFailure(route.id, errorText(ex), enterCooldown, route.cooldownSeconds); } } private String errorText(Throwable ex) { if (ex == null) return "unknown"; if (ex instanceof RestClientResponseException) { RestClientResponseException e = (RestClientResponseException) ex; String body = e.getResponseBodyAsString(); if (body != null && body.length() > 240) { body = body.substring(0, 240); } return "status=" + e.getStatusCode().value() + ", body=" + body; } if (ex instanceof WebClientResponseException) { WebClientResponseException e = (WebClientResponseException) ex; String body = e.getResponseBodyAsString(); if (body != null && body.length() > 240) { body = body.substring(0, 240); } return "status=" + e.getStatusCode().value() + ", body=" + body; } Integer springAiStatus = llmSpringAiClientService.statusCodeOf(ex); if (springAiStatus != null) { return "status=" + springAiStatus + ", body=" + llmSpringAiClientService.responseBodyOf(ex, 240); } return ex.getMessage() == null ? ex.toString() : ex.getMessage(); } private boolean isQuotaExhausted(Throwable ex) { Integer status = statusCodeOf(ex); if (status != null && status == 429) { return true; } String text = responseBodyOf(ex); text = text == null ? "" : text.toLowerCase(Locale.ROOT); return text.contains("insufficient_quota") || text.contains("quota") || text.contains("余额") || text.contains("用量") || text.contains("超限") || text.contains("rate limit"); } private List resolveRoutes() { List routes = new ArrayList<>(); List dbRoutes = llmRoutingService.listAvailableRoutes(); for (LlmRouteConfig c : dbRoutes) { routes.add(ResolvedRoute.fromDb(c)); } // 兼容:数据库为空时,回退到 yml if (routes.isEmpty() && !isBlank(fallbackBaseUrl) && !isBlank(fallbackApiKey) && !isBlank(fallbackModel)) { routes.add(ResolvedRoute.fromFallback(fallbackBaseUrl, fallbackApiKey, fallbackModel, isFallbackThinkingEnabled())); } return routes; } private boolean isFallbackThinkingEnabled() { String x = fallbackThinking == null ? "" : fallbackThinking.trim().toLowerCase(); return "true".equals(x) || "1".equals(x) || "enable".equals(x); } private boolean isBlank(String s) { return s == null || s.trim().isEmpty(); } private String nextTraceId() { return UUID.randomUUID().toString().replace("-", ""); } private void appendLimited(StringBuilder sb, String text) { if (sb == null || text == null || text.isEmpty()) { return; } int remain = LOG_TEXT_LIMIT - sb.length(); if (remain <= 0) { return; } if (text.length() <= remain) { sb.append(text); } else { sb.append(text, 0, remain); } } private Integer statusCodeOf(Throwable ex) { if (ex instanceof RestClientResponseException) { return ((RestClientResponseException) ex).getStatusCode().value(); } if (ex instanceof WebClientResponseException) { return ((WebClientResponseException) ex).getStatusCode().value(); } return llmSpringAiClientService.statusCodeOf(ex); } private String responseBodyOf(Throwable ex) { if (ex instanceof RestClientResponseException) { return cut(((RestClientResponseException) ex).getResponseBodyAsString(), LOG_TEXT_LIMIT); } if (ex instanceof WebClientResponseException) { return cut(((WebClientResponseException) ex).getResponseBodyAsString(), LOG_TEXT_LIMIT); } return cut(llmSpringAiClientService.responseBodyOf(ex, LOG_TEXT_LIMIT), LOG_TEXT_LIMIT); } private String safeName(Throwable ex) { return ex == null ? null : ex.getClass().getSimpleName(); } private String cut(String text, int maxLen) { if (text == null) return null; String clean = text.replace("\r", " "); return clean.length() > maxLen ? clean.substring(0, maxLen) : clean; } private void recordCall(String traceId, String scene, boolean stream, int attemptNo, ResolvedRoute route, boolean success, Integer httpStatus, long latencyMs, ChatCompletionRequest req, ChatCompletionResponse responseObj, String response, String switchMode, Throwable err, String extra) { LlmCallLog item = new LlmCallLog(); item.setTraceId(cut(traceId, 64)); item.setScene(cut(scene, 64)); item.setStream((short) (stream ? 1 : 0)); item.setAttemptNo(attemptNo); if (route != null) { item.setRouteId(route.id); item.setRouteName(cut(route.name, 128)); item.setBaseUrl(cut(route.baseUrl, 255)); item.setModel(cut(route.model, 128)); } item.setSuccess((short) (success ? 1 : 0)); item.setHttpStatus(httpStatus); item.setLatencyMs(latencyMs < 0 ? 0 : latencyMs); item.setSwitchMode(cut(switchMode, 32)); item.setRequestContent(cut(JSON.toJSONString(req), LOG_TEXT_LIMIT)); item.setResponseContent(cut(response, LOG_TEXT_LIMIT)); item.setErrorType(cut(safeName(err), 128)); item.setErrorMessage(err == null ? null : cut(errorText(err), 1024)); item.setExtra(cut(buildExtraPayload(responseObj == null ? null : responseObj.getUsage(), extra), 512)); item.setCreateTime(new Date()); llmCallLogService.saveIgnoreError(item); // 累加 token 到独立存储 if (success && responseObj != null && responseObj.getUsage() != null) { ChatCompletionResponse.Usage usage = responseObj.getUsage(); aiTokenUsageService.incrementTokens( usage.getPromptTokens() == null ? 0 : usage.getPromptTokens(), usage.getCompletionTokens() == null ? 0 : usage.getCompletionTokens(), usage.getTotalTokens() == null ? 0 : usage.getTotalTokens(), 1); } } private ChatCompletionResponse usageResponse(ChatCompletionResponse.Usage usage) { if (usage == null) { return null; } ChatCompletionResponse response = new ChatCompletionResponse(); response.setUsage(usage); return response; } private String buildExtraPayload(ChatCompletionResponse.Usage usage, String extra) { if (usage == null && isBlank(extra)) { return null; } HashMap payload = new HashMap<>(); if (usage != null) { if (usage.getPromptTokens() != null) { payload.put("promptTokens", usage.getPromptTokens()); } if (usage.getCompletionTokens() != null) { payload.put("completionTokens", usage.getCompletionTokens()); } if (usage.getTotalTokens() != null) { payload.put("totalTokens", usage.getTotalTokens()); } } if (!isBlank(extra)) { payload.put("note", extra); } return payload.isEmpty() ? null : JSON.toJSONString(payload); } private static class ResolvedRoute { private Long id; private String name; private String baseUrl; private String apiKey; private String model; private boolean thinkingEnabled; private boolean switchOnQuota; private boolean switchOnError; private Integer cooldownSeconds; private static ResolvedRoute fromDb(LlmRouteConfig c) { ResolvedRoute r = new ResolvedRoute(); r.id = c.getId(); r.name = c.getName(); r.baseUrl = c.getBaseUrl(); r.apiKey = c.getApiKey(); r.model = c.getModel(); r.thinkingEnabled = c.getThinking() != null && c.getThinking() == 1; r.switchOnQuota = c.getSwitchOnQuota() == null || c.getSwitchOnQuota() == 1; r.switchOnError = c.getSwitchOnError() == null || c.getSwitchOnError() == 1; r.cooldownSeconds = c.getCooldownSeconds(); return r; } private static ResolvedRoute fromFallback(String baseUrl, String apiKey, String model, boolean thinkingEnabled) { ResolvedRoute r = new ResolvedRoute(); r.name = "fallback-yml"; r.baseUrl = baseUrl; r.apiKey = apiKey; r.model = model; r.thinkingEnabled = thinkingEnabled; r.switchOnQuota = true; r.switchOnError = true; r.cooldownSeconds = 300; return r; } private String tag() { String showName = name == null ? "unnamed" : name; String showModel = model == null ? "" : (" model=" + model); return showName + showModel; } } }