From 1b8a4677f362d234d834120deac4880d7ae89a50 Mon Sep 17 00:00:00 2001
From: Junjie <fallin.jie@qq.com>
Date: 星期四, 12 三月 2026 17:03:59 +0800
Subject: [PATCH] #

---
 src/main/java/com/zy/ai/service/LlmChatService.java |   77 ++++++++++++++++++++++++++++++--------
 1 files changed, 61 insertions(+), 16 deletions(-)

diff --git a/src/main/java/com/zy/ai/service/LlmChatService.java b/src/main/java/com/zy/ai/service/LlmChatService.java
index e2eddd6..a3835f5 100644
--- a/src/main/java/com/zy/ai/service/LlmChatService.java
+++ b/src/main/java/com/zy/ai/service/LlmChatService.java
@@ -22,6 +22,7 @@
 import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Consumer;
 
 @Slf4j
@@ -99,7 +100,7 @@
         List<ResolvedRoute> routes = resolveRoutes();
         if (routes.isEmpty()) {
             log.error("璋冪敤 LLM 澶辫触: 鏈厤缃彲鐢� LLM 璺敱");
-            recordCall(traceId, scene, false, 1, null, false, null, 0L, req, null, "none",
+            recordCall(traceId, scene, false, 1, null, false, null, 0L, req, null, null, "none",
                     new RuntimeException("鏈厤缃彲鐢� LLM 璺敱"), "no_route");
             return null;
         }
@@ -118,7 +119,7 @@
                     boolean canSwitch = shouldSwitch(route, false);
                     markFailure(route, ex, canSwitch);
                     recordCall(traceId, scene, false, i + 1, route, false, callResult.statusCode,
-                            System.currentTimeMillis() - start, routeReq, callResult.payload, "error", ex,
+                            System.currentTimeMillis() - start, routeReq, resp, callResult.payload, "error", ex,
                             "invalid_completion");
                     if (hasNext && canSwitch) {
                         log.warn("LLM 鍒囨崲鍒颁笅涓�璺敱, current={}, reason={}", route.tag(), ex.getMessage());
@@ -130,7 +131,7 @@
                 }
                 markSuccess(route);
                 recordCall(traceId, scene, false, i + 1, route, true, callResult.statusCode,
-                        System.currentTimeMillis() - start, routeReq, buildResponseText(resp, callResult.payload),
+                        System.currentTimeMillis() - start, routeReq, resp, buildResponseText(resp, callResult.payload),
                         "none", null, null);
                 return resp;
             } catch (Throwable ex) {
@@ -139,7 +140,7 @@
                 boolean canSwitch = shouldSwitch(route, quota);
                 markFailure(route, ex, canSwitch);
                 recordCall(traceId, scene, false, i + 1, route, false, statusCodeOf(ex),
-                        System.currentTimeMillis() - start, routeReq, responseBodyOf(ex),
+                        System.currentTimeMillis() - start, routeReq, null, responseBodyOf(ex),
                         quota ? "quota" : "error", ex, null);
                 if (hasNext && canSwitch) {
                     log.warn("LLM 鍒囨崲鍒颁笅涓�璺敱, current={}, reason={}", route.tag(), errorText(ex));
@@ -169,7 +170,7 @@
         req.setMax_tokens(maxTokens != null ? maxTokens : 1024);
         req.setStream(true);
 
-        streamWithFailover(req, onChunk, onComplete, onError, "chat_stream");
+        streamWithFailover(req, onChunk, onComplete, onError, null, "chat_stream");
     }
 
     public void chatStreamWithTools(List<ChatCompletionRequest.Message> messages,
@@ -178,7 +179,8 @@
                                     List<Object> tools,
                                     Consumer<String> onChunk,
                                     Runnable onComplete,
-                                    Consumer<Throwable> onError) {
+                                    Consumer<Throwable> onError,
+                                    Consumer<ChatCompletionResponse.Usage> onUsage) {
         ChatCompletionRequest req = new ChatCompletionRequest();
         req.setMessages(messages);
         req.setTemperature(temperature != null ? temperature : 0.3);
@@ -188,23 +190,24 @@
             req.setTools(tools);
             req.setTool_choice("auto");
         }
-        streamWithFailover(req, onChunk, onComplete, onError, tools != null && !tools.isEmpty() ? "chat_stream_tools" : "chat_stream");
+        streamWithFailover(req, onChunk, onComplete, onError, onUsage, tools != null && !tools.isEmpty() ? "chat_stream_tools" : "chat_stream");
     }
 
     private void streamWithFailover(ChatCompletionRequest req,
                                     Consumer<String> onChunk,
                                     Runnable onComplete,
                                     Consumer<Throwable> onError,
+                                    Consumer<ChatCompletionResponse.Usage> onUsage,
                                     String scene) {
         String traceId = nextTraceId();
         List<ResolvedRoute> routes = resolveRoutes();
         if (routes.isEmpty()) {
-            recordCall(traceId, scene, true, 1, null, false, null, 0L, req, null, "none",
+            recordCall(traceId, scene, true, 1, null, false, null, 0L, req, null, null, "none",
                     new RuntimeException("鏈厤缃彲鐢� LLM 璺敱"), "no_route");
             if (onError != null) onError.accept(new RuntimeException("鏈厤缃彲鐢� LLM 璺敱"));
             return;
         }
-        attemptStream(routes, 0, req, onChunk, onComplete, onError, traceId, scene);
+        attemptStream(routes, 0, req, onChunk, onComplete, onError, onUsage, traceId, scene);
     }
 
     private void attemptStream(List<ResolvedRoute> routes,
@@ -213,6 +216,7 @@
                                Consumer<String> onChunk,
                                Runnable onComplete,
                                Consumer<Throwable> onError,
+                               Consumer<ChatCompletionResponse.Usage> onUsage,
                                String traceId,
                                String scene) {
         if (index >= routes.size()) {
@@ -228,6 +232,7 @@
         AtomicBoolean doneSeen = new AtomicBoolean(false);
         AtomicBoolean errorSeen = new AtomicBoolean(false);
         AtomicBoolean emitted = new AtomicBoolean(false);
+        AtomicReference<ChatCompletionResponse.Usage> usageRef = new AtomicReference<>();
         LinkedBlockingQueue<String> queue = new LinkedBlockingQueue<>();
 
         Thread drain = new Thread(() -> {
@@ -257,7 +262,7 @@
         drain.setDaemon(true);
         drain.start();
 
-        Flux<String> streamSource = streamFluxWithSpringAi(route, routeReq);
+        Flux<String> streamSource = streamFluxWithSpringAi(route, routeReq, usageRef::set);
         streamSource.subscribe(payload -> {
             if (payload == null || payload.isEmpty()) return;
             queue.offer(payload);
@@ -269,25 +274,33 @@
             boolean canSwitch = shouldSwitch(route, quota);
             markFailure(route, err, canSwitch);
             recordCall(traceId, scene, true, index + 1, route, false, statusCodeOf(err),
-                    System.currentTimeMillis() - start, routeReq, outputBuffer.toString(),
+                    System.currentTimeMillis() - start, routeReq, usageResponse(usageRef.get()), outputBuffer.toString(),
                     quota ? "quota" : "error", err, "emitted=" + emitted.get());
             if (!emitted.get() && canSwitch && index < routes.size() - 1) {
                 log.warn("LLM 璺敱澶辫触锛岃嚜鍔ㄥ垏鎹紝current={}, reason={}", route.tag(), errorText(err));
-                attemptStream(routes, index + 1, req, onChunk, onComplete, onError, traceId, scene);
+                attemptStream(routes, index + 1, req, onChunk, onComplete, onError, onUsage, traceId, scene);
                 return;
             }
             if (onError != null) onError.accept(err);
         }, () -> {
             markSuccess(route);
+            if (onUsage != null && usageRef.get() != null) {
+                try {
+                    onUsage.accept(usageRef.get());
+                } catch (Exception ignore) {
+                }
+            }
             recordCall(traceId, scene, true, index + 1, route, true, 200,
-                    System.currentTimeMillis() - start, routeReq, outputBuffer.toString(),
+                    System.currentTimeMillis() - start, routeReq, usageResponse(usageRef.get()), outputBuffer.toString(),
                     "none", null, null);
             doneSeen.set(true);
         });
     }
 
-    private Flux<String> streamFluxWithSpringAi(ResolvedRoute route, ChatCompletionRequest req) {
-        return llmSpringAiClientService.streamCompletion(route.baseUrl, route.apiKey, req)
+    private Flux<String> streamFluxWithSpringAi(ResolvedRoute route,
+                                                ChatCompletionRequest req,
+                                                Consumer<ChatCompletionResponse.Usage> usageConsumer) {
+        return llmSpringAiClientService.streamCompletion(route.baseUrl, route.apiKey, req, usageConsumer)
                 .doOnError(ex -> log.error("璋冪敤 Spring AI 娴佸紡澶辫触, route={}", route.tag(), ex));
     }
 
@@ -491,6 +504,7 @@
                             Integer httpStatus,
                             long latencyMs,
                             ChatCompletionRequest req,
+                            ChatCompletionResponse responseObj,
                             String response,
                             String switchMode,
                             Throwable err,
@@ -514,11 +528,42 @@
         item.setResponseContent(cut(response, LOG_TEXT_LIMIT));
         item.setErrorType(cut(safeName(err), 128));
         item.setErrorMessage(err == null ? null : cut(errorText(err), 1024));
-        item.setExtra(cut(extra, 512));
+        item.setExtra(cut(buildExtraPayload(responseObj == null ? null : responseObj.getUsage(), extra), 512));
         item.setCreateTime(new Date());
         llmCallLogService.saveIgnoreError(item);
     }
 
+    private ChatCompletionResponse usageResponse(ChatCompletionResponse.Usage usage) {
+        if (usage == null) {
+            return null;
+        }
+        ChatCompletionResponse response = new ChatCompletionResponse();
+        response.setUsage(usage);
+        return response;
+    }
+
+    private String buildExtraPayload(ChatCompletionResponse.Usage usage, String extra) {
+        if (usage == null && isBlank(extra)) {
+            return null;
+        }
+        HashMap<String, Object> payload = new HashMap<>();
+        if (usage != null) {
+            if (usage.getPromptTokens() != null) {
+                payload.put("promptTokens", usage.getPromptTokens());
+            }
+            if (usage.getCompletionTokens() != null) {
+                payload.put("completionTokens", usage.getCompletionTokens());
+            }
+            if (usage.getTotalTokens() != null) {
+                payload.put("totalTokens", usage.getTotalTokens());
+            }
+        }
+        if (!isBlank(extra)) {
+            payload.put("note", extra);
+        }
+        return payload.isEmpty() ? null : JSON.toJSONString(payload);
+    }
+
     private static class CompletionCallResult {
         private final int statusCode;
         private final String payload;

--
Gitblit v1.9.1