From 4898d942bd6e3c1119493cf0314b15f2bd54daf3 Mon Sep 17 00:00:00 2001
From: Junjie <fallin.jie@qq.com>
Date: 星期六, 03 一月 2026 22:06:22 +0800
Subject: [PATCH] #mcp

---
 src/main/java/com/zy/ai/service/WcsDiagnosisService.java |  395 ++++++++++++++++++++++++++++++++++++++++++++++++++-----
 1 files changed, 354 insertions(+), 41 deletions(-)

diff --git a/src/main/java/com/zy/ai/service/WcsDiagnosisService.java b/src/main/java/com/zy/ai/service/WcsDiagnosisService.java
index 1272f6a..189e57d 100644
--- a/src/main/java/com/zy/ai/service/WcsDiagnosisService.java
+++ b/src/main/java/com/zy/ai/service/WcsDiagnosisService.java
@@ -1,15 +1,20 @@
 package com.zy.ai.service;
 
 import com.alibaba.fastjson.JSON;
+import com.alibaba.fastjson.JSONObject;
 import com.zy.ai.entity.ChatCompletionRequest;
+import com.zy.ai.entity.ChatCompletionResponse;
 import com.zy.ai.entity.WcsDiagnosisRequest;
+import com.zy.ai.mcp.controller.McpController;
 import com.zy.ai.utils.AiPromptUtils;
 import com.zy.ai.utils.AiUtils;
 import com.zy.common.utils.RedisUtil;
 import com.zy.core.enums.RedisKeyType;
 import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.beans.factory.annotation.Value;
 import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
 import lombok.RequiredArgsConstructor;
+import lombok.extern.slf4j.Slf4j;
 import org.springframework.stereotype.Service;
 
 import java.util.ArrayList;
@@ -18,10 +23,13 @@
 
 @Service
 @RequiredArgsConstructor
