| | |
| | | package com.zy.ai.service; |
| | | |
| | | import com.alibaba.fastjson.JSON; |
| | | import com.alibaba.fastjson.JSONObject; |
| | | import com.zy.ai.entity.ChatCompletionRequest; |
| | | import com.zy.ai.entity.ChatCompletionResponse; |
| | | import com.zy.ai.entity.WcsDiagnosisRequest; |
| | | import com.zy.ai.mcp.controller.McpController; |
| | | import com.zy.ai.utils.AiPromptUtils; |
| | | import com.zy.ai.utils.AiUtils; |
| | | import com.zy.common.utils.RedisUtil; |
| | | import com.zy.core.enums.RedisKeyType; |
| | | import org.springframework.beans.factory.annotation.Autowired; |
| | | import org.springframework.beans.factory.annotation.Value; |
| | | import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; |
| | | import lombok.RequiredArgsConstructor; |
| | | import lombok.extern.slf4j.Slf4j; |
| | | import org.springframework.stereotype.Service; |
| | | |
| | | import java.util.ArrayList; |
| | |
| | | |
| | | @Service |
| | | @RequiredArgsConstructor |
| | | @Slf4j |
| | | public class WcsDiagnosisService { |
| | | |
| | | private static final long CHAT_TTL_SECONDS = 7L * 24 * 3600; |
| | | |
| | | @Value("${llm.platform}") |
| | | private String platform; |
| | | @Autowired |
| | | private LlmChatService llmChatService; |
| | | @Autowired |
| | |
| | | private AiPromptUtils aiPromptUtils; |
| | | @Autowired |
| | | private AiUtils aiUtils; |
| | | |
| | | /** |
| | | * 针对“系统不执行任务 / 不知道哪个设备没在运行”的通用 AI 诊断 |
| | | */ |
| | | public String diagnose(WcsDiagnosisRequest request) { |
| | | List<ChatCompletionRequest.Message> messages = new ArrayList<>(); |
| | | |
| | | // 1. system:定义专家身份 + 输出结构 |
| | | ChatCompletionRequest.Message system = new ChatCompletionRequest.Message(); |
| | | system.setRole("system"); |
| | | system.setContent(aiPromptUtils.getAiDiagnosePrompt()); |
| | | messages.add(system); |
| | | |
| | | ChatCompletionRequest.Message user = new ChatCompletionRequest.Message(); |
| | | user.setRole("user"); |
| | | user.setContent(aiUtils.buildDiagnosisUserContent(request)); |
| | | messages.add(user); |
| | | |
| | | // 调用大模型 |
| | | return llmChatService.chat(messages, 0.2, 2048); |
| | | } |
| | | @Autowired(required = false) |
| | | private McpController mcpController; |
| | | @Autowired |
| | | private PythonService pythonService; |
| | | |
| | | public void diagnoseStream(WcsDiagnosisRequest request, SseEmitter emitter) { |
| | | List<ChatCompletionRequest.Message> messages = new ArrayList<>(); |
| | | |
| | | ChatCompletionRequest.Message mcpSystem = new ChatCompletionRequest.Message(); |
| | | mcpSystem.setRole("system"); |
| | | mcpSystem.setContent(aiPromptUtils.getAiDiagnosePromptMcp()); |
| | | |
| | | ChatCompletionRequest.Message mcpUser = new ChatCompletionRequest.Message(); |
| | | mcpUser.setRole("user"); |
| | | mcpUser.setContent(aiUtils.buildDiagnosisUserContentMcp(request)); |
| | | |
| | | if (runMcpStreamingDiagnosis(messages, mcpSystem, mcpUser, 0.3, 2048, emitter, null)) { |
| | | return; |
| | | } |
| | | |
| | | messages = new ArrayList<>(); |
| | | ChatCompletionRequest.Message system = new ChatCompletionRequest.Message(); |
| | | system.setRole("system"); |
| | | system.setContent(aiPromptUtils.getAiDiagnosePrompt()); |
| | |
| | | user.setContent(aiUtils.buildDiagnosisUserContent(request)); |
| | | messages.add(user); |
| | | |
| | | llmChatService.chatStream(messages, 0.2, 2048, s -> { |
| | | llmChatService.chatStream(messages, 0.3, 2048, s -> { |
| | | try { |
| | | String safe = s == null ? "" : s.replace("\r", "").replace("\n", "\\n"); |
| | | if (!safe.isEmpty()) { |
| | |
| | | } |
| | | } catch (Exception ignore) {} |
| | | }, () -> { |
| | | try { emitter.complete(); } catch (Exception ignore) {} |
| | | try { |
| | | log.info("AI diagnose stream stopped: normal end"); |
| | | emitter.complete(); |
| | | } catch (Exception ignore) {} |
| | | }, e -> { |
| | | try { emitter.completeWithError(e); } catch (Exception ignore) {} |
| | | try { |
| | | try { emitter.send(SseEmitter.event().data("【AI】运行已停止(异常)")); } catch (Exception ignore) {} |
| | | log.error("AI diagnose stream stopped: error", e); |
| | | emitter.completeWithError(e); |
| | | } catch (Exception ignore) {} |
| | | }); |
| | | } |
| | | |
| | |
| | | String chatId, |
| | | boolean reset, |
| | | SseEmitter emitter) { |
| | | List<ChatCompletionRequest.Message> base = new ArrayList<>(); |
| | | if (platform.equals("python")) { |
| | | pythonService.runPython(prompt, chatId, emitter); |
| | | return; |
| | | } |
| | | |
| | | ChatCompletionRequest.Message system = new ChatCompletionRequest.Message(); |
| | | system.setRole("system"); |
| | | system.setContent(aiPromptUtils.getWcsSensorPrompt()); |
| | | base.add(system); |
| | | List<ChatCompletionRequest.Message> messages = new ArrayList<>(); |
| | | |
| | | List<ChatCompletionRequest.Message> history = null; |
| | | String historyKey = null; |
| | |
| | | ChatCompletionRequest.Message m = convertToMessage(o); |
| | | if (m != null) history.add(m); |
| | | } |
| | | if (!history.isEmpty()) base.addAll(history); |
| | | if (!history.isEmpty()) messages.addAll(history); |
| | | } else { |
| | | history = new ArrayList<>(); |
| | | } |
| | | } |
| | | |
| | | ChatCompletionRequest.Message contextMsg = new ChatCompletionRequest.Message(); |
| | | contextMsg.setRole("user"); |
| | | contextMsg.setContent(aiUtils.buildAskUserContent(request)); |
| | | base.add(contextMsg); |
| | | |
| | | ChatCompletionRequest.Message questionMsg = new ChatCompletionRequest.Message(); |
| | | questionMsg.setRole("user"); |
| | | questionMsg.setContent("【用户提问】\n" + (prompt == null ? "" : prompt)); |
| | | base.add(questionMsg); |
| | | |
| | | StringBuilder assistantBuffer = new StringBuilder(); |
| | | final String finalChatId = chatId; |
| | |
| | | final String finalMetaKey = metaKey; |
| | | final String finalPrompt = prompt; |
| | | |
| | | llmChatService.chatStream(base, 0.2, 2048, s -> { |
| | | ChatCompletionRequest.Message mcpSystem = new ChatCompletionRequest.Message(); |
| | | mcpSystem.setRole("system"); |
| | | mcpSystem.setContent(aiPromptUtils.getWcsSensorPromptMcp()); |
| | | |
| | | ChatCompletionRequest.Message mcpUser = new ChatCompletionRequest.Message(); |
| | | mcpUser.setRole("user"); |
| | | mcpUser.setContent("【用户提问】\n" + (prompt == null ? "" : prompt)); |
| | | |
| | | if (runMcpStreamingDiagnosis(messages, mcpSystem, mcpUser, 0.3, 2048, emitter, finalChatId)) { |
| | | return; |
| | | } |
| | | |
| | | messages = new ArrayList<>(); |
| | | ChatCompletionRequest.Message system = new ChatCompletionRequest.Message(); |
| | | system.setRole("system"); |
| | | system.setContent(aiPromptUtils.getWcsSensorPrompt()); |
| | | messages.add(system); |
| | | |
| | | ChatCompletionRequest.Message questionMsg = new ChatCompletionRequest.Message(); |
| | | questionMsg.setRole("user"); |
| | | questionMsg.setContent("【用户提问】\n" + (prompt == null ? "" : prompt)); |
| | | messages.add(questionMsg); |
| | | |
| | | llmChatService.chatStream(messages, 0.3, 2048, s -> { |
| | | try { |
| | | String safe = s == null ? "" : s.replace("\r", "").replace("\n", "\\n"); |
| | | if (!safe.isEmpty()) { |
| | |
| | | String p = prompt.replaceAll("\n", " ").trim(); |
| | | return p.length() > 20 ? p.substring(0, 20) : p; |
| | | } |
| | | |
| | | private boolean runMcpStreamingDiagnosis(List<ChatCompletionRequest.Message> baseMessages, |
| | | ChatCompletionRequest.Message systemPrompt, |
| | | ChatCompletionRequest.Message userQuestion, |
| | | Double temperature, |
| | | Integer maxTokens, |
| | | SseEmitter emitter, |
| | | String chatId) { |
| | | try { |
| | | if (mcpController == null) return false; |
| | | List<Object> tools = buildOpenAiTools(); |
| | | if (tools.isEmpty()) return false; |
| | | |
| | | baseMessages.add(systemPrompt); |
| | | baseMessages.add(userQuestion); |
| | | |
| | | List<ChatCompletionRequest.Message> messages = new ArrayList<>(baseMessages.size() + 8); |
| | | messages.addAll(baseMessages); |
| | | |
| | | sse(emitter, "<think>\\n正在初始化诊断与工具环境...\\n"); |
| | | |
| | | int maxRound = 10; |
| | | int i = 0; |
| | | while(true) { |
| | | sse(emitter, "\\n正在分析(第" + (i + 1) + "轮)...\\n"); |
| | | ChatCompletionResponse resp = llmChatService.chatCompletion(messages, temperature, maxTokens, tools); |
| | | if (resp == null || resp.getChoices() == null || resp.getChoices().isEmpty() || resp.getChoices().get(0).getMessage() == null) { |
| | | sse(emitter, "\\n分析出错,正在回退...\\n"); |
| | | return false; |
| | | } |
| | | |
| | | ChatCompletionRequest.Message assistant = resp.getChoices().get(0).getMessage(); |
| | | messages.add(assistant); |
| | | sse(emitter, assistant.getContent()); |
| | | |
| | | List<ChatCompletionRequest.ToolCall> toolCalls = assistant.getTool_calls(); |
| | | if (toolCalls == null || toolCalls.isEmpty()) { |
| | | break; |
| | | } |
| | | |
| | | for (ChatCompletionRequest.ToolCall tc : toolCalls) { |
| | | String toolName = tc != null && tc.getFunction() != null ? tc.getFunction().getName() : null; |
| | | if (toolName == null || toolName.trim().isEmpty()) continue; |
| | | sse(emitter, "\\n准备调用工具:" + toolName + "\\n"); |
| | | JSONObject args = new JSONObject(); |
| | | if (tc.getFunction() != null && tc.getFunction().getArguments() != null && !tc.getFunction().getArguments().trim().isEmpty()) { |
| | | try { |
| | | args = JSON.parseObject(tc.getFunction().getArguments()); |
| | | } catch (Exception ignore) { |
| | | args = new JSONObject(); |
| | | args.put("_raw", tc.getFunction().getArguments()); |
| | | } |
| | | } |
| | | Object output; |
| | | try { |
| | | output = mcpController.callTool(toolName, args); |
| | | } catch (Exception e) { |
| | | java.util.LinkedHashMap<String, Object> err = new java.util.LinkedHashMap<String, Object>(); |
| | | err.put("tool", toolName); |
| | | err.put("error", e.getMessage()); |
| | | output = err; |
| | | } |
| | | sse(emitter, "\\n工具返回,正在继续推理...\\n"); |
| | | ChatCompletionRequest.Message toolMsg = new ChatCompletionRequest.Message(); |
| | | toolMsg.setRole("tool"); |
| | | toolMsg.setTool_call_id(tc == null ? null : tc.getId()); |
| | | toolMsg.setContent(JSON.toJSONString(output)); |
| | | messages.add(toolMsg); |
| | | } |
| | | |
| | | if(i++ >= maxRound) break; |
| | | } |
| | | |
| | | sse(emitter, "\\n正在根据数据进行分析...\\n</think>\\n\\n"); |
| | | |
| | | ChatCompletionRequest.Message diagnosisMessage = new ChatCompletionRequest.Message(); |
| | | diagnosisMessage.setRole("system"); |
| | | diagnosisMessage.setContent("根据以上信息进行分析,并给出完整的诊断结论。"); |
| | | messages.add(diagnosisMessage); |
| | | |
| | | StringBuilder assistantBuffer = new StringBuilder(); |
| | | llmChatService.chatStreamWithTools(messages, temperature, maxTokens, tools, s -> { |
| | | try { |
| | | String safe = s == null ? "" : s.replace("\r", "").replace("\n", "\\n"); |
| | | if (!safe.isEmpty()) { |
| | | sse(emitter, safe); |
| | | assistantBuffer.append(safe); |
| | | } |
| | | } catch (Exception ignore) {} |
| | | }, () -> { |
| | | try { |
| | | sse(emitter, "\\n\\n【AI】运行已停止(正常结束)\\n\\n"); |
| | | log.info("AI MCP diagnose stopped: final end"); |
| | | emitter.complete(); |
| | | |
| | | if (chatId != null) { |
| | | String historyKey = RedisKeyType.AI_CHAT_HISTORY.key + chatId; |
| | | String metaKey = RedisKeyType.AI_CHAT_META.key + chatId; |
| | | |
| | | ChatCompletionRequest.Message a = new ChatCompletionRequest.Message(); |
| | | a.setRole("assistant"); |
| | | a.setContent(assistantBuffer.toString()); |
| | | redisUtil.lSet(historyKey, userQuestion); |
| | | redisUtil.lSet(historyKey, a); |
| | | redisUtil.expire(historyKey, CHAT_TTL_SECONDS); |
| | | Map<Object, Object> old = redisUtil.hmget(metaKey); |
| | | Long createdAt = old != null && old.get("createdAt") != null ? |
| | | (old.get("createdAt") instanceof Number ? ((Number) old.get("createdAt")).longValue() : Long.valueOf(String.valueOf(old.get("createdAt")))) |
| | | : System.currentTimeMillis(); |
| | | Map<String, Object> meta = new java.util.HashMap<>(); |
| | | meta.put("chatId", chatId); |
| | | meta.put("title", buildTitleFromPrompt(userQuestion.getContent())); |
| | | meta.put("createdAt", createdAt); |
| | | meta.put("updatedAt", System.currentTimeMillis()); |
| | | redisUtil.hmset(metaKey, meta, CHAT_TTL_SECONDS); |
| | | } |
| | | } catch (Exception ignore) {} |
| | | }, e -> { |
| | | sse(emitter, "\\n\\n【AI】分析出错,正在回退...\\n\\n"); |
| | | }); |
| | | return true; |
| | | } catch (Exception e) { |
| | | try { |
| | | sse(emitter, "\\n\\n【AI】运行已停止(异常)\\n\\n"); |
| | | log.error("AI MCP diagnose stopped: error", e); |
| | | emitter.completeWithError(e); |
| | | } catch (Exception ignore) {} |
| | | return true; |
| | | } |
| | | } |
| | | |
| | | private void sse(SseEmitter emitter, String data) { |
| | | if (data == null) return; |
| | | try { |
| | | emitter.send(SseEmitter.event().data(data)); |
| | | } catch (Exception e) { |
| | | log.warn("SSE send failed", e); |
| | | } |
| | | } |
| | | |
| | | private List<Object> buildOpenAiTools() { |
| | | if (mcpController == null) return java.util.Collections.emptyList(); |
| | | List<Map<String, Object>> mcpTools = mcpController.listTools(); |
| | | if (mcpTools == null || mcpTools.isEmpty()) return java.util.Collections.emptyList(); |
| | | |
| | | List<Object> tools = new ArrayList<>(); |
| | | for (Map<String, Object> t : mcpTools) { |
| | | if (t == null) continue; |
| | | Object name = t.get("name"); |
| | | if (name == null) continue; |
| | | Object inputSchema = t.get("inputSchema"); |
| | | java.util.LinkedHashMap<String, Object> function = new java.util.LinkedHashMap<String, Object>(); |
| | | function.put("name", String.valueOf(name)); |
| | | Object desc = t.get("description"); |
| | | if (desc != null) function.put("description", String.valueOf(desc)); |
| | | function.put("parameters", inputSchema == null ? new java.util.LinkedHashMap<String, Object>() : inputSchema); |
| | | |
| | | java.util.LinkedHashMap<String, Object> tool = new java.util.LinkedHashMap<String, Object>(); |
| | | tool.put("type", "function"); |
| | | tool.put("function", function); |
| | | tools.add(tool); |
| | | } |
| | | return tools; |
| | | } |
| | | |
| | | private void sendLargeText(SseEmitter emitter, String text) { |
| | | if (text == null) return; |
| | | String safe = text.replace("\r", "").replace("\n", "\\n"); |
| | | int chunkSize = 256; |
| | | int i = 0; |
| | | while (i < safe.length()) { |
| | | int end = Math.min(i + chunkSize, safe.length()); |
| | | String part = safe.substring(i, end); |
| | | if (!part.isEmpty()) { |
| | | try { emitter.send(SseEmitter.event().data(part)); } catch (Exception ignore) {} |
| | | } |
| | | i = end; |
| | | } |
| | | } |
| | | |
| | | private static final java.util.regex.Pattern DSML_INVOKE_PATTERN = |
| | | java.util.regex.Pattern.compile("<\\uFF5CDSML\\uFF5Cinvoke\\s+name=\\\"([^\\\"]+)\\\"[^>]*>([\\s\\S]*?)</\\uFF5CDSML\\uFF5Cinvoke>", java.util.regex.Pattern.MULTILINE); |
| | | private static final java.util.regex.Pattern JSON_OBJECT_PATTERN = |
| | | java.util.regex.Pattern.compile("\\{[\\s\\S]*\\}"); |
| | | private static final java.util.regex.Pattern DSML_PARAM_PATTERN = |
| | | java.util.regex.Pattern.compile("<\\uFF5CDSML\\uFF5Cparameter\\s+name=\\\"([^\\\"]+)\\\"\\s*([^>]*)>([\\s\\S]*?)</\\uFF5CDSML\\uFF5Cparameter>", java.util.regex.Pattern.MULTILINE); |
| | | |
| | | private java.util.List<DsmlInvocation> parseDsmlInvocations(String content) { |
| | | java.util.List<DsmlInvocation> list = new java.util.ArrayList<>(); |
| | | if (content == null || content.isEmpty()) return list; |
| | | java.util.regex.Matcher m = DSML_INVOKE_PATTERN.matcher(content); |
| | | while (m.find()) { |
| | | String name = m.group(1); |
| | | String inner = m.group(2); |
| | | com.alibaba.fastjson.JSONObject args = null; |
| | | if (inner != null) { |
| | | java.util.regex.Matcher jm = JSON_OBJECT_PATTERN.matcher(inner); |
| | | if (jm.find()) { |
| | | String json = jm.group(); |
| | | try { args = com.alibaba.fastjson.JSON.parseObject(json); } catch (Exception ignore) {} |
| | | } |
| | | java.util.regex.Matcher pm = DSML_PARAM_PATTERN.matcher(inner); |
| | | while (pm.find()) { |
| | | if (args == null) args = new com.alibaba.fastjson.JSONObject(); |
| | | String pName = pm.group(1); |
| | | String attr = pm.group(2); |
| | | String valText = pm.group(3); |
| | | boolean isString = attr != null && attr.toLowerCase().contains("string=\"true\""); |
| | | String t = valText == null ? "" : valText.trim(); |
| | | if (isString) { |
| | | args.put(pName, t); |
| | | } else { |
| | | if ("true".equalsIgnoreCase(t) || "false".equalsIgnoreCase(t)) { |
| | | args.put(pName, Boolean.valueOf(t)); |
| | | } else { |
| | | try { |
| | | if (t.contains(".")) { |
| | | args.put(pName, Double.valueOf(t)); |
| | | } else { |
| | | args.put(pName, Long.valueOf(t)); |
| | | } |
| | | } catch (Exception ex) { |
| | | args.put(pName, t); |
| | | } |
| | | } |
| | | } |
| | | } |
| | | } |
| | | DsmlInvocation inv = new DsmlInvocation(); |
| | | inv.name = name; |
| | | inv.arguments = args; |
| | | list.add(inv); |
| | | } |
| | | return list; |
| | | } |
| | | |
| | | private static class DsmlInvocation { |
| | | String name; |
| | | com.alibaba.fastjson.JSONObject arguments; |
| | | } |
| | | |
| | | private List<DsmlInvocation> buildDefaultStatusInvocations() { |
| | | List<DsmlInvocation> list = new ArrayList<>(); |
| | | DsmlInvocation crn = new DsmlInvocation(); |
| | | crn.name = "device.get_crn_status"; |
| | | com.alibaba.fastjson.JSONObject a1 = new com.alibaba.fastjson.JSONObject(); |
| | | a1.put("limit", 20); |
| | | crn.arguments = a1; |
| | | list.add(crn); |
| | | |
| | | DsmlInvocation st = new DsmlInvocation(); |
| | | st.name = "device.get_station_status"; |
| | | com.alibaba.fastjson.JSONObject a2 = new com.alibaba.fastjson.JSONObject(); |
| | | a2.put("limit", 20); |
| | | st.arguments = a2; |
| | | list.add(st); |
| | | |
| | | DsmlInvocation rgv = new DsmlInvocation(); |
| | | rgv.name = "device.get_rgv_status"; |
| | | com.alibaba.fastjson.JSONObject a3 = new com.alibaba.fastjson.JSONObject(); |
| | | a3.put("limit", 20); |
| | | rgv.arguments = a3; |
| | | list.add(rgv); |
| | | |
| | | return list; |
| | | } |
| | | |
| | | private void ensureStatusCoverage(List<DsmlInvocation> invs) { |
| | | if (invs == null) return; |
| | | java.util.Set<String> names = new java.util.HashSet<String>(); |
| | | for (DsmlInvocation d : invs) { |
| | | if (d != null && d.name != null) names.add(d.name); |
| | | } |
| | | if (!names.contains("device.get_crn_status") || !names.contains("device.get_station_status") || !names.contains("device.get_rgv_status")) { |
| | | List<DsmlInvocation> defaults = buildDefaultStatusInvocations(); |
| | | for (DsmlInvocation d : defaults) { |
| | | if (!names.contains(d.name)) invs.add(d); |
| | | } |
| | | } |
| | | } |
| | | |
| | | private boolean isConclusionText(String content) { |
| | | if (content == null) return false; |
| | | String c = content; |
| | | int len = c.length(); |
| | | boolean longEnough = len >= 200; |
| | | boolean hasAllSections = c.contains("问题概述") && c.contains("可疑设备列表") && c.contains("可能原因") && c.contains("建议排查步骤") && c.contains("风险评估"); |
| | | boolean hasExplicitConclusion = (c.contains("结论") || c.contains("诊断结果")) && longEnough; |
| | | return hasAllSections || hasExplicitConclusion; |
| | | } |
| | | } |