package com.zy.ai.service;
|
|
import com.zy.ai.entity.ChatCompletionRequest;
|
import com.zy.ai.entity.ChatCompletionResponse;
|
import lombok.RequiredArgsConstructor;
|
import lombok.extern.slf4j.Slf4j;
|
import org.springframework.beans.factory.annotation.Value;
|
import org.springframework.http.HttpHeaders;
|
import org.springframework.http.MediaType;
|
import org.springframework.stereotype.Service;
|
import org.springframework.web.reactive.function.client.WebClient;
|
import reactor.core.publisher.Mono;
|
import reactor.core.publisher.Flux;
|
|
import java.util.List;
|
import java.util.function.Consumer;
|
import java.util.concurrent.LinkedBlockingQueue;
|
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.atomic.AtomicBoolean;
|
import com.alibaba.fastjson.JSON;
|
import com.alibaba.fastjson.JSONArray;
|
import com.alibaba.fastjson.JSONObject;
|
|
@Slf4j
|
@Service
|
@RequiredArgsConstructor
|
public class LlmChatService {
|
|
private final WebClient llmWebClient;
|
|
@Value("${llm.api-key}")
|
private String apiKey;
|
|
@Value("${llm.model}")
|
private String model;
|
|
/**
|
* 通用对话方法:传入 messages,返回大模型文本回复
|
*/
|
public String chat(List<ChatCompletionRequest.Message> messages,
|
Double temperature,
|
Integer maxTokens) {
|
|
ChatCompletionRequest req = new ChatCompletionRequest();
|
req.setModel(model);
|
req.setMessages(messages);
|
req.setTemperature(temperature != null ? temperature : 0.3);
|
req.setMax_tokens(maxTokens != null ? maxTokens : 1024);
|
req.setStream(false);
|
|
ChatCompletionResponse response = llmWebClient.post()
|
.uri("/chat/completions")
|
.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey)
|
.contentType(MediaType.APPLICATION_JSON)
|
.accept(MediaType.APPLICATION_JSON, MediaType.TEXT_EVENT_STREAM)
|
.bodyValue(req)
|
.exchangeToMono(resp -> resp.bodyToFlux(String.class)
|
.collectList()
|
.map(list -> {
|
String payload = String.join("\n\n", list);
|
return parseCompletion(payload);
|
}))
|
.doOnError(ex -> log.error("调用 LLM 失败", ex))
|
.onErrorResume(ex -> Mono.empty())
|
.block();
|
|
if (response == null ||
|
response.getChoices() == null ||
|
response.getChoices().isEmpty() ||
|
response.getChoices().get(0).getMessage() == null ||
|
response.getChoices().get(0).getMessage().getContent() == null ||
|
response.getChoices().get(0).getMessage().getContent().isEmpty()) {
|
return null;
|
}
|
|
return response.getChoices().get(0).getMessage().getContent();
|
}
|
|
public void chatStream(List<ChatCompletionRequest.Message> messages,
|
Double temperature,
|
Integer maxTokens,
|
Consumer<String> onChunk,
|
Runnable onComplete,
|
Consumer<Throwable> onError) {
|
|
ChatCompletionRequest req = new ChatCompletionRequest();
|
req.setModel(model);
|
req.setMessages(messages);
|
req.setTemperature(temperature != null ? temperature : 0.3);
|
req.setMax_tokens(maxTokens != null ? maxTokens : 1024);
|
req.setStream(true);
|
|
|
Flux<String> flux = llmWebClient.post()
|
.uri("/chat/completions")
|
.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey)
|
.contentType(MediaType.APPLICATION_JSON)
|
.accept(MediaType.TEXT_EVENT_STREAM)
|
.bodyValue(req)
|
.retrieve()
|
.bodyToFlux(String.class)
|
.doOnError(ex -> log.error("调用 LLM 流式失败", ex));
|
|
AtomicBoolean doneSeen = new AtomicBoolean(false);
|
AtomicBoolean errorSeen = new AtomicBoolean(false);
|
LinkedBlockingQueue<String> queue = new LinkedBlockingQueue<>();
|
|
Thread drain = new Thread(() -> {
|
try {
|
while (true) {
|
String s = queue.poll(2, TimeUnit.SECONDS);
|
if (s != null) {
|
try { onChunk.accept(s); } catch (Exception ignore) {}
|
}
|
if (doneSeen.get() && queue.isEmpty()) {
|
if (!errorSeen.get()) {
|
try { if (onComplete != null) onComplete.run(); } catch (Exception ignore) {}
|
}
|
break;
|
}
|
}
|
} catch (InterruptedException ignore) {
|
ignore.printStackTrace();
|
}
|
});
|
drain.setDaemon(true);
|
drain.start();
|
|
flux.subscribe(payload -> {
|
if (payload == null || payload.isEmpty()) return;
|
String[] events = payload.split("\\r?\\n\\r?\\n");
|
for (String part : events) {
|
String s = part;
|
if (s == null || s.isEmpty()) continue;
|
if (s.startsWith("data:")) {
|
s = s.substring(5);
|
if (s.startsWith(" ")) s = s.substring(1);
|
}
|
if ("[DONE]".equals(s.trim())) {
|
doneSeen.set(true);
|
continue;
|
}
|
try {
|
JSONObject obj = JSON.parseObject(s);
|
JSONArray choices = obj.getJSONArray("choices");
|
if (choices != null && !choices.isEmpty()) {
|
JSONObject c0 = choices.getJSONObject(0);
|
JSONObject delta = c0.getJSONObject("delta");
|
if (delta != null) {
|
String content = delta.getString("content");
|
if (content != null) {
|
try { queue.offer(content); } catch (Exception ignore) {}
|
}
|
}
|
}
|
} catch (Exception e) {
|
e.printStackTrace();
|
}
|
}
|
}, err -> {
|
errorSeen.set(true);
|
doneSeen.set(true);
|
if (onError != null) onError.accept(err);
|
}, () -> {
|
if (!doneSeen.get()) {
|
errorSeen.set(true);
|
doneSeen.set(true);
|
if (onError != null) onError.accept(new RuntimeException("LLM 流意外完成"));
|
} else {
|
doneSeen.set(true);
|
}
|
});
|
}
|
|
private ChatCompletionResponse mergeSseChunk(ChatCompletionResponse acc, String payload) {
|
if (payload == null || payload.isEmpty()) return acc;
|
String[] events = payload.split("\\r?\\n\\r?\\n");
|
for (String part : events) {
|
String s = part;
|
if (s == null || s.isEmpty()) continue;
|
if (s.startsWith("data:")) {
|
s = s.substring(5);
|
if (s.startsWith(" ")) s = s.substring(1);
|
}
|
if ("[DONE]".equals(s.trim())) {
|
continue;
|
}
|
try {
|
JSONObject obj = JSON.parseObject(s);
|
if (obj == null) continue;
|
JSONArray choices = obj.getJSONArray("choices");
|
if (choices != null && !choices.isEmpty()) {
|
JSONObject c0 = choices.getJSONObject(0);
|
if (acc.getChoices() == null || acc.getChoices().isEmpty()) {
|
ChatCompletionResponse.Choice choice = new ChatCompletionResponse.Choice();
|
ChatCompletionRequest.Message msg = new ChatCompletionRequest.Message();
|
choice.setMessage(msg);
|
java.util.ArrayList<ChatCompletionResponse.Choice> list = new java.util.ArrayList<>();
|
list.add(choice);
|
acc.setChoices(list);
|
}
|
ChatCompletionResponse.Choice choice = acc.getChoices().get(0);
|
ChatCompletionRequest.Message msg = choice.getMessage();
|
if (msg.getRole() == null || msg.getRole().isEmpty()) {
|
msg.setRole("assistant");
|
}
|
JSONObject delta = c0.getJSONObject("delta");
|
if (delta != null) {
|
String c = delta.getString("content");
|
if (c != null) {
|
String prev = msg.getContent();
|
msg.setContent(prev == null ? c : prev + c);
|
}
|
String role = delta.getString("role");
|
if (role != null && !role.isEmpty()) msg.setRole(role);
|
}
|
JSONObject message = c0.getJSONObject("message");
|
if (message != null) {
|
String c = message.getString("content");
|
if (c != null) {
|
String prev = msg.getContent();
|
msg.setContent(prev == null ? c : prev + c);
|
}
|
String role = message.getString("role");
|
if (role != null && !role.isEmpty()) msg.setRole(role);
|
}
|
String fr = c0.getString("finish_reason");
|
if (fr != null && !fr.isEmpty()) choice.setFinishReason(fr);
|
}
|
String id = obj.getString("id");
|
if (id != null && !id.isEmpty()) acc.setId(id);
|
Long created = obj.getLong("created");
|
if (created != null) acc.setCreated(created);
|
String object = obj.getString("object");
|
if (object != null && !object.isEmpty()) acc.setObjectName(object);
|
} catch (Exception ignore) {}
|
}
|
return acc;
|
}
|
|
private ChatCompletionResponse parseCompletion(String payload) {
|
if (payload == null) return null;
|
try {
|
ChatCompletionResponse r = JSON.parseObject(payload, ChatCompletionResponse.class);
|
if (r != null && r.getChoices() != null && !r.getChoices().isEmpty() && r.getChoices().get(0).getMessage() != null) {
|
return r;
|
}
|
} catch (Exception ignore) {}
|
ChatCompletionResponse sse = mergeSseChunk(new ChatCompletionResponse(), payload);
|
if (sse.getChoices() != null && !sse.getChoices().isEmpty() && sse.getChoices().get(0).getMessage() != null && sse.getChoices().get(0).getMessage().getContent() != null) {
|
return sse;
|
}
|
ChatCompletionResponse r = new ChatCompletionResponse();
|
ChatCompletionResponse.Choice choice = new ChatCompletionResponse.Choice();
|
ChatCompletionRequest.Message msg = new ChatCompletionRequest.Message();
|
msg.setRole("assistant");
|
msg.setContent(payload);
|
choice.setMessage(msg);
|
java.util.ArrayList<ChatCompletionResponse.Choice> list = new java.util.ArrayList<>();
|
list.add(choice);
|
r.setChoices(list);
|
return r;
|
}
|
}
|