+@Slf4j
 public class WcsDiagnosisService {
 
     private static final long CHAT_TTL_SECONDS = 7L * 24 * 3600;
 
+    @Value("${llm.platform}")
+    private String platform;
     @Autowired
     private LlmChatService llmChatService;
     @Autowired
@@ -30,31 +38,27 @@
     private AiPromptUtils aiPromptUtils;
     @Autowired
     private AiUtils aiUtils;
-
-    /**
-     * 閽堝鈥滅郴缁熶笉鎵ц浠诲姟 / 涓嶇煡閬撳摢涓澶囨病鍦ㄨ繍琛屸�濈殑閫氱敤 AI 璇婃柇
-     */
-    public String diagnose(WcsDiagnosisRequest request) {
-        List<ChatCompletionRequest.Message> messages = new ArrayList<>();
-
-        // 1. system锛氬畾涔変笓瀹惰韩浠� + 杈撳嚭缁撴瀯
-        ChatCompletionRequest.Message system = new ChatCompletionRequest.Message();
-        system.setRole("system");
-        system.setContent(aiPromptUtils.getAiDiagnosePrompt());
-        messages.add(system);
-
-        ChatCompletionRequest.Message user = new ChatCompletionRequest.Message();
-        user.setRole("user");
-        user.setContent(aiUtils.buildDiagnosisUserContent(request));
-        messages.add(user);
-
-        // 璋冪敤澶фā鍨�
-        return llmChatService.chat(messages, 0.2, 2048);
-    }
+    @Autowired(required = false)
+    private McpController mcpController;
+    @Autowired
+    private PythonService pythonService;
 
     public void diagnoseStream(WcsDiagnosisRequest request, SseEmitter emitter) {
         List<ChatCompletionRequest.Message> messages = new ArrayList<>();
 
+        ChatCompletionRequest.Message mcpSystem = new ChatCompletionRequest.Message();
+        mcpSystem.setRole("system");
+        mcpSystem.setContent(aiPromptUtils.getAiDiagnosePromptMcp());
+
+        ChatCompletionRequest.Message mcpUser = new ChatCompletionRequest.Message();
+        mcpUser.setRole("user");
+        mcpUser.setContent(aiUtils.buildDiagnosisUserContentMcp(request));
+
+        if (runMcpStreamingDiagnosis(messages, mcpSystem, mcpUser, 0.3, 2048, emitter, null)) {
+            return;
+        }
+
+        messages = new ArrayList<>();
         ChatCompletionRequest.Message system = new ChatCompletionRequest.Message();
         system.setRole("system");
         system.setContent(aiPromptUtils.getAiDiagnosePrompt());
@@ -65,7 +69,7 @@
         user.setContent(aiUtils.buildDiagnosisUserContent(request));
         messages.add(user);
 
-        llmChatService.chatStream(messages, 0.2, 2048, s -> {
+        llmChatService.chatStream(messages, 0.3, 2048, s -> {
             try {
                 String safe = s == null ? "" : s.replace("\r", "").replace("\n", "\\n");
                 if (!safe.isEmpty()) {
@@ -73,9 +77,16 @@
                 }
             } catch (Exception ignore) {}
         }, () -> {
-            try { emitter.complete(); } catch (Exception ignore) {}
+            try {
+                log.info("AI diagnose stream stopped: normal end");
+                emitter.complete();
+            } catch (Exception ignore) {}
         }, e -> {
-            try { emitter.completeWithError(e); } catch (Exception ignore) {}
+            try {
+                try { emitter.send(SseEmitter.event().data("銆怉I銆戣繍琛屽凡鍋滄锛堝紓甯革級")); } catch (Exception ignore) {}
+                log.error("AI diagnose stream stopped: error", e);
+                emitter.completeWithError(e);
+            } catch (Exception ignore) {}
         });
     }
 
@@ -84,12 +95,12 @@
                           String chatId,
                           boolean reset,
                           SseEmitter emitter) {
-        List<ChatCompletionRequest.Message> base = new ArrayList<>();
+        if (platform.equals("python")) {
+            pythonService.runPython(prompt, chatId, emitter);
+            return;
+        }
 
-        ChatCompletionRequest.Message system = new ChatCompletionRequest.Message();
-        system.setRole("system");
-        system.setContent(aiPromptUtils.getWcsSensorPrompt());
-        base.add(system);
+        List<ChatCompletionRequest.Message> messages = new ArrayList<>();
 
         List<ChatCompletionRequest.Message> history = null;
         String historyKey = null;
@@ -107,21 +118,11 @@
                     ChatCompletionRequest.Message m = convertToMessage(o);
                     if (m != null) history.add(m);
                 }
-                if (!history.isEmpty()) base.addAll(history);
+                if (!history.isEmpty()) messages.addAll(history);
             } else {
                 history = new ArrayList<>();
             }
         }
-
-        ChatCompletionRequest.Message contextMsg = new ChatCompletionRequest.Message();
-        contextMsg.setRole("user");
-        contextMsg.setContent(aiUtils.buildAskUserContent(request));
-        base.add(contextMsg);
-
-        ChatCompletionRequest.Message questionMsg = new ChatCompletionRequest.Message();
-        questionMsg.setRole("user");
-        questionMsg.setContent("銆愮敤鎴锋彁闂�慭n" + (prompt == null ? "" : prompt));
-        base.add(questionMsg);
 
         StringBuilder assistantBuffer = new StringBuilder();
         final String finalChatId = chatId;
@@ -129,7 +130,30 @@
         final String finalMetaKey = metaKey;
         final String finalPrompt = prompt;
 
