#
Junjie
4 天以前 0c1110daa59bf77ddcff2704641280f417158c10
src/main/java/com/zy/ai/service/WcsDiagnosisService.java
@@ -2,16 +2,17 @@
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.zy.ai.entity.AiPromptTemplate;
import com.zy.ai.entity.ChatCompletionRequest;
import com.zy.ai.entity.ChatCompletionResponse;
import com.zy.ai.entity.WcsDiagnosisRequest;
import com.zy.ai.mcp.controller.McpController;
import com.zy.ai.utils.AiPromptUtils;
import com.zy.ai.enums.AiPromptScene;
import com.zy.ai.mcp.service.SpringAiMcpToolManager;
import com.zy.ai.service.AiPromptTemplateService;
import com.zy.ai.utils.AiUtils;
import com.zy.common.utils.RedisUtil;
import com.zy.core.enums.RedisKeyType;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
@@ -28,78 +29,36 @@
    private static final long CHAT_TTL_SECONDS = 7L * 24 * 3600;
    @Value("${llm.platform}")
    private String platform;
    @Autowired
    private LlmChatService llmChatService;
    @Autowired
    private RedisUtil redisUtil;
    @Autowired
    private AiPromptUtils aiPromptUtils;
    @Autowired
    private AiUtils aiUtils;
    @Autowired(required = false)
    private McpController mcpController;
    @Autowired
    private PythonService pythonService;
    private SpringAiMcpToolManager mcpToolManager;
    @Autowired
    private AiPromptTemplateService aiPromptTemplateService;
    public void diagnoseStream(WcsDiagnosisRequest request, SseEmitter emitter) {
        List<ChatCompletionRequest.Message> messages = new ArrayList<>();
        AiPromptTemplate promptTemplate = aiPromptTemplateService.resolvePublished(AiPromptScene.DIAGNOSE_STREAM.getCode());
        ChatCompletionRequest.Message mcpSystem = new ChatCompletionRequest.Message();
        mcpSystem.setRole("system");
        mcpSystem.setContent(aiPromptUtils.getAiDiagnosePromptMcp());
        mcpSystem.setContent(promptTemplate.getContent());
        ChatCompletionRequest.Message mcpUser = new ChatCompletionRequest.Message();
        mcpUser.setRole("user");
        mcpUser.setContent(aiUtils.buildDiagnosisUserContentMcp(request));
        if (runMcpStreamingDiagnosis(messages, mcpSystem, mcpUser, 0.3, 2048, emitter, null)) {
            return;
        }
        messages = new ArrayList<>();
        ChatCompletionRequest.Message system = new ChatCompletionRequest.Message();
        system.setRole("system");
        system.setContent(aiPromptUtils.getAiDiagnosePrompt());
        messages.add(system);
        ChatCompletionRequest.Message user = new ChatCompletionRequest.Message();
        user.setRole("user");
        user.setContent(aiUtils.buildDiagnosisUserContent(request));
        messages.add(user);
        llmChatService.chatStream(messages, 0.3, 2048, s -> {
            try {
                String safe = s == null ? "" : s.replace("\r", "").replace("\n", "\\n");
                if (!safe.isEmpty()) {
                    emitter.send(SseEmitter.event().data(safe));
                }
            } catch (Exception ignore) {}
        }, () -> {
            try {
                log.info("AI diagnose stream stopped: normal end");
                emitter.complete();
            } catch (Exception ignore) {}
        }, e -> {
            try {
                try { emitter.send(SseEmitter.event().data("【AI】运行已停止(异常)")); } catch (Exception ignore) {}
                log.error("AI diagnose stream stopped: error", e);
                emitter.completeWithError(e);
            } catch (Exception ignore) {}
        });
        runMcpStreamingDiagnosis(messages, mcpSystem, mcpUser, promptTemplate, 0.3, 2048, emitter, null);
    }
    public void askStream(WcsDiagnosisRequest request,
                          String prompt,
    public void askStream(String prompt,
                          String chatId,
                          boolean reset,
                          SseEmitter emitter) {
        if (platform.equals("python")) {
            pythonService.runPython(prompt, chatId, emitter);
            return;
        }
        List<ChatCompletionRequest.Message> messages = new ArrayList<>();
        List<ChatCompletionRequest.Message> history = null;
@@ -124,71 +83,18 @@
            }
        }
        StringBuilder assistantBuffer = new StringBuilder();
        final String finalChatId = chatId;
        final String finalHistoryKey = historyKey;
        final String finalMetaKey = metaKey;
        final String finalPrompt = prompt;
        AiPromptTemplate promptTemplate = aiPromptTemplateService.resolvePublished(AiPromptScene.SENSOR_CHAT.getCode());
        ChatCompletionRequest.Message mcpSystem = new ChatCompletionRequest.Message();
        mcpSystem.setRole("system");
        mcpSystem.setContent(aiPromptUtils.getWcsSensorPromptMcp());
        mcpSystem.setContent(promptTemplate.getContent());
        ChatCompletionRequest.Message mcpUser = new ChatCompletionRequest.Message();
        mcpUser.setRole("user");
        mcpUser.setContent("【用户提问】\n" + (prompt == null ? "" : prompt));
        if (runMcpStreamingDiagnosis(messages, mcpSystem, mcpUser, 0.3, 2048, emitter, finalChatId)) {
            return;
        }
        messages = new ArrayList<>();
        ChatCompletionRequest.Message system = new ChatCompletionRequest.Message();
        system.setRole("system");
        system.setContent(aiPromptUtils.getWcsSensorPrompt());
        messages.add(system);
        ChatCompletionRequest.Message questionMsg = new ChatCompletionRequest.Message();
        questionMsg.setRole("user");
        questionMsg.setContent("【用户提问】\n" + (prompt == null ? "" : prompt));
        messages.add(questionMsg);
        llmChatService.chatStream(messages, 0.3, 2048, s -> {
            try {
                String safe = s == null ? "" : s.replace("\r", "").replace("\n", "\\n");
                if (!safe.isEmpty()) {
                    emitter.send(SseEmitter.event().data(safe));
                    assistantBuffer.append(s);
                }
            } catch (Exception ignore) {}
        }, () -> {
            try {
                if (finalChatId != null && !finalChatId.isEmpty()) {
                    ChatCompletionRequest.Message q = new ChatCompletionRequest.Message();
                    q.setRole("user");
                    q.setContent(finalPrompt == null ? "" : finalPrompt);
                    ChatCompletionRequest.Message a = new ChatCompletionRequest.Message();
                    a.setRole("assistant");
                    a.setContent(assistantBuffer.toString());
                    redisUtil.lSet(finalHistoryKey, q);
                    redisUtil.lSet(finalHistoryKey, a);
                    redisUtil.expire(finalHistoryKey, CHAT_TTL_SECONDS);
                    Map<Object, Object> old = redisUtil.hmget(finalMetaKey);
                    Long createdAt = old != null && old.get("createdAt") != null ?
                            (old.get("createdAt") instanceof Number ? ((Number) old.get("createdAt")).longValue() : Long.valueOf(String.valueOf(old.get("createdAt"))))
                            : System.currentTimeMillis();
                    Map<String, Object> meta = new java.util.HashMap<>();
                    meta.put("chatId", finalChatId);
                    meta.put("title", buildTitleFromPrompt(finalPrompt));
                    meta.put("createdAt", createdAt);
                    meta.put("updatedAt", System.currentTimeMillis());
                    redisUtil.hmset(finalMetaKey, meta, CHAT_TTL_SECONDS);
                }
                emitter.complete();
            } catch (Exception ignore) {}
        }, e -> {
            try { emitter.completeWithError(e); } catch (Exception ignore) {}
        });
        runMcpStreamingDiagnosis(messages, mcpSystem, mcpUser, promptTemplate, 0.3, 2048, emitter, finalChatId);
    }
    public List<Map<String, Object>> listChats() {
@@ -256,17 +162,22 @@
        return p.length() > 20 ? p.substring(0, 20) : p;
    }
 
    private boolean runMcpStreamingDiagnosis(List<ChatCompletionRequest.Message> baseMessages,
                                             ChatCompletionRequest.Message systemPrompt,
                                             ChatCompletionRequest.Message userQuestion,
                                             Double temperature,
                                             Integer maxTokens,
                                             SseEmitter emitter,
                                             String chatId) {
    private void runMcpStreamingDiagnosis(List<ChatCompletionRequest.Message> baseMessages,
                                          ChatCompletionRequest.Message systemPrompt,
                                          ChatCompletionRequest.Message userQuestion,
                                          AiPromptTemplate promptTemplate,
                                          Double temperature,
                                          Integer maxTokens,
                                          SseEmitter emitter,
                                          String chatId) {
        try {
            if (mcpController == null) return false;
            List<Object> tools = buildOpenAiTools();
            if (tools.isEmpty()) return false;
            if (mcpToolManager == null) {
                throw new IllegalStateException("Spring AI MCP tool manager is unavailable");
            }
            List<Object> tools = mcpToolManager.buildOpenAiTools();
            if (tools.isEmpty()) {
                throw new IllegalStateException("No MCP tools registered");
            }
            baseMessages.add(systemPrompt);
            baseMessages.add(userQuestion);
@@ -282,8 +193,7 @@
                sse(emitter, "\\n正在分析(第" + (i + 1) + "轮)...\\n");
                ChatCompletionResponse resp = llmChatService.chatCompletion(messages, temperature, maxTokens, tools);
                if (resp == null || resp.getChoices() == null || resp.getChoices().isEmpty() || resp.getChoices().get(0).getMessage() == null) {
                    sse(emitter, "\\n分析出错,正在回退...\\n");
                    return false;
                    throw new IllegalStateException("LLM returned empty response");
                }
                ChatCompletionRequest.Message assistant = resp.getChoices().get(0).getMessage();
@@ -310,7 +220,7 @@
                    }
                    Object output;
                    try {
                        output = mcpController.callTool(toolName, args);
                        output = mcpToolManager.callTool(toolName, args);
                    } catch (Exception e) {
                        java.util.LinkedHashMap<String, Object> err = new java.util.LinkedHashMap<String, Object>();
                        err.put("tool", toolName);
@@ -367,22 +277,30 @@
                        Map<String, Object> meta = new java.util.HashMap<>();
                        meta.put("chatId", chatId);
                        meta.put("title", buildTitleFromPrompt(userQuestion.getContent()));
                        if (promptTemplate != null) {
                            meta.put("promptTemplateId", promptTemplate.getId());
                            meta.put("promptSceneCode", promptTemplate.getSceneCode());
                            meta.put("promptVersion", promptTemplate.getVersion());
                            meta.put("promptName", promptTemplate.getName());
                        }
                        meta.put("createdAt", createdAt);
                        meta.put("updatedAt", System.currentTimeMillis());
                        redisUtil.hmset(metaKey, meta, CHAT_TTL_SECONDS);
                    }
                } catch (Exception ignore) {}
            }, e -> {
                sse(emitter, "\\n\\n【AI】分析出错,正在回退...\\n\\n");
                try {
                    sse(emitter, "\\n\\n【AI】分析出错,运行已停止(异常)\\n\\n");
                    log.error("AI MCP diagnose stopped: stream error", e);
                    emitter.complete();
                } catch (Exception ignore) {}
            });
            return true;
        } catch (Exception e) {
            try {
                sse(emitter, "\\n\\n【AI】运行已停止(异常)\\n\\n");
                log.error("AI MCP diagnose stopped: error", e);
                emitter.completeWithError(e);
                emitter.complete();
            } catch (Exception ignore) {}
            return true;
        }
    }
@@ -393,31 +311,6 @@
        } catch (Exception e) {
            log.warn("SSE send failed", e);
        }
    }
    private List<Object> buildOpenAiTools() {
        if (mcpController == null) return java.util.Collections.emptyList();
        List<Map<String, Object>> mcpTools = mcpController.listTools();
        if (mcpTools == null || mcpTools.isEmpty()) return java.util.Collections.emptyList();
        List<Object> tools = new ArrayList<>();
        for (Map<String, Object> t : mcpTools) {
            if (t == null) continue;
            Object name = t.get("name");
            if (name == null) continue;
            Object inputSchema = t.get("inputSchema");
            java.util.LinkedHashMap<String, Object> function = new java.util.LinkedHashMap<String, Object>();
            function.put("name", String.valueOf(name));
            Object desc = t.get("description");
            if (desc != null) function.put("description", String.valueOf(desc));
            function.put("parameters", inputSchema == null ? new java.util.LinkedHashMap<String, Object>() : inputSchema);
            java.util.LinkedHashMap<String, Object> tool = new java.util.LinkedHashMap<String, Object>();
            tool.put("type", "function");
            tool.put("function", function);
            tools.add(tool);
        }
        return tools;
    }
    private void sendLargeText(SseEmitter emitter, String text) {