#
Junjie
2 天以前 338f3b81425ab96d8c856909a775124af5365e3c
src/main/java/com/zy/ai/service/WcsDiagnosisService.java
@@ -35,8 +35,14 @@
    @Autowired
    private AiChatStoreService aiChatStoreService;
    public void diagnoseStream(WcsDiagnosisRequest request, SseEmitter emitter) {
    public void diagnoseStream(WcsDiagnosisRequest request,
                               String chatId,
                               boolean reset,
                               SseEmitter emitter) {
        List<ChatCompletionRequest.Message> messages = new ArrayList<>();
        if (chatId != null && !chatId.isEmpty() && reset) {
            aiChatStoreService.deleteChat(chatId);
        }
        AiPromptTemplate promptTemplate = aiPromptTemplateService.resolvePublished(AiPromptScene.DIAGNOSE_STREAM.getCode());
        ChatCompletionRequest.Message mcpSystem = new ChatCompletionRequest.Message();
@@ -47,7 +53,11 @@
        mcpUser.setRole("user");
        mcpUser.setContent(aiUtils.buildDiagnosisUserContentMcp(request));
        runMcpStreamingDiagnosis(messages, mcpSystem, mcpUser, promptTemplate, 0.3, 2048, emitter, null);
        ChatCompletionRequest.Message storedUser = new ChatCompletionRequest.Message();
        storedUser.setRole("user");
        storedUser.setContent(buildDiagnoseDisplayPrompt(request));
        runMcpStreamingDiagnosis(messages, mcpSystem, mcpUser, storedUser, promptTemplate, 0.3, 2048, emitter, chatId);
    }
    public void askStream(String prompt,
@@ -77,7 +87,7 @@
        mcpUser.setRole("user");
        mcpUser.setContent(prompt == null ? "" : prompt);
        runMcpStreamingDiagnosis(messages, mcpSystem, mcpUser, promptTemplate, 0.3, 2048, emitter, finalChatId);
        runMcpStreamingDiagnosis(messages, mcpSystem, mcpUser, mcpUser, promptTemplate, 0.3, 2048, emitter, finalChatId);
    }
    public List<Map<String, Object>> listChats() {
@@ -101,6 +111,7 @@
    private void runMcpStreamingDiagnosis(List<ChatCompletionRequest.Message> baseMessages,
                                          ChatCompletionRequest.Message systemPrompt,
                                          ChatCompletionRequest.Message userQuestion,
                                          ChatCompletionRequest.Message storedUserQuestion,
                                          AiPromptTemplate promptTemplate,
                                          Double temperature,
                                          Integer maxTokens,
@@ -115,6 +126,7 @@
                throw new IllegalStateException("No MCP tools registered");
            }
            AgentUsageStats usageStats = new AgentUsageStats();
            StringBuilder reasoningBuffer = new StringBuilder();
            baseMessages.add(systemPrompt);
            baseMessages.add(userQuestion);
@@ -123,11 +135,13 @@
            messages.addAll(baseMessages);
 
            sse(emitter, "<think>\\n正在初始化诊断与工具环境...\\n");
            appendReasoning(reasoningBuffer, "正在初始化诊断与工具环境...\n");
            int maxRound = 10;
            int i = 0;
            while(true) {
                sse(emitter, "\\n正在分析(第" + (i + 1) + "轮)...\\n");
                appendReasoning(reasoningBuffer, "\n正在分析(第" + (i + 1) + "轮)...\n");
                ChatCompletionResponse resp = llmChatService.chatCompletion(messages, temperature, maxTokens, tools);
                if (resp == null || resp.getChoices() == null || resp.getChoices().isEmpty() || resp.getChoices().get(0).getMessage() == null) {
                    throw new IllegalStateException("LLM returned empty response");
@@ -137,6 +151,7 @@
                ChatCompletionRequest.Message assistant = resp.getChoices().get(0).getMessage();
                messages.add(assistant);
                sse(emitter, assistant.getContent());
                appendReasoning(reasoningBuffer, assistant == null ? null : assistant.getContent());
                List<ChatCompletionRequest.ToolCall> toolCalls = assistant.getTool_calls();
                if (toolCalls == null || toolCalls.isEmpty()) {
@@ -147,6 +162,7 @@
                    String toolName = tc != null && tc.getFunction() != null ? tc.getFunction().getName() : null;
                    if (toolName == null || toolName.trim().isEmpty()) continue;
                    sse(emitter, "\\n准备调用工具:" + toolName + "\\n");
                    appendReasoning(reasoningBuffer, "\n准备调用工具:" + toolName + "\n");
                    JSONObject args = new JSONObject();
                    if (tc.getFunction() != null && tc.getFunction().getArguments() != null && !tc.getFunction().getArguments().trim().isEmpty()) {
                        try {
@@ -166,6 +182,7 @@
                        output = err;
                    }
                    sse(emitter, "\\n工具返回,正在继续推理...\\n");
                    appendReasoning(reasoningBuffer, "\n工具返回,正在继续推理...\n");
                    ChatCompletionRequest.Message toolMsg = new ChatCompletionRequest.Message();
                    toolMsg.setRole("tool");
                    toolMsg.setTool_call_id(tc == null ? null : tc.getId());
@@ -177,6 +194,7 @@
            }
            sse(emitter, "\\n正在根据数据进行分析...\\n</think>\\n\\n");
            appendReasoning(reasoningBuffer, "\n正在根据数据进行分析...\n");
            ChatCompletionRequest.Message diagnosisMessage = new ChatCompletionRequest.Message();
            diagnosisMessage.setRole("system");
@@ -203,9 +221,10 @@
                        ChatCompletionRequest.Message a = new ChatCompletionRequest.Message();
                        a.setRole("assistant");
                        a.setContent(assistantBuffer.toString());
                        a.setReasoningContent(reasoningBuffer.toString());
                        aiChatStoreService.saveConversation(chatId,
                                buildTitleFromPrompt(userQuestion.getContent()),
                                userQuestion,
                                buildTitleFromPrompt(storedUserQuestion == null ? null : storedUserQuestion.getContent()),
                                storedUserQuestion == null ? userQuestion : storedUserQuestion,
                                a,
                                promptTemplate,
                                usageStats.getPromptTokens(),
@@ -260,6 +279,13 @@
        payload.put("totalTokens", usageStats.getTotalTokens());
        payload.put("llmCallCount", usageStats.getLlmCallCount());
        return payload;
    }
    private void appendReasoning(StringBuilder reasoningBuffer, String text) {
        if (reasoningBuffer == null || text == null || text.isEmpty()) {
            return;
        }
        reasoningBuffer.append(text);
    }
    private void sendLargeText(SseEmitter emitter, String text) {
@@ -378,6 +404,13 @@
        }
    }
    private String buildDiagnoseDisplayPrompt(WcsDiagnosisRequest request) {
        if (request == null || request.getAlarmMessage() == null || request.getAlarmMessage().trim().isEmpty()) {
            return "对当前系统进行巡检";
        }
        return request.getAlarmMessage().trim();
    }
    private static class AgentUsageStats {
        private long promptTokens;
        private long completionTokens;