Junjie
8 天以前 4898d942bd6e3c1119493cf0314b15f2bd54daf3
src/main/java/com/zy/ai/service/LlmChatService.java
@@ -12,6 +12,7 @@
import reactor.core.publisher.Mono;
import reactor.core.publisher.Flux;
import java.util.HashMap;
import java.util.List;
import java.util.function.Consumer;
import java.util.concurrent.LinkedBlockingQueue;
@@ -33,6 +34,9 @@
    @Value("${llm.model}")
    private String model;
    @Value("${llm.pythonPlatformUrl}")
    private String pythonPlatformUrl;
    /**
     * 通用对话方法:传入 messages,返回大模型文本回复
@@ -74,6 +78,47 @@
        }
        return response.getChoices().get(0).getMessage().getContent();
    }
    public ChatCompletionResponse chatCompletion(List<ChatCompletionRequest.Message> messages,
                                                 Double temperature,
                                                 Integer maxTokens,
                                                 List<Object> tools) {
        ChatCompletionRequest req = new ChatCompletionRequest();
        req.setModel(model);
        req.setMessages(messages);
        req.setTemperature(temperature != null ? temperature : 0.3);
        req.setMax_tokens(maxTokens != null ? maxTokens : 1024);
        req.setStream(false);
        if (tools != null && !tools.isEmpty()) {
            req.setTools(tools);
            req.setTool_choice("auto");
        }
        return complete(req);
    }
    public ChatCompletionResponse complete(ChatCompletionRequest req) {
        try {
            return llmWebClient.post()
                    .uri("/chat/completions")
                    .header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey)
                    .contentType(MediaType.APPLICATION_JSON)
                    .accept(MediaType.APPLICATION_JSON, MediaType.TEXT_EVENT_STREAM)
                    .bodyValue(req)
                    .exchangeToMono(resp -> resp.bodyToFlux(String.class)
                            .collectList()
                            .map(list -> {
                                String payload = String.join("\n\n", list);
                                return parseCompletion(payload);
                            }))
                    .doOnError(ex -> log.error("调用 LLM 失败", ex))
                    .onErrorResume(ex -> Mono.empty())
                    .block();
        } catch (Exception e) {
            log.error("调用 LLM 失败", e);
            return null;
        }
    }
    public void chatStream(List<ChatCompletionRequest.Message> messages,
@@ -172,6 +217,207 @@
        });
    }
    public void chatStreamWithTools(List<ChatCompletionRequest.Message> messages,
                                    Double temperature,
                                    Integer maxTokens,
                                    List<Object> tools,
                                    Consumer<String> onChunk,
                                    Runnable onComplete,
                                    Consumer<Throwable> onError) {
        ChatCompletionRequest req = new ChatCompletionRequest();
        req.setModel(model);
        req.setMessages(messages);
        req.setTemperature(temperature != null ? temperature : 0.3);
        req.setMax_tokens(maxTokens != null ? maxTokens : 1024);
        req.setStream(true);
        if (tools != null && !tools.isEmpty()) {
            req.setTools(tools);
            req.setTool_choice("auto");
        }
        Flux<String> flux = llmWebClient.post()
                .uri("/chat/completions")
                .header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey)
                .contentType(MediaType.APPLICATION_JSON)
                .accept(MediaType.TEXT_EVENT_STREAM)
                .bodyValue(req)
                .retrieve()
                .bodyToFlux(String.class)
                .doOnError(ex -> log.error("调用 LLM 流式失败", ex));
        AtomicBoolean doneSeen = new AtomicBoolean(false);
        AtomicBoolean errorSeen = new AtomicBoolean(false);
        LinkedBlockingQueue<String> queue = new LinkedBlockingQueue<>();
        Thread drain = new Thread(() -> {
            try {
                while (true) {
                    String s = queue.poll(5, TimeUnit.SECONDS);
                    if (s != null) {
                        try { onChunk.accept(s); } catch (Exception ignore) {}
                    }
                    if (doneSeen.get() && queue.isEmpty()) {
                        if (!errorSeen.get()) {
                            try { if (onComplete != null) onComplete.run(); } catch (Exception ignore) {}
                        }
                        break;
                    }
                }
            } catch (InterruptedException ignore) {
                ignore.printStackTrace();
            }
        });
        drain.setDaemon(true);
        drain.start();
        flux.subscribe(payload -> {
            if (payload == null || payload.isEmpty()) return;
            String[] events = payload.split("\\r?\\n\\r?\\n");
            for (String part : events) {
                String s = part;
                if (s == null || s.isEmpty()) continue;
                if (s.startsWith("data:")) {
                    s = s.substring(5);
                    if (s.startsWith(" ")) s = s.substring(1);
                }
                if ("[DONE]".equals(s.trim())) {
                    doneSeen.set(true);
                    continue;
                }
                try {
                    JSONObject obj = JSON.parseObject(s);
                    JSONArray choices = obj.getJSONArray("choices");
                    if (choices != null && !choices.isEmpty()) {
                        JSONObject c0 = choices.getJSONObject(0);
                        JSONObject delta = c0.getJSONObject("delta");
                        if (delta != null) {
                            String content = delta.getString("content");
                            if (content != null) {
                                try { queue.offer(content); } catch (Exception ignore) {}
                            }
                        }
                    }
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }, err -> {
            errorSeen.set(true);
            doneSeen.set(true);
            if (onError != null) onError.accept(err);
        }, () -> {
            if (!doneSeen.get()) {
                errorSeen.set(true);
                doneSeen.set(true);
                if (onError != null) onError.accept(new RuntimeException("LLM 流意外完成"));
            } else {
                doneSeen.set(true);
            }
        });
    }
    public void chatStreamRunPython(String prompt, String chatId, Consumer<String> onChunk,
                                    Runnable onComplete,
                                    Consumer<Throwable> onError) {
        HashMap<String, Object> req = new HashMap<>();
        req.put("prompt", prompt);
        req.put("chatId", chatId);
        Flux<String> flux = llmWebClient.post()
                .uri(pythonPlatformUrl)
                .header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey)
                .contentType(MediaType.APPLICATION_JSON)
                .accept(MediaType.TEXT_EVENT_STREAM)
                .bodyValue(req)
                .retrieve()
                .bodyToFlux(String.class)
                .doOnError(ex -> log.error("调用 LLM 流式失败", ex));
        AtomicBoolean doneSeen = new AtomicBoolean(false);
        AtomicBoolean errorSeen = new AtomicBoolean(false);
        LinkedBlockingQueue<String> queue = new LinkedBlockingQueue<>();
        Thread drain = new Thread(() -> {
            try {
                while (true) {
                    String s = queue.poll(2, TimeUnit.SECONDS);
                    if (s != null) {
                        try {
                            onChunk.accept(s);
                        } catch (Exception ignore) {
                        }
                    }
                    if (doneSeen.get() && queue.isEmpty()) {
                        if (!errorSeen.get()) {
                            try {
                                if (onComplete != null) onComplete.run();
                            } catch (Exception ignore) {
                            }
                        }
                        break;
                    }
                }
            } catch (InterruptedException ignore) {
                ignore.printStackTrace();
            }
        });
        drain.setDaemon(true);
        drain.start();
        flux.subscribe(payload -> {
            if (payload == null || payload.isEmpty()) return;
            String[] events = payload.split("\\r?\\n\\r?\\n");
            for (String part : events) {
                String s = part;
                if (s == null || s.isEmpty()) continue;
                if (s.startsWith("data:")) {
                    s = s.substring(5);
                    if (s.startsWith(" ")) s = s.substring(1);
                }
                if ("[DONE]".equals(s.trim())) {
                    doneSeen.set(true);
                    continue;
                }
                if("<think>".equals(s.trim()) || "</think>".equals(s.trim())) {
                    queue.offer(s.trim());
                    continue;
                }
                try {
                    JSONObject obj = JSON.parseObject(s);
                    JSONArray choices = obj.getJSONArray("choices");
                    if (choices != null && !choices.isEmpty()) {
                        JSONObject c0 = choices.getJSONObject(0);
                        JSONObject delta = c0.getJSONObject("delta");
                        if (delta != null) {
                            String content = delta.getString("content");
                            if (content != null) {
                                try {
                                    queue.offer(content);
                                } catch (Exception ignore) {
                                }
                            }
                        }
                    }
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }, err -> {
            errorSeen.set(true);
            doneSeen.set(true);
            if (onError != null) onError.accept(err);
        }, () -> {
            if (!doneSeen.get()) {
                errorSeen.set(true);
                doneSeen.set(true);
                if (onError != null) onError.accept(new RuntimeException("LLM 流意外完成"));
            } else {
                doneSeen.set(true);
            }
        });
    }
    private ChatCompletionResponse mergeSseChunk(ChatCompletionResponse acc, String payload) {
        if (payload == null || payload.isEmpty()) return acc;
        String[] events = payload.split("\\r?\\n\\r?\\n");