package com.zy.ai.service;
|
|
import com.alibaba.fastjson.JSON;
|
import com.alibaba.fastjson.JSONObject;
|
import com.zy.ai.entity.ChatCompletionRequest;
|
import com.zy.ai.entity.ChatCompletionResponse;
|
import com.zy.ai.entity.WcsDiagnosisRequest;
|
import com.zy.ai.mcp.controller.McpController;
|
import com.zy.ai.utils.AiPromptUtils;
|
import com.zy.ai.utils.AiUtils;
|
import com.zy.common.utils.RedisUtil;
|
import com.zy.core.enums.RedisKeyType;
|
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Value;
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
import lombok.RequiredArgsConstructor;
|
import lombok.extern.slf4j.Slf4j;
|
import org.springframework.stereotype.Service;
|
|
import java.util.ArrayList;
|
import java.util.List;
|
import java.util.Map;
|
|
@Service
|
@RequiredArgsConstructor
|
@Slf4j
|
public class WcsDiagnosisService {
|
|
private static final long CHAT_TTL_SECONDS = 7L * 24 * 3600;
|
|
@Value("${llm.platform}")
|
private String platform;
|
@Autowired
|
private LlmChatService llmChatService;
|
@Autowired
|
private RedisUtil redisUtil;
|
@Autowired
|
private AiPromptUtils aiPromptUtils;
|
@Autowired
|
private AiUtils aiUtils;
|
@Autowired(required = false)
|
private McpController mcpController;
|
@Autowired
|
private PythonService pythonService;
|
|
public void diagnoseStream(WcsDiagnosisRequest request, SseEmitter emitter) {
|
List<ChatCompletionRequest.Message> messages = new ArrayList<>();
|
|
ChatCompletionRequest.Message mcpSystem = new ChatCompletionRequest.Message();
|
mcpSystem.setRole("system");
|
mcpSystem.setContent(aiPromptUtils.getAiDiagnosePromptMcp());
|
|
ChatCompletionRequest.Message mcpUser = new ChatCompletionRequest.Message();
|
mcpUser.setRole("user");
|
mcpUser.setContent(aiUtils.buildDiagnosisUserContentMcp(request));
|
|
if (runMcpStreamingDiagnosis(messages, mcpSystem, mcpUser, 0.3, 2048, emitter, null)) {
|
return;
|
}
|
|
messages = new ArrayList<>();
|
ChatCompletionRequest.Message system = new ChatCompletionRequest.Message();
|
system.setRole("system");
|
system.setContent(aiPromptUtils.getAiDiagnosePrompt());
|
messages.add(system);
|
|
ChatCompletionRequest.Message user = new ChatCompletionRequest.Message();
|
user.setRole("user");
|
user.setContent(aiUtils.buildDiagnosisUserContent(request));
|
messages.add(user);
|
|
llmChatService.chatStream(messages, 0.3, 2048, s -> {
|
try {
|
String safe = s == null ? "" : s.replace("\r", "").replace("\n", "\\n");
|
if (!safe.isEmpty()) {
|
emitter.send(SseEmitter.event().data(safe));
|
}
|
} catch (Exception ignore) {}
|
}, () -> {
|
try {
|
log.info("AI diagnose stream stopped: normal end");
|
emitter.complete();
|
} catch (Exception ignore) {}
|
}, e -> {
|
try {
|
try { emitter.send(SseEmitter.event().data("【AI】运行已停止(异常)")); } catch (Exception ignore) {}
|
log.error("AI diagnose stream stopped: error", e);
|
emitter.completeWithError(e);
|
} catch (Exception ignore) {}
|
});
|
}
|
|
public void askStream(WcsDiagnosisRequest request,
|
String prompt,
|
String chatId,
|
boolean reset,
|
SseEmitter emitter) {
|
if (platform.equals("python")) {
|
pythonService.runPython(prompt, chatId, emitter);
|
return;
|
}
|
|
List<ChatCompletionRequest.Message> messages = new ArrayList<>();
|
|
List<ChatCompletionRequest.Message> history = null;
|
String historyKey = null;
|
String metaKey = null;
|
if (chatId != null && !chatId.isEmpty()) {
|
historyKey = RedisKeyType.AI_CHAT_HISTORY.key + chatId;
|
metaKey = RedisKeyType.AI_CHAT_META.key + chatId;
|
if (reset) {
|
redisUtil.del(historyKey, metaKey);
|
}
|
List<Object> stored = redisUtil.lGet(historyKey, 0, -1);
|
if (stored != null && !stored.isEmpty()) {
|
history = new ArrayList<>(stored.size());
|
for (Object o : stored) {
|
ChatCompletionRequest.Message m = convertToMessage(o);
|
if (m != null) history.add(m);
|
}
|
if (!history.isEmpty()) messages.addAll(history);
|
} else {
|
history = new ArrayList<>();
|
}
|
}
|
|
StringBuilder assistantBuffer = new StringBuilder();
|
final String finalChatId = chatId;
|
final String finalHistoryKey = historyKey;
|
final String finalMetaKey = metaKey;
|
final String finalPrompt = prompt;
|
|
ChatCompletionRequest.Message mcpSystem = new ChatCompletionRequest.Message();
|
mcpSystem.setRole("system");
|
mcpSystem.setContent(aiPromptUtils.getWcsSensorPromptMcp());
|
|
ChatCompletionRequest.Message mcpUser = new ChatCompletionRequest.Message();
|
mcpUser.setRole("user");
|
mcpUser.setContent("【用户提问】\n" + (prompt == null ? "" : prompt));
|
|
if (runMcpStreamingDiagnosis(messages, mcpSystem, mcpUser, 0.3, 2048, emitter, finalChatId)) {
|
return;
|
}
|
|
messages = new ArrayList<>();
|
ChatCompletionRequest.Message system = new ChatCompletionRequest.Message();
|
system.setRole("system");
|
system.setContent(aiPromptUtils.getWcsSensorPrompt());
|
messages.add(system);
|
|
ChatCompletionRequest.Message questionMsg = new ChatCompletionRequest.Message();
|
questionMsg.setRole("user");
|
questionMsg.setContent("【用户提问】\n" + (prompt == null ? "" : prompt));
|
messages.add(questionMsg);
|
|
llmChatService.chatStream(messages, 0.3, 2048, s -> {
|
try {
|
String safe = s == null ? "" : s.replace("\r", "").replace("\n", "\\n");
|
if (!safe.isEmpty()) {
|
emitter.send(SseEmitter.event().data(safe));
|
assistantBuffer.append(s);
|
}
|
} catch (Exception ignore) {}
|
}, () -> {
|
try {
|
if (finalChatId != null && !finalChatId.isEmpty()) {
|
ChatCompletionRequest.Message q = new ChatCompletionRequest.Message();
|
q.setRole("user");
|
q.setContent(finalPrompt == null ? "" : finalPrompt);
|
ChatCompletionRequest.Message a = new ChatCompletionRequest.Message();
|
a.setRole("assistant");
|
a.setContent(assistantBuffer.toString());
|
redisUtil.lSet(finalHistoryKey, q);
|
redisUtil.lSet(finalHistoryKey, a);
|
redisUtil.expire(finalHistoryKey, CHAT_TTL_SECONDS);
|
Map<Object, Object> old = redisUtil.hmget(finalMetaKey);
|
Long createdAt = old != null && old.get("createdAt") != null ?
|
(old.get("createdAt") instanceof Number ? ((Number) old.get("createdAt")).longValue() : Long.valueOf(String.valueOf(old.get("createdAt"))))
|
: System.currentTimeMillis();
|
Map<String, Object> meta = new java.util.HashMap<>();
|
meta.put("chatId", finalChatId);
|
meta.put("title", buildTitleFromPrompt(finalPrompt));
|
meta.put("createdAt", createdAt);
|
meta.put("updatedAt", System.currentTimeMillis());
|
redisUtil.hmset(finalMetaKey, meta, CHAT_TTL_SECONDS);
|
}
|
emitter.complete();
|
} catch (Exception ignore) {}
|
}, e -> {
|
try { emitter.completeWithError(e); } catch (Exception ignore) {}
|
});
|
}
|
|
public List<Map<String, Object>> listChats() {
|
java.util.Set<String> keys = redisUtil.scanKeys(RedisKeyType.AI_CHAT_META.key, 1000);
|
List<Map<String, Object>> resp = new ArrayList<>();
|
if (keys != null) {
|
for (String key : keys) {
|
Map<Object, Object> m = redisUtil.hmget(key);
|
if (m != null && !m.isEmpty()) {
|
java.util.HashMap<String, Object> item = new java.util.HashMap<>();
|
for (Map.Entry<Object, Object> e : m.entrySet()) {
|
item.put(String.valueOf(e.getKey()), e.getValue());
|
}
|
String chatId = String.valueOf(item.get("chatId"));
|
String historyKey = RedisKeyType.AI_CHAT_HISTORY.key + chatId;
|
item.put("size", redisUtil.lGetListSize(historyKey));
|
resp.add(item);
|
}
|
}
|
}
|
return resp;
|
}
|
|
public boolean deleteChat(String chatId) {
|
if (chatId == null || chatId.isEmpty()) return false;
|
String historyKey = RedisKeyType.AI_CHAT_HISTORY.key + chatId;
|
String metaKey = RedisKeyType.AI_CHAT_META.key + chatId;
|
redisUtil.del(historyKey, metaKey);
|
return true;
|
}
|
|
public List<ChatCompletionRequest.Message> getChatHistory(String chatId) {
|
if (chatId == null || chatId.isEmpty()) return java.util.Collections.emptyList();
|
String historyKey = RedisKeyType.AI_CHAT_HISTORY.key + chatId;
|
List<Object> stored = redisUtil.lGet(historyKey, 0, -1);
|
List<ChatCompletionRequest.Message> result = new ArrayList<>();
|
if (stored != null) {
|
for (Object o : stored) {
|
ChatCompletionRequest.Message m = convertToMessage(o);
|
if (m != null) result.add(m);
|
}
|
}
|
return result;
|
}
|
|
private ChatCompletionRequest.Message convertToMessage(Object o) {
|
if (o instanceof ChatCompletionRequest.Message) {
|
return (ChatCompletionRequest.Message) o;
|
}
|
if (o instanceof Map) {
|
Map<?, ?> map = (Map<?, ?>) o;
|
ChatCompletionRequest.Message m = new ChatCompletionRequest.Message();
|
Object role = map.get("role");
|
Object content = map.get("content");
|
m.setRole(role == null ? null : String.valueOf(role));
|
m.setContent(content == null ? null : String.valueOf(content));
|
return m;
|
}
|
return null;
|
}
|
|
private String buildTitleFromPrompt(String prompt) {
|
if (prompt == null || prompt.isEmpty()) return "未命名会话";
|
String p = prompt.replaceAll("\n", " ").trim();
|
return p.length() > 20 ? p.substring(0, 20) : p;
|
}
|
|
private boolean runMcpStreamingDiagnosis(List<ChatCompletionRequest.Message> baseMessages,
|
ChatCompletionRequest.Message systemPrompt,
|
ChatCompletionRequest.Message userQuestion,
|
Double temperature,
|
Integer maxTokens,
|
SseEmitter emitter,
|
String chatId) {
|
try {
|
if (mcpController == null) return false;
|
List<Object> tools = buildOpenAiTools();
|
if (tools.isEmpty()) return false;
|
|
baseMessages.add(systemPrompt);
|
baseMessages.add(userQuestion);
|
|
List<ChatCompletionRequest.Message> messages = new ArrayList<>(baseMessages.size() + 8);
|
messages.addAll(baseMessages);
|
|
sse(emitter, "<think>\\n正在初始化诊断与工具环境...\\n");
|
|
int maxRound = 10;
|
int i = 0;
|
while(true) {
|
sse(emitter, "\\n正在分析(第" + (i + 1) + "轮)...\\n");
|
ChatCompletionResponse resp = llmChatService.chatCompletion(messages, temperature, maxTokens, tools);
|
if (resp == null || resp.getChoices() == null || resp.getChoices().isEmpty() || resp.getChoices().get(0).getMessage() == null) {
|
sse(emitter, "\\n分析出错,正在回退...\\n");
|
return false;
|
}
|
|
ChatCompletionRequest.Message assistant = resp.getChoices().get(0).getMessage();
|
messages.add(assistant);
|
sse(emitter, assistant.getContent());
|
|
List<ChatCompletionRequest.ToolCall> toolCalls = assistant.getTool_calls();
|
if (toolCalls == null || toolCalls.isEmpty()) {
|
break;
|
}
|
|
for (ChatCompletionRequest.ToolCall tc : toolCalls) {
|
String toolName = tc != null && tc.getFunction() != null ? tc.getFunction().getName() : null;
|
if (toolName == null || toolName.trim().isEmpty()) continue;
|
sse(emitter, "\\n准备调用工具:" + toolName + "\\n");
|
JSONObject args = new JSONObject();
|
if (tc.getFunction() != null && tc.getFunction().getArguments() != null && !tc.getFunction().getArguments().trim().isEmpty()) {
|
try {
|
args = JSON.parseObject(tc.getFunction().getArguments());
|
} catch (Exception ignore) {
|
args = new JSONObject();
|
args.put("_raw", tc.getFunction().getArguments());
|
}
|
}
|
Object output;
|
try {
|
output = mcpController.callTool(toolName, args);
|
} catch (Exception e) {
|
java.util.LinkedHashMap<String, Object> err = new java.util.LinkedHashMap<String, Object>();
|
err.put("tool", toolName);
|
err.put("error", e.getMessage());
|
output = err;
|
}
|
sse(emitter, "\\n工具返回,正在继续推理...\\n");
|
ChatCompletionRequest.Message toolMsg = new ChatCompletionRequest.Message();
|
toolMsg.setRole("tool");
|
toolMsg.setTool_call_id(tc == null ? null : tc.getId());
|
toolMsg.setContent(JSON.toJSONString(output));
|
messages.add(toolMsg);
|
}
|
|
if(i++ >= maxRound) break;
|
}
|
|
sse(emitter, "\\n正在根据数据进行分析...\\n</think>\\n\\n");
|
|
ChatCompletionRequest.Message diagnosisMessage = new ChatCompletionRequest.Message();
|
diagnosisMessage.setRole("system");
|
diagnosisMessage.setContent("根据以上信息进行分析,并给出完整的诊断结论。");
|
messages.add(diagnosisMessage);
|
|
StringBuilder assistantBuffer = new StringBuilder();
|
llmChatService.chatStreamWithTools(messages, temperature, maxTokens, tools, s -> {
|
try {
|
String safe = s == null ? "" : s.replace("\r", "").replace("\n", "\\n");
|
if (!safe.isEmpty()) {
|
sse(emitter, safe);
|
assistantBuffer.append(safe);
|
}
|
} catch (Exception ignore) {}
|
}, () -> {
|
try {
|
sse(emitter, "\\n\\n【AI】运行已停止(正常结束)\\n\\n");
|
log.info("AI MCP diagnose stopped: final end");
|
emitter.complete();
|
|
if (chatId != null) {
|
String historyKey = RedisKeyType.AI_CHAT_HISTORY.key + chatId;
|
String metaKey = RedisKeyType.AI_CHAT_META.key + chatId;
|
|
ChatCompletionRequest.Message a = new ChatCompletionRequest.Message();
|
a.setRole("assistant");
|
a.setContent(assistantBuffer.toString());
|
redisUtil.lSet(historyKey, userQuestion);
|
redisUtil.lSet(historyKey, a);
|
redisUtil.expire(historyKey, CHAT_TTL_SECONDS);
|
Map<Object, Object> old = redisUtil.hmget(metaKey);
|
Long createdAt = old != null && old.get("createdAt") != null ?
|
(old.get("createdAt") instanceof Number ? ((Number) old.get("createdAt")).longValue() : Long.valueOf(String.valueOf(old.get("createdAt"))))
|
: System.currentTimeMillis();
|
Map<String, Object> meta = new java.util.HashMap<>();
|
meta.put("chatId", chatId);
|
meta.put("title", buildTitleFromPrompt(userQuestion.getContent()));
|
meta.put("createdAt", createdAt);
|
meta.put("updatedAt", System.currentTimeMillis());
|
redisUtil.hmset(metaKey, meta, CHAT_TTL_SECONDS);
|
}
|
} catch (Exception ignore) {}
|
}, e -> {
|
sse(emitter, "\\n\\n【AI】分析出错,正在回退...\\n\\n");
|
});
|
return true;
|
} catch (Exception e) {
|
try {
|
sse(emitter, "\\n\\n【AI】运行已停止(异常)\\n\\n");
|
log.error("AI MCP diagnose stopped: error", e);
|
emitter.completeWithError(e);
|
} catch (Exception ignore) {}
|
return true;
|
}
|
}
|
|
private void sse(SseEmitter emitter, String data) {
|
if (data == null) return;
|
try {
|
emitter.send(SseEmitter.event().data(data));
|
} catch (Exception e) {
|
log.warn("SSE send failed", e);
|
}
|
}
|
|
private List<Object> buildOpenAiTools() {
|
if (mcpController == null) return java.util.Collections.emptyList();
|
List<Map<String, Object>> mcpTools = mcpController.listTools();
|
if (mcpTools == null || mcpTools.isEmpty()) return java.util.Collections.emptyList();
|
|
List<Object> tools = new ArrayList<>();
|
for (Map<String, Object> t : mcpTools) {
|
if (t == null) continue;
|
Object name = t.get("name");
|
if (name == null) continue;
|
Object inputSchema = t.get("inputSchema");
|
java.util.LinkedHashMap<String, Object> function = new java.util.LinkedHashMap<String, Object>();
|
function.put("name", String.valueOf(name));
|
Object desc = t.get("description");
|
if (desc != null) function.put("description", String.valueOf(desc));
|
function.put("parameters", inputSchema == null ? new java.util.LinkedHashMap<String, Object>() : inputSchema);
|
|
java.util.LinkedHashMap<String, Object> tool = new java.util.LinkedHashMap<String, Object>();
|
tool.put("type", "function");
|
tool.put("function", function);
|
tools.add(tool);
|
}
|
return tools;
|
}
|
|
private void sendLargeText(SseEmitter emitter, String text) {
|
if (text == null) return;
|
String safe = text.replace("\r", "").replace("\n", "\\n");
|
int chunkSize = 256;
|
int i = 0;
|
while (i < safe.length()) {
|
int end = Math.min(i + chunkSize, safe.length());
|
String part = safe.substring(i, end);
|
if (!part.isEmpty()) {
|
try { emitter.send(SseEmitter.event().data(part)); } catch (Exception ignore) {}
|
}
|
i = end;
|
}
|
}
|
|
private static final java.util.regex.Pattern DSML_INVOKE_PATTERN =
|
java.util.regex.Pattern.compile("<\\uFF5CDSML\\uFF5Cinvoke\\s+name=\\\"([^\\\"]+)\\\"[^>]*>([\\s\\S]*?)</\\uFF5CDSML\\uFF5Cinvoke>", java.util.regex.Pattern.MULTILINE);
|
private static final java.util.regex.Pattern JSON_OBJECT_PATTERN =
|
java.util.regex.Pattern.compile("\\{[\\s\\S]*\\}");
|
private static final java.util.regex.Pattern DSML_PARAM_PATTERN =
|
java.util.regex.Pattern.compile("<\\uFF5CDSML\\uFF5Cparameter\\s+name=\\\"([^\\\"]+)\\\"\\s*([^>]*)>([\\s\\S]*?)</\\uFF5CDSML\\uFF5Cparameter>", java.util.regex.Pattern.MULTILINE);
|
|
private java.util.List<DsmlInvocation> parseDsmlInvocations(String content) {
|
java.util.List<DsmlInvocation> list = new java.util.ArrayList<>();
|
if (content == null || content.isEmpty()) return list;
|
java.util.regex.Matcher m = DSML_INVOKE_PATTERN.matcher(content);
|
while (m.find()) {
|
String name = m.group(1);
|
String inner = m.group(2);
|
com.alibaba.fastjson.JSONObject args = null;
|
if (inner != null) {
|
java.util.regex.Matcher jm = JSON_OBJECT_PATTERN.matcher(inner);
|
if (jm.find()) {
|
String json = jm.group();
|
try { args = com.alibaba.fastjson.JSON.parseObject(json); } catch (Exception ignore) {}
|
}
|
java.util.regex.Matcher pm = DSML_PARAM_PATTERN.matcher(inner);
|
while (pm.find()) {
|
if (args == null) args = new com.alibaba.fastjson.JSONObject();
|
String pName = pm.group(1);
|
String attr = pm.group(2);
|
String valText = pm.group(3);
|
boolean isString = attr != null && attr.toLowerCase().contains("string=\"true\"");
|
String t = valText == null ? "" : valText.trim();
|
if (isString) {
|
args.put(pName, t);
|
} else {
|
if ("true".equalsIgnoreCase(t) || "false".equalsIgnoreCase(t)) {
|
args.put(pName, Boolean.valueOf(t));
|
} else {
|
try {
|
if (t.contains(".")) {
|
args.put(pName, Double.valueOf(t));
|
} else {
|
args.put(pName, Long.valueOf(t));
|
}
|
} catch (Exception ex) {
|
args.put(pName, t);
|
}
|
}
|
}
|
}
|
}
|
DsmlInvocation inv = new DsmlInvocation();
|
inv.name = name;
|
inv.arguments = args;
|
list.add(inv);
|
}
|
return list;
|
}
|
|
private static class DsmlInvocation {
|
String name;
|
com.alibaba.fastjson.JSONObject arguments;
|
}
|
|
private List<DsmlInvocation> buildDefaultStatusInvocations() {
|
List<DsmlInvocation> list = new ArrayList<>();
|
DsmlInvocation crn = new DsmlInvocation();
|
crn.name = "device.get_crn_status";
|
com.alibaba.fastjson.JSONObject a1 = new com.alibaba.fastjson.JSONObject();
|
a1.put("limit", 20);
|
crn.arguments = a1;
|
list.add(crn);
|
|
DsmlInvocation st = new DsmlInvocation();
|
st.name = "device.get_station_status";
|
com.alibaba.fastjson.JSONObject a2 = new com.alibaba.fastjson.JSONObject();
|
a2.put("limit", 20);
|
st.arguments = a2;
|
list.add(st);
|
|
DsmlInvocation rgv = new DsmlInvocation();
|
rgv.name = "device.get_rgv_status";
|
com.alibaba.fastjson.JSONObject a3 = new com.alibaba.fastjson.JSONObject();
|
a3.put("limit", 20);
|
rgv.arguments = a3;
|
list.add(rgv);
|
|
return list;
|
}
|
|
private void ensureStatusCoverage(List<DsmlInvocation> invs) {
|
if (invs == null) return;
|
java.util.Set<String> names = new java.util.HashSet<String>();
|
for (DsmlInvocation d : invs) {
|
if (d != null && d.name != null) names.add(d.name);
|
}
|
if (!names.contains("device.get_crn_status") || !names.contains("device.get_station_status") || !names.contains("device.get_rgv_status")) {
|
List<DsmlInvocation> defaults = buildDefaultStatusInvocations();
|
for (DsmlInvocation d : defaults) {
|
if (!names.contains(d.name)) invs.add(d);
|
}
|
}
|
}
|
|
private boolean isConclusionText(String content) {
|
if (content == null) return false;
|
String c = content;
|
int len = c.length();
|
boolean longEnough = len >= 200;
|
boolean hasAllSections = c.contains("问题概述") && c.contains("可疑设备列表") && c.contains("可能原因") && c.contains("建议排查步骤") && c.contains("风险评估");
|
boolean hasExplicitConclusion = (c.contains("结论") || c.contains("诊断结果")) && longEnough;
|
return hasAllSections || hasExplicitConclusion;
|
}
|
}
|