-        llmChatService.chatStream(base, 0.2, 2048, s -> {
+        ChatCompletionRequest.Message mcpSystem = new ChatCompletionRequest.Message();
+        mcpSystem.setRole("system");
+        mcpSystem.setContent(aiPromptUtils.getWcsSensorPromptMcp());
+
+        ChatCompletionRequest.Message mcpUser = new ChatCompletionRequest.Message();
+        mcpUser.setRole("user");
+        mcpUser.setContent("銆愮敤鎴锋彁闂�慭n" + (prompt == null ? "" : prompt));
+
+        if (runMcpStreamingDiagnosis(messages, mcpSystem, mcpUser, 0.3, 2048, emitter, finalChatId)) {
+            return;
+        }
+
+        messages = new ArrayList<>();
+        ChatCompletionRequest.Message system = new ChatCompletionRequest.Message();
+        system.setRole("system");
+        system.setContent(aiPromptUtils.getWcsSensorPrompt());
+        messages.add(system);
+
+        ChatCompletionRequest.Message questionMsg = new ChatCompletionRequest.Message();
+        questionMsg.setRole("user");
+        questionMsg.setContent("銆愮敤鎴锋彁闂�慭n" + (prompt == null ? "" : prompt));
+        messages.add(questionMsg);
+
+        llmChatService.chatStream(messages, 0.3, 2048, s -> {
             try {
                 String safe = s == null ? "" : s.replace("\r", "").replace("\n", "\\n");
                 if (!safe.isEmpty()) {
@@ -231,5 +255,294 @@
         String p = prompt.replaceAll("\n", " ").trim();
         return p.length() > 20 ? p.substring(0, 20) : p;
     }
+ 
+    private boolean runMcpStreamingDiagnosis(List<ChatCompletionRequest.Message> baseMessages,
+                                             ChatCompletionRequest.Message systemPrompt,
+                                             ChatCompletionRequest.Message userQuestion,
+                                             Double temperature,
+                                             Integer maxTokens,
+                                             SseEmitter emitter,
+                                             String chatId) {
+        try {
+            if (mcpController == null) return false;
+            List<Object> tools = buildOpenAiTools();
+            if (tools.isEmpty()) return false;
 
+            baseMessages.add(systemPrompt);
+            baseMessages.add(userQuestion);
+ 
+            List<ChatCompletionRequest.Message> messages = new ArrayList<>(baseMessages.size() + 8);
+            messages.addAll(baseMessages);
+ 
+            sse(emitter, "<think>\\n姝e湪鍒濆鍖栬瘖鏂笌宸ュ叿鐜...\\n");
+
+            int maxRound = 10;
+            int i = 0;
+            while(true) {
+                sse(emitter, "\\n姝e湪鍒嗘瀽锛堢" + (i + 1) + "杞級...\\n");
+                ChatCompletionResponse resp = llmChatService.chatCompletion(messages, temperature, maxTokens, tools);
+                if (resp == null || resp.getChoices() == null || resp.getChoices().isEmpty() || resp.getChoices().get(0).getMessage() == null) {
+                    sse(emitter, "\\n鍒嗘瀽鍑洪敊锛屾鍦ㄥ洖閫�...\\n");
+                    return false;
+                }
+
+                ChatCompletionRequest.Message assistant = resp.getChoices().get(0).getMessage();
+                messages.add(assistant);
+                sse(emitter, assistant.getContent());
+
+                List<ChatCompletionRequest.ToolCall> toolCalls = assistant.getTool_calls();
+                if (toolCalls == null || toolCalls.isEmpty()) {
+                    break;
+                }
+
+                for (ChatCompletionRequest.ToolCall tc : toolCalls) {
+                    String toolName = tc != null && tc.getFunction() != null ? tc.getFunction().getName() : null;
+                    if (toolName == null || toolName.trim().isEmpty()) continue;
+                    sse(emitter, "\\n鍑嗗璋冪敤宸ュ叿锛�" + toolName + "\\n");
+                    JSONObject args = new JSONObject();
+                    if (tc.getFunction() != null && tc.getFunction().getArguments() != null && !tc.getFunction().getArguments().trim().isEmpty()) {
+                        try {
+                            args = JSON.parseObject(tc.getFunction().getArguments());
+                        } catch (Exception ignore) {
+                            args = new JSONObject();
+                            args.put("_raw", tc.getFunction().getArguments());
+                        }
+                    }
+                    Object output;
+                    try {
+                        output = mcpController.callTool(toolName, args);
+                    } catch (Exception e) {
+                        java.util.LinkedHashMap<String, Object> err = new java.util.LinkedHashMap<String, Object>();
+                        err.put("tool", toolName);
+                        err.put("error", e.getMessage());
+                        output = err;
+                    }
+                    sse(emitter, "\\n宸ュ叿杩斿洖锛屾鍦ㄧ户缁帹鐞�...\\n");
+                    ChatCompletionRequest.Message toolMsg = new ChatCompletionRequest.Message();
+                    toolMsg.setRole("tool");
+                    toolMsg.setTool_call_id(tc == null ? null : tc.getId());
+                    toolMsg.setContent(JSON.toJSONString(output));
+                    messages.add(toolMsg);
+                }
+
+                if(i++ >= maxRound) break;
+            }
+
+            sse(emitter, "\\n姝e湪鏍规嵁鏁版嵁杩涜鍒嗘瀽...\\n</think>\\n\\n");
+
+            ChatCompletionRequest.Message diagnosisMessage = new ChatCompletionRequest.Message();
+            diagnosisMessage.setRole("system");
+            diagnosisMessage.setContent("鏍规嵁浠ヤ笂淇℃伅杩涜鍒嗘瀽锛屽苟缁欏嚭瀹屾暣鐨勮瘖鏂粨璁恒��");
+            messages.add(diagnosisMessage);
+
+            StringBuilder assistantBuffer = new StringBuilder();
+            llmChatService.chatStreamWithTools(messages, temperature, maxTokens, tools, s -> {
+                try {
+                    String safe = s == null ? "" : s.replace("\r", "").replace("\n", "\\n");
+                    if (!safe.isEmpty()) {
+                        sse(emitter, safe);
+                        assistantBuffer.append(safe);
+                    }
+                } catch (Exception ignore) {}
+            }, () -> {
+                try {
+                    sse(emitter, "\\n\\n銆怉I銆戣繍琛屽凡鍋滄锛堟甯哥粨鏉燂級\\n\\n");
+                    log.info("AI MCP diagnose stopped: final end");
+                    emitter.complete();
+
+                    if (chatId != null) {
+                        String historyKey = RedisKeyType.AI_CHAT_HISTORY.key + chatId;
+                        String metaKey = RedisKeyType.AI_CHAT_META.key + chatId;
+
+                        ChatCompletionRequest.Message a = new ChatCompletionRequest.Message();
+                        a.setRole("assistant");
+                        a.setContent(assistantBuffer.toString());
+                        redisUtil.lSet(historyKey, userQuestion);
+                        redisUtil.lSet(historyKey, a);
+                        redisUtil.expire(historyKey, CHAT_TTL_SECONDS);
+                        Map<Object, Object> old = redisUtil.hmget(metaKey);
+                        Long createdAt = old != null && old.get("createdAt") != null ?
+                                (old.get("createdAt") instanceof Number ? ((Number) old.get("createdAt")).longValue() : Long.valueOf(String.valueOf(old.get("createdAt"))))
+                                : System.currentTimeMillis();
+                        Map<String, Object> meta = new java.util.HashMap<>();
+                        meta.put("chatId", chatId);
+                        meta.put("title", buildTitleFromPrompt(userQuestion.getContent()));
+                        meta.put("createdAt", createdAt);
+                        meta.put("updatedAt", System.currentTimeMillis());
+                        redisUtil.hmset(metaKey, meta, CHAT_TTL_SECONDS);
+                    }
+                } catch (Exception ignore) {}
+            }, e -> {
+                sse(emitter, "\\n\\n銆怉I銆戝垎鏋愬嚭閿欙紝姝e湪鍥為��...\\n\\n");
+            });
+            return true;
+        } catch (Exception e) {
+            try {
+                sse(emitter, "\\n\\n銆怉I銆戣繍琛屽凡鍋滄锛堝紓甯革級\\n\\n");
+                log.error("AI MCP diagnose stopped: error", e);
+                emitter.completeWithError(e);
+            } catch (Exception ignore) {}
+            return true;
+        }
+    }
+
+    private void sse(SseEmitter emitter, String data) {
+        if (data == null) return;
+        try {
+            emitter.send(SseEmitter.event().data(data));
+        } catch (Exception e) {
+            log.warn("SSE send failed", e);
+        }
+    }
+
+    private List<Object> buildOpenAiTools() {
+        if (mcpController == null) return java.util.Collections.emptyList();
+        List<Map<String, Object>> mcpTools = mcpController.listTools();
+        if (mcpTools == null || mcpTools.isEmpty()) return java.util.Collections.emptyList();
+
+        List<Object> tools = new ArrayList<>();
+        for (Map<String, Object> t : mcpTools) {
+            if (t == null) continue;
+            Object name = t.get("name");
+            if (name == null) continue;
+            Object inputSchema = t.get("inputSchema");
+            java.util.LinkedHashMap<String, Object> function = new java.util.LinkedHashMap<String, Object>();
+            function.put("name", String.valueOf(name));
+            Object desc = t.get("description");
+            if (desc != null) function.put("description", String.valueOf(desc));
+            function.put("parameters", inputSchema == null ? new java.util.LinkedHashMap<String, Object>() : inputSchema);
+
+            java.util.LinkedHashMap<String, Object> tool = new java.util.LinkedHashMap<String, Object>();
+            tool.put("type", "function");
+            tool.put("function", function);
+            tools.add(tool);
+        }
+        return tools;
+    }
+
+    private void sendLargeText(SseEmitter emitter, String text) {
+        if (text == null) return;
+        String safe = text.replace("\r", "").replace("\n", "\\n");
+        int chunkSize = 256;
+        int i = 0;
+        while (i < safe.length()) {
+            int end = Math.min(i + chunkSize, safe.length());
+            String part = safe.substring(i, end);
+            if (!part.isEmpty()) {
+                try { emitter.send(SseEmitter.event().data(part)); } catch (Exception ignore) {}
+            }
+            i = end;
+        }
+    }
+
+    private static final java.util.regex.Pattern DSML_INVOKE_PATTERN =
+            java.util.regex.Pattern.compile("<\\uFF5CDSML\\uFF5Cinvoke\\s+name=\\\"([^\\\"]+)\\\"[^>]*>([\\s\\S]*?)</\\uFF5CDSML\\uFF5Cinvoke>", java.util.regex.Pattern.MULTILINE);
+    private static final java.util.regex.Pattern JSON_OBJECT_PATTERN =
+            java.util.regex.Pattern.compile("\\{[\\s\\S]*\\}");
+    private static final java.util.regex.Pattern DSML_PARAM_PATTERN =
+            java.util.regex.Pattern.compile("<\\uFF5CDSML\\uFF5Cparameter\\s+name=\\\"([^\\\"]+)\\\"\\s*([^>]*)>([\\s\\S]*?)</\\uFF5CDSML\\uFF5Cparameter>", java.util.regex.Pattern.MULTILINE);
+
+    private java.util.List<DsmlInvocation> parseDsmlInvocations(String content) {
+        java.util.List<DsmlInvocation> list = new java.util.ArrayList<>();
+        if (content == null || content.isEmpty()) return list;
+        java.util.regex.Matcher m = DSML_INVOKE_PATTERN.matcher(content);
+        while (m.find()) {
+            String name = m.group(1);
+            String inner = m.group(2);
+            com.alibaba.fastjson.JSONObject args = null;
+            if (inner != null) {
+                java.util.regex.Matcher jm = JSON_OBJECT_PATTERN.matcher(inner);
+                if (jm.find()) {
+                    String json = jm.group();
+                    try { args = com.alibaba.fastjson.JSON.parseObject(json); } catch (Exception ignore) {}
+                }
+                java.util.regex.Matcher pm = DSML_PARAM_PATTERN.matcher(inner);
+                while (pm.find()) {
+                    if (args == null) args = new com.alibaba.fastjson.JSONObject();
+                    String pName = pm.group(1);
+                    String attr = pm.group(2);
+                    String valText = pm.group(3);
+                    boolean isString = attr != null && attr.toLowerCase().contains("string=\"true\"");
+                    String t = valText == null ? "" : valText.trim();
+                    if (isString) {
+                        args.put(pName, t);
+                    } else {
+                        if ("true".equalsIgnoreCase(t) || "false".equalsIgnoreCase(t)) {
+                            args.put(pName, Boolean.valueOf(t));
+                        } else {
+                            try {
+                                if (t.contains(".")) {
+                                    args.put(pName, Double.valueOf(t));
+                                } else {
+                                    args.put(pName, Long.valueOf(t));
+                                }
+                            } catch (Exception ex) {
+                                args.put(pName, t);
+                            }
+                        }
+                    }
+                }
+            }
+            DsmlInvocation inv = new DsmlInvocation();
+            inv.name = name;
+            inv.arguments = args;
+            list.add(inv);
+        }
+        return list;
+    }
+
+    private static class DsmlInvocation {
+        String name;
+        com.alibaba.fastjson.JSONObject arguments;
+    }
+
+    private List<DsmlInvocation> buildDefaultStatusInvocations() {
+        List<DsmlInvocation> list = new ArrayList<>();
+        DsmlInvocation crn = new DsmlInvocation();
+        crn.name = "device.get_crn_status";
+        com.alibaba.fastjson.JSONObject a1 = new com.alibaba.fastjson.JSONObject();
+        a1.put("limit", 20);
+        crn.arguments = a1;
+        list.add(crn);
+
+        DsmlInvocation st = new DsmlInvocation();
+        st.name = "device.get_station_status";
+        com.alibaba.fastjson.JSONObject a2 = new com.alibaba.fastjson.JSONObject();
+        a2.put("limit", 20);
+        st.arguments = a2;
+        list.add(st);
+
+        DsmlInvocation rgv = new DsmlInvocation();
+        rgv.name = "device.get_rgv_status";
+        com.alibaba.fastjson.JSONObject a3 = new com.alibaba.fastjson.JSONObject();
+        a3.put("limit", 20);
+        rgv.arguments = a3;
+        list.add(rgv);
+
+        return list;
+    }
+
+    private void ensureStatusCoverage(List<DsmlInvocation> invs) {
+        if (invs == null) return;
+        java.util.Set<String> names = new java.util.HashSet<String>();
+        for (DsmlInvocation d : invs) {
+            if (d != null && d.name != null) names.add(d.name);
+        }
+        if (!names.contains("device.get_crn_status") || !names.contains("device.get_station_status") || !names.contains("device.get_rgv_status")) {
+            List<DsmlInvocation> defaults = buildDefaultStatusInvocations();
+            for (DsmlInvocation d : defaults) {
+                if (!names.contains(d.name)) invs.add(d);
+            }
+        }
+    }
+
+    private boolean isConclusionText(String content) {
+        if (content == null) return false;
+        String c = content;
+        int len = c.length();
+        boolean longEnough = len >= 200;
+        boolean hasAllSections = c.contains("闂姒傝堪") && c.contains("鍙枒璁惧鍒楄〃") && c.contains("鍙兘鍘熷洜") && c.contains("寤鸿鎺掓煡姝ラ") && c.contains("椋庨櫓璇勪及");
+        boolean hasExplicitConclusion = (c.contains("缁撹") || c.contains("璇婃柇缁撴灉")) && longEnough;
+        return hasAllSections || hasExplicitConclusion;
+    }
 }

--
Gitblit v1.9.1