From caf3bdd9bbb629c8bc6f1a19b3ccdf441bf7650c Mon Sep 17 00:00:00 2001
From: Junjie <fallin.jie@qq.com>
Date: 星期日, 15 三月 2026 17:46:47 +0800
Subject: [PATCH] #
---
src/main/java/com/zy/ai/service/LlmChatService.java | 77 ++++++++++++++++++++++++++++++--------
1 files changed, 61 insertions(+), 16 deletions(-)
diff --git a/src/main/java/com/zy/ai/service/LlmChatService.java b/src/main/java/com/zy/ai/service/LlmChatService.java
index e2eddd6..a3835f5 100644
--- a/src/main/java/com/zy/ai/service/LlmChatService.java
+++ b/src/main/java/com/zy/ai/service/LlmChatService.java
@@ -22,6 +22,7 @@
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
@Slf4j
@@ -99,7 +100,7 @@
List<ResolvedRoute> routes = resolveRoutes();
if (routes.isEmpty()) {
log.error("璋冪敤 LLM 澶辫触: 鏈厤缃彲鐢� LLM 璺敱");
- recordCall(traceId, scene, false, 1, null, false, null, 0L, req, null, "none",
+ recordCall(traceId, scene, false, 1, null, false, null, 0L, req, null, null, "none",
new RuntimeException("鏈厤缃彲鐢� LLM 璺敱"), "no_route");
return null;
}
@@ -118,7 +119,7 @@
boolean canSwitch = shouldSwitch(route, false);
markFailure(route, ex, canSwitch);
recordCall(traceId, scene, false, i + 1, route, false, callResult.statusCode,
- System.currentTimeMillis() - start, routeReq, callResult.payload, "error", ex,
+ System.currentTimeMillis() - start, routeReq, resp, callResult.payload, "error", ex,
"invalid_completion");
if (hasNext && canSwitch) {
log.warn("LLM 鍒囨崲鍒颁笅涓�璺敱, current={}, reason={}", route.tag(), ex.getMessage());
@@ -130,7 +131,7 @@
}
markSuccess(route);
recordCall(traceId, scene, false, i + 1, route, true, callResult.statusCode,
- System.currentTimeMillis() - start, routeReq, buildResponseText(resp, callResult.payload),
+ System.currentTimeMillis() - start, routeReq, resp, buildResponseText(resp, callResult.payload),
"none", null, null);
return resp;
} catch (Throwable ex) {
@@ -139,7 +140,7 @@
boolean canSwitch = shouldSwitch(route, quota);
markFailure(route, ex, canSwitch);
recordCall(traceId, scene, false, i + 1, route, false, statusCodeOf(ex),
- System.currentTimeMillis() - start, routeReq, responseBodyOf(ex),
+ System.currentTimeMillis() - start, routeReq, null, responseBodyOf(ex),
quota ? "quota" : "error", ex, null);
if (hasNext && canSwitch) {
log.warn("LLM 鍒囨崲鍒颁笅涓�璺敱, current={}, reason={}", route.tag(), errorText(ex));
@@ -169,7 +170,7 @@
req.setMax_tokens(maxTokens != null ? maxTokens : 1024);
req.setStream(true);
- streamWithFailover(req, onChunk, onComplete, onError, "chat_stream");
+ streamWithFailover(req, onChunk, onComplete, onError, null, "chat_stream");
}
public void chatStreamWithTools(List<ChatCompletionRequest.Message> messages,
@@ -178,7 +179,8 @@
List<Object> tools,
Consumer<String> onChunk,
Runnable onComplete,
- Consumer<Throwable> onError) {
+ Consumer<Throwable> onError,
+ Consumer<ChatCompletionResponse.Usage> onUsage) {
ChatCompletionRequest req = new ChatCompletionRequest();
req.setMessages(messages);
req.setTemperature(temperature != null ? temperature : 0.3);
@@ -188,23 +190,24 @@
req.setTools(tools);
req.setTool_choice("auto");
}
- streamWithFailover(req, onChunk, onComplete, onError, tools != null && !tools.isEmpty() ? "chat_stream_tools" : "chat_stream");
+ streamWithFailover(req, onChunk, onComplete, onError, onUsage, tools != null && !tools.isEmpty() ? "chat_stream_tools" : "chat_stream");
}
private void streamWithFailover(ChatCompletionRequest req,
Consumer<String> onChunk,
Runnable onComplete,
Consumer<Throwable> onError,
+ Consumer<ChatCompletionResponse.Usage> onUsage,
String scene) {
String traceId = nextTraceId();
List<ResolvedRoute> routes = resolveRoutes();
if (routes.isEmpty()) {
- recordCall(traceId, scene, true, 1, null, false, null, 0L, req, null, "none",
+ recordCall(traceId, scene, true, 1, null, false, null, 0L, req, null, null, "none",
new RuntimeException("鏈厤缃彲鐢� LLM 璺敱"), "no_route");
if (onError != null) onError.accept(new RuntimeException("鏈厤缃彲鐢� LLM 璺敱"));
return;
}
- attemptStream(routes, 0, req, onChunk, onComplete, onError, traceId, scene);
+ attemptStream(routes, 0, req, onChunk, onComplete, onError, onUsage, traceId, scene);
}
private void attemptStream(List<ResolvedRoute> routes,
@@ -213,6 +216,7 @@
Consumer<String> onChunk,
Runnable onComplete,
Consumer<Throwable> onError,
+ Consumer<ChatCompletionResponse.Usage> onUsage,
String traceId,
String scene) {
if (index >= routes.size()) {
@@ -228,6 +232,7 @@
AtomicBoolean doneSeen = new AtomicBoolean(false);
AtomicBoolean errorSeen = new AtomicBoolean(false);
AtomicBoolean emitted = new AtomicBoolean(false);
+ AtomicReference<ChatCompletionResponse.Usage> usageRef = new AtomicReference<>();
LinkedBlockingQueue<String> queue = new LinkedBlockingQueue<>();
Thread drain = new Thread(() -> {
@@ -257,7 +262,7 @@
drain.setDaemon(true);
drain.start();
- Flux<String> streamSource = streamFluxWithSpringAi(route, routeReq);
+ Flux<String> streamSource = streamFluxWithSpringAi(route, routeReq, usageRef::set);
streamSource.subscribe(payload -> {
if (payload == null || payload.isEmpty()) return;
queue.offer(payload);
@@ -269,25 +274,33 @@
boolean canSwitch = shouldSwitch(route, quota);
markFailure(route, err, canSwitch);
recordCall(traceId, scene, true, index + 1, route, false, statusCodeOf(err),
- System.currentTimeMillis() - start, routeReq, outputBuffer.toString(),
+ System.currentTimeMillis() - start, routeReq, usageResponse(usageRef.get()), outputBuffer.toString(),
quota ? "quota" : "error", err, "emitted=" + emitted.get());
if (!emitted.get() && canSwitch && index < routes.size() - 1) {
log.warn("LLM 璺敱澶辫触锛岃嚜鍔ㄥ垏鎹紝current={}, reason={}", route.tag(), errorText(err));
- attemptStream(routes, index + 1, req, onChunk, onComplete, onError, traceId, scene);
+ attemptStream(routes, index + 1, req, onChunk, onComplete, onError, onUsage, traceId, scene);
return;
}
if (onError != null) onError.accept(err);
}, () -> {
markSuccess(route);
+ if (onUsage != null && usageRef.get() != null) {
+ try {
+ onUsage.accept(usageRef.get());
+ } catch (Exception ignore) {
+ }
+ }
recordCall(traceId, scene, true, index + 1, route, true, 200,
- System.currentTimeMillis() - start, routeReq, outputBuffer.toString(),
+ System.currentTimeMillis() - start, routeReq, usageResponse(usageRef.get()), outputBuffer.toString(),
"none", null, null);
doneSeen.set(true);
});
}
- private Flux<String> streamFluxWithSpringAi(ResolvedRoute route, ChatCompletionRequest req) {
- return llmSpringAiClientService.streamCompletion(route.baseUrl, route.apiKey, req)
+ private Flux<String> streamFluxWithSpringAi(ResolvedRoute route,
+ ChatCompletionRequest req,
+ Consumer<ChatCompletionResponse.Usage> usageConsumer) {
+ return llmSpringAiClientService.streamCompletion(route.baseUrl, route.apiKey, req, usageConsumer)
.doOnError(ex -> log.error("璋冪敤 Spring AI 娴佸紡澶辫触, route={}", route.tag(), ex));
}
@@ -491,6 +504,7 @@
Integer httpStatus,
long latencyMs,
ChatCompletionRequest req,
+ ChatCompletionResponse responseObj,
String response,
String switchMode,
Throwable err,
@@ -514,11 +528,42 @@
item.setResponseContent(cut(response, LOG_TEXT_LIMIT));
item.setErrorType(cut(safeName(err), 128));
item.setErrorMessage(err == null ? null : cut(errorText(err), 1024));
- item.setExtra(cut(extra, 512));
+ item.setExtra(cut(buildExtraPayload(responseObj == null ? null : responseObj.getUsage(), extra), 512));
item.setCreateTime(new Date());
llmCallLogService.saveIgnoreError(item);
}
+ private ChatCompletionResponse usageResponse(ChatCompletionResponse.Usage usage) {
+ if (usage == null) {
+ return null;
+ }
+ ChatCompletionResponse response = new ChatCompletionResponse();
+ response.setUsage(usage);
+ return response;
+ }
+
+ private String buildExtraPayload(ChatCompletionResponse.Usage usage, String extra) {
+ if (usage == null && isBlank(extra)) {
+ return null;
+ }
+ HashMap<String, Object> payload = new HashMap<>();
+ if (usage != null) {
+ if (usage.getPromptTokens() != null) {
+ payload.put("promptTokens", usage.getPromptTokens());
+ }
+ if (usage.getCompletionTokens() != null) {
+ payload.put("completionTokens", usage.getCompletionTokens());
+ }
+ if (usage.getTotalTokens() != null) {
+ payload.put("totalTokens", usage.getTotalTokens());
+ }
+ }
+ if (!isBlank(extra)) {
+ payload.put("note", extra);
+ }
+ return payload.isEmpty() ? null : JSON.toJSONString(payload);
+ }
+
private static class CompletionCallResult {
private final int statusCode;
private final String payload;
--
Gitblit v1.9.1