package com.zy.ai.service; import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONObject; import com.zy.ai.entity.ChatCompletionRequest; import com.zy.ai.entity.ChatCompletionResponse; import com.zy.ai.entity.WcsDiagnosisRequest; import com.zy.ai.mcp.controller.McpController; import com.zy.ai.utils.AiPromptUtils; import com.zy.ai.utils.AiUtils; import com.zy.common.utils.RedisUtil; import com.zy.core.enums.RedisKeyType; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; import java.util.ArrayList; import java.util.List; import java.util.Map; @Service @RequiredArgsConstructor @Slf4j public class WcsDiagnosisService { private static final long CHAT_TTL_SECONDS = 7L * 24 * 3600; @Value("${llm.platform}") private String platform; @Autowired private LlmChatService llmChatService; @Autowired private RedisUtil redisUtil; @Autowired private AiPromptUtils aiPromptUtils; @Autowired private AiUtils aiUtils; @Autowired(required = false) private McpController mcpController; @Autowired private PythonService pythonService; public void diagnoseStream(WcsDiagnosisRequest request, SseEmitter emitter) { List messages = new ArrayList<>(); ChatCompletionRequest.Message mcpSystem = new ChatCompletionRequest.Message(); mcpSystem.setRole("system"); mcpSystem.setContent(aiPromptUtils.getAiDiagnosePromptMcp()); ChatCompletionRequest.Message mcpUser = new ChatCompletionRequest.Message(); mcpUser.setRole("user"); mcpUser.setContent(aiUtils.buildDiagnosisUserContentMcp(request)); if (runMcpStreamingDiagnosis(messages, mcpSystem, mcpUser, 0.3, 2048, emitter, null)) { return; } messages = new ArrayList<>(); ChatCompletionRequest.Message system = new ChatCompletionRequest.Message(); system.setRole("system"); system.setContent(aiPromptUtils.getAiDiagnosePrompt()); messages.add(system); ChatCompletionRequest.Message user = new ChatCompletionRequest.Message(); user.setRole("user"); user.setContent(aiUtils.buildDiagnosisUserContent(request)); messages.add(user); llmChatService.chatStream(messages, 0.3, 2048, s -> { try { String safe = s == null ? "" : s.replace("\r", "").replace("\n", "\\n"); if (!safe.isEmpty()) { emitter.send(SseEmitter.event().data(safe)); } } catch (Exception ignore) {} }, () -> { try { log.info("AI diagnose stream stopped: normal end"); emitter.complete(); } catch (Exception ignore) {} }, e -> { try { try { emitter.send(SseEmitter.event().data("【AI】运行已停止(异常)")); } catch (Exception ignore) {} log.error("AI diagnose stream stopped: error", e); emitter.completeWithError(e); } catch (Exception ignore) {} }); } public void askStream(WcsDiagnosisRequest request, String prompt, String chatId, boolean reset, SseEmitter emitter) { if (platform.equals("python")) { pythonService.runPython(prompt, chatId, emitter); return; } List messages = new ArrayList<>(); List history = null; String historyKey = null; String metaKey = null; if (chatId != null && !chatId.isEmpty()) { historyKey = RedisKeyType.AI_CHAT_HISTORY.key + chatId; metaKey = RedisKeyType.AI_CHAT_META.key + chatId; if (reset) { redisUtil.del(historyKey, metaKey); } List stored = redisUtil.lGet(historyKey, 0, -1); if (stored != null && !stored.isEmpty()) { history = new ArrayList<>(stored.size()); for (Object o : stored) { ChatCompletionRequest.Message m = convertToMessage(o); if (m != null) history.add(m); } if (!history.isEmpty()) messages.addAll(history); } else { history = new ArrayList<>(); } } StringBuilder assistantBuffer = new StringBuilder(); final String finalChatId = chatId; final String finalHistoryKey = historyKey; final String finalMetaKey = metaKey; final String finalPrompt = prompt; ChatCompletionRequest.Message mcpSystem = new ChatCompletionRequest.Message(); mcpSystem.setRole("system"); mcpSystem.setContent(aiPromptUtils.getWcsSensorPromptMcp()); ChatCompletionRequest.Message mcpUser = new ChatCompletionRequest.Message(); mcpUser.setRole("user"); mcpUser.setContent("【用户提问】\n" + (prompt == null ? "" : prompt)); if (runMcpStreamingDiagnosis(messages, mcpSystem, mcpUser, 0.3, 2048, emitter, finalChatId)) { return; } messages = new ArrayList<>(); ChatCompletionRequest.Message system = new ChatCompletionRequest.Message(); system.setRole("system"); system.setContent(aiPromptUtils.getWcsSensorPrompt()); messages.add(system); ChatCompletionRequest.Message questionMsg = new ChatCompletionRequest.Message(); questionMsg.setRole("user"); questionMsg.setContent("【用户提问】\n" + (prompt == null ? "" : prompt)); messages.add(questionMsg); llmChatService.chatStream(messages, 0.3, 2048, s -> { try { String safe = s == null ? "" : s.replace("\r", "").replace("\n", "\\n"); if (!safe.isEmpty()) { emitter.send(SseEmitter.event().data(safe)); assistantBuffer.append(s); } } catch (Exception ignore) {} }, () -> { try { if (finalChatId != null && !finalChatId.isEmpty()) { ChatCompletionRequest.Message q = new ChatCompletionRequest.Message(); q.setRole("user"); q.setContent(finalPrompt == null ? "" : finalPrompt); ChatCompletionRequest.Message a = new ChatCompletionRequest.Message(); a.setRole("assistant"); a.setContent(assistantBuffer.toString()); redisUtil.lSet(finalHistoryKey, q); redisUtil.lSet(finalHistoryKey, a); redisUtil.expire(finalHistoryKey, CHAT_TTL_SECONDS); Map old = redisUtil.hmget(finalMetaKey); Long createdAt = old != null && old.get("createdAt") != null ? (old.get("createdAt") instanceof Number ? ((Number) old.get("createdAt")).longValue() : Long.valueOf(String.valueOf(old.get("createdAt")))) : System.currentTimeMillis(); Map meta = new java.util.HashMap<>(); meta.put("chatId", finalChatId); meta.put("title", buildTitleFromPrompt(finalPrompt)); meta.put("createdAt", createdAt); meta.put("updatedAt", System.currentTimeMillis()); redisUtil.hmset(finalMetaKey, meta, CHAT_TTL_SECONDS); } emitter.complete(); } catch (Exception ignore) {} }, e -> { try { emitter.completeWithError(e); } catch (Exception ignore) {} }); } public List> listChats() { java.util.Set keys = redisUtil.scanKeys(RedisKeyType.AI_CHAT_META.key, 1000); List> resp = new ArrayList<>(); if (keys != null) { for (String key : keys) { Map m = redisUtil.hmget(key); if (m != null && !m.isEmpty()) { java.util.HashMap item = new java.util.HashMap<>(); for (Map.Entry e : m.entrySet()) { item.put(String.valueOf(e.getKey()), e.getValue()); } String chatId = String.valueOf(item.get("chatId")); String historyKey = RedisKeyType.AI_CHAT_HISTORY.key + chatId; item.put("size", redisUtil.lGetListSize(historyKey)); resp.add(item); } } } return resp; } public boolean deleteChat(String chatId) { if (chatId == null || chatId.isEmpty()) return false; String historyKey = RedisKeyType.AI_CHAT_HISTORY.key + chatId; String metaKey = RedisKeyType.AI_CHAT_META.key + chatId; redisUtil.del(historyKey, metaKey); return true; } public List getChatHistory(String chatId) { if (chatId == null || chatId.isEmpty()) return java.util.Collections.emptyList(); String historyKey = RedisKeyType.AI_CHAT_HISTORY.key + chatId; List stored = redisUtil.lGet(historyKey, 0, -1); List result = new ArrayList<>(); if (stored != null) { for (Object o : stored) { ChatCompletionRequest.Message m = convertToMessage(o); if (m != null) result.add(m); } } return result; } private ChatCompletionRequest.Message convertToMessage(Object o) { if (o instanceof ChatCompletionRequest.Message) { return (ChatCompletionRequest.Message) o; } if (o instanceof Map) { Map map = (Map) o; ChatCompletionRequest.Message m = new ChatCompletionRequest.Message(); Object role = map.get("role"); Object content = map.get("content"); m.setRole(role == null ? null : String.valueOf(role)); m.setContent(content == null ? null : String.valueOf(content)); return m; } return null; } private String buildTitleFromPrompt(String prompt) { if (prompt == null || prompt.isEmpty()) return "未命名会话"; String p = prompt.replaceAll("\n", " ").trim(); return p.length() > 20 ? p.substring(0, 20) : p; } private boolean runMcpStreamingDiagnosis(List baseMessages, ChatCompletionRequest.Message systemPrompt, ChatCompletionRequest.Message userQuestion, Double temperature, Integer maxTokens, SseEmitter emitter, String chatId) { try { if (mcpController == null) return false; List tools = buildOpenAiTools(); if (tools.isEmpty()) return false; baseMessages.add(systemPrompt); baseMessages.add(userQuestion); List messages = new ArrayList<>(baseMessages.size() + 8); messages.addAll(baseMessages); sse(emitter, "\\n正在初始化诊断与工具环境...\\n"); int maxRound = 10; int i = 0; while(true) { sse(emitter, "\\n正在分析(第" + (i + 1) + "轮)...\\n"); ChatCompletionResponse resp = llmChatService.chatCompletion(messages, temperature, maxTokens, tools); if (resp == null || resp.getChoices() == null || resp.getChoices().isEmpty() || resp.getChoices().get(0).getMessage() == null) { sse(emitter, "\\n分析出错,正在回退...\\n"); return false; } ChatCompletionRequest.Message assistant = resp.getChoices().get(0).getMessage(); messages.add(assistant); sse(emitter, assistant.getContent()); List toolCalls = assistant.getTool_calls(); if (toolCalls == null || toolCalls.isEmpty()) { break; } for (ChatCompletionRequest.ToolCall tc : toolCalls) { String toolName = tc != null && tc.getFunction() != null ? tc.getFunction().getName() : null; if (toolName == null || toolName.trim().isEmpty()) continue; sse(emitter, "\\n准备调用工具:" + toolName + "\\n"); JSONObject args = new JSONObject(); if (tc.getFunction() != null && tc.getFunction().getArguments() != null && !tc.getFunction().getArguments().trim().isEmpty()) { try { args = JSON.parseObject(tc.getFunction().getArguments()); } catch (Exception ignore) { args = new JSONObject(); args.put("_raw", tc.getFunction().getArguments()); } } Object output; try { output = mcpController.callTool(toolName, args); } catch (Exception e) { java.util.LinkedHashMap err = new java.util.LinkedHashMap(); err.put("tool", toolName); err.put("error", e.getMessage()); output = err; } sse(emitter, "\\n工具返回,正在继续推理...\\n"); ChatCompletionRequest.Message toolMsg = new ChatCompletionRequest.Message(); toolMsg.setRole("tool"); toolMsg.setTool_call_id(tc == null ? null : tc.getId()); toolMsg.setContent(JSON.toJSONString(output)); messages.add(toolMsg); } if(i++ >= maxRound) break; } sse(emitter, "\\n正在根据数据进行分析...\\n\\n\\n"); ChatCompletionRequest.Message diagnosisMessage = new ChatCompletionRequest.Message(); diagnosisMessage.setRole("system"); diagnosisMessage.setContent("根据以上信息进行分析,并给出完整的诊断结论。"); messages.add(diagnosisMessage); StringBuilder assistantBuffer = new StringBuilder(); llmChatService.chatStreamWithTools(messages, temperature, maxTokens, tools, s -> { try { String safe = s == null ? "" : s.replace("\r", "").replace("\n", "\\n"); if (!safe.isEmpty()) { sse(emitter, safe); assistantBuffer.append(safe); } } catch (Exception ignore) {} }, () -> { try { sse(emitter, "\\n\\n【AI】运行已停止(正常结束)\\n\\n"); log.info("AI MCP diagnose stopped: final end"); emitter.complete(); if (chatId != null) { String historyKey = RedisKeyType.AI_CHAT_HISTORY.key + chatId; String metaKey = RedisKeyType.AI_CHAT_META.key + chatId; ChatCompletionRequest.Message a = new ChatCompletionRequest.Message(); a.setRole("assistant"); a.setContent(assistantBuffer.toString()); redisUtil.lSet(historyKey, userQuestion); redisUtil.lSet(historyKey, a); redisUtil.expire(historyKey, CHAT_TTL_SECONDS); Map old = redisUtil.hmget(metaKey); Long createdAt = old != null && old.get("createdAt") != null ? (old.get("createdAt") instanceof Number ? ((Number) old.get("createdAt")).longValue() : Long.valueOf(String.valueOf(old.get("createdAt")))) : System.currentTimeMillis(); Map meta = new java.util.HashMap<>(); meta.put("chatId", chatId); meta.put("title", buildTitleFromPrompt(userQuestion.getContent())); meta.put("createdAt", createdAt); meta.put("updatedAt", System.currentTimeMillis()); redisUtil.hmset(metaKey, meta, CHAT_TTL_SECONDS); } } catch (Exception ignore) {} }, e -> { sse(emitter, "\\n\\n【AI】分析出错,正在回退...\\n\\n"); }); return true; } catch (Exception e) { try { sse(emitter, "\\n\\n【AI】运行已停止(异常)\\n\\n"); log.error("AI MCP diagnose stopped: error", e); emitter.completeWithError(e); } catch (Exception ignore) {} return true; } } private void sse(SseEmitter emitter, String data) { if (data == null) return; try { emitter.send(SseEmitter.event().data(data)); } catch (Exception e) { log.warn("SSE send failed", e); } } private List buildOpenAiTools() { if (mcpController == null) return java.util.Collections.emptyList(); List> mcpTools = mcpController.listTools(); if (mcpTools == null || mcpTools.isEmpty()) return java.util.Collections.emptyList(); List tools = new ArrayList<>(); for (Map t : mcpTools) { if (t == null) continue; Object name = t.get("name"); if (name == null) continue; Object inputSchema = t.get("inputSchema"); java.util.LinkedHashMap function = new java.util.LinkedHashMap(); function.put("name", String.valueOf(name)); Object desc = t.get("description"); if (desc != null) function.put("description", String.valueOf(desc)); function.put("parameters", inputSchema == null ? new java.util.LinkedHashMap() : inputSchema); java.util.LinkedHashMap tool = new java.util.LinkedHashMap(); tool.put("type", "function"); tool.put("function", function); tools.add(tool); } return tools; } private void sendLargeText(SseEmitter emitter, String text) { if (text == null) return; String safe = text.replace("\r", "").replace("\n", "\\n"); int chunkSize = 256; int i = 0; while (i < safe.length()) { int end = Math.min(i + chunkSize, safe.length()); String part = safe.substring(i, end); if (!part.isEmpty()) { try { emitter.send(SseEmitter.event().data(part)); } catch (Exception ignore) {} } i = end; } } private static final java.util.regex.Pattern DSML_INVOKE_PATTERN = java.util.regex.Pattern.compile("<\\uFF5CDSML\\uFF5Cinvoke\\s+name=\\\"([^\\\"]+)\\\"[^>]*>([\\s\\S]*?)", java.util.regex.Pattern.MULTILINE); private static final java.util.regex.Pattern JSON_OBJECT_PATTERN = java.util.regex.Pattern.compile("\\{[\\s\\S]*\\}"); private static final java.util.regex.Pattern DSML_PARAM_PATTERN = java.util.regex.Pattern.compile("<\\uFF5CDSML\\uFF5Cparameter\\s+name=\\\"([^\\\"]+)\\\"\\s*([^>]*)>([\\s\\S]*?)", java.util.regex.Pattern.MULTILINE); private java.util.List parseDsmlInvocations(String content) { java.util.List list = new java.util.ArrayList<>(); if (content == null || content.isEmpty()) return list; java.util.regex.Matcher m = DSML_INVOKE_PATTERN.matcher(content); while (m.find()) { String name = m.group(1); String inner = m.group(2); com.alibaba.fastjson.JSONObject args = null; if (inner != null) { java.util.regex.Matcher jm = JSON_OBJECT_PATTERN.matcher(inner); if (jm.find()) { String json = jm.group(); try { args = com.alibaba.fastjson.JSON.parseObject(json); } catch (Exception ignore) {} } java.util.regex.Matcher pm = DSML_PARAM_PATTERN.matcher(inner); while (pm.find()) { if (args == null) args = new com.alibaba.fastjson.JSONObject(); String pName = pm.group(1); String attr = pm.group(2); String valText = pm.group(3); boolean isString = attr != null && attr.toLowerCase().contains("string=\"true\""); String t = valText == null ? "" : valText.trim(); if (isString) { args.put(pName, t); } else { if ("true".equalsIgnoreCase(t) || "false".equalsIgnoreCase(t)) { args.put(pName, Boolean.valueOf(t)); } else { try { if (t.contains(".")) { args.put(pName, Double.valueOf(t)); } else { args.put(pName, Long.valueOf(t)); } } catch (Exception ex) { args.put(pName, t); } } } } } DsmlInvocation inv = new DsmlInvocation(); inv.name = name; inv.arguments = args; list.add(inv); } return list; } private static class DsmlInvocation { String name; com.alibaba.fastjson.JSONObject arguments; } private List buildDefaultStatusInvocations() { List list = new ArrayList<>(); DsmlInvocation crn = new DsmlInvocation(); crn.name = "device.get_crn_status"; com.alibaba.fastjson.JSONObject a1 = new com.alibaba.fastjson.JSONObject(); a1.put("limit", 20); crn.arguments = a1; list.add(crn); DsmlInvocation st = new DsmlInvocation(); st.name = "device.get_station_status"; com.alibaba.fastjson.JSONObject a2 = new com.alibaba.fastjson.JSONObject(); a2.put("limit", 20); st.arguments = a2; list.add(st); DsmlInvocation rgv = new DsmlInvocation(); rgv.name = "device.get_rgv_status"; com.alibaba.fastjson.JSONObject a3 = new com.alibaba.fastjson.JSONObject(); a3.put("limit", 20); rgv.arguments = a3; list.add(rgv); return list; } private void ensureStatusCoverage(List invs) { if (invs == null) return; java.util.Set names = new java.util.HashSet(); for (DsmlInvocation d : invs) { if (d != null && d.name != null) names.add(d.name); } if (!names.contains("device.get_crn_status") || !names.contains("device.get_station_status") || !names.contains("device.get_rgv_status")) { List defaults = buildDefaultStatusInvocations(); for (DsmlInvocation d : defaults) { if (!names.contains(d.name)) invs.add(d); } } } private boolean isConclusionText(String content) { if (content == null) return false; String c = content; int len = c.length(); boolean longEnough = len >= 200; boolean hasAllSections = c.contains("问题概述") && c.contains("可疑设备列表") && c.contains("可能原因") && c.contains("建议排查步骤") && c.contains("风险评估"); boolean hasExplicitConclusion = (c.contains("结论") || c.contains("诊断结果")) && longEnough; return hasAllSections || hasExplicitConclusion; } }