#
Junjie
16 小时以前 83af5944a32527fd8aa83537dd840d428af7f577
src/main/java/com/zy/ai/service/WcsDiagnosisService.java
@@ -8,10 +8,7 @@
import com.zy.ai.entity.WcsDiagnosisRequest;
import com.zy.ai.enums.AiPromptScene;
import com.zy.ai.mcp.service.SpringAiMcpToolManager;
import com.zy.ai.service.AiPromptTemplateService;
import com.zy.ai.utils.AiUtils;
import com.zy.common.utils.RedisUtil;
import com.zy.core.enums.RedisKeyType;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import lombok.RequiredArgsConstructor;
@@ -27,21 +24,25 @@
@Slf4j
public class WcsDiagnosisService {
    private static final long CHAT_TTL_SECONDS = 7L * 24 * 3600;
    @Autowired
    private LlmChatService llmChatService;
    @Autowired
    private RedisUtil redisUtil;
    @Autowired
    private AiUtils aiUtils;
    @Autowired
    private SpringAiMcpToolManager mcpToolManager;
    @Autowired
    private AiPromptTemplateService aiPromptTemplateService;
    @Autowired
    private AiChatStoreService aiChatStoreService;
    public void diagnoseStream(WcsDiagnosisRequest request, SseEmitter emitter) {
    public void diagnoseStream(WcsDiagnosisRequest request,
                               String chatId,
                               boolean reset,
                               SseEmitter emitter) {
        List<ChatCompletionRequest.Message> messages = new ArrayList<>();
        if (chatId != null && !chatId.isEmpty() && reset) {
            aiChatStoreService.deleteChat(chatId);
        }
        AiPromptTemplate promptTemplate = aiPromptTemplateService.resolvePublished(AiPromptScene.DIAGNOSE_STREAM.getCode());
        ChatCompletionRequest.Message mcpSystem = new ChatCompletionRequest.Message();
@@ -52,7 +53,11 @@
        mcpUser.setRole("user");
        mcpUser.setContent(aiUtils.buildDiagnosisUserContentMcp(request));
        runMcpStreamingDiagnosis(messages, mcpSystem, mcpUser, promptTemplate, 0.3, 2048, emitter, null);
        ChatCompletionRequest.Message storedUser = new ChatCompletionRequest.Message();
        storedUser.setRole("user");
        storedUser.setContent(buildDiagnoseDisplayPrompt(request));
        runMcpStreamingDiagnosis(messages, mcpSystem, mcpUser, storedUser, promptTemplate, 0.3, 2048, emitter, chatId);
    }
    public void askStream(String prompt,
@@ -61,25 +66,13 @@
                          SseEmitter emitter) {
        List<ChatCompletionRequest.Message> messages = new ArrayList<>();
        List<ChatCompletionRequest.Message> history = null;
        String historyKey = null;
        String metaKey = null;
        if (chatId != null && !chatId.isEmpty()) {
            historyKey = RedisKeyType.AI_CHAT_HISTORY.key + chatId;
            metaKey = RedisKeyType.AI_CHAT_META.key + chatId;
            if (reset) {
                redisUtil.del(historyKey, metaKey);
                aiChatStoreService.deleteChat(chatId);
            }
            List<Object> stored = redisUtil.lGet(historyKey, 0, -1);
            if (stored != null && !stored.isEmpty()) {
                history = new ArrayList<>(stored.size());
                for (Object o : stored) {
                    ChatCompletionRequest.Message m = convertToMessage(o);
                    if (m != null) history.add(m);
                }
                if (!history.isEmpty()) messages.addAll(history);
            } else {
                history = new ArrayList<>();
            List<ChatCompletionRequest.Message> history = aiChatStoreService.getChatHistory(chatId);
            if (history != null && !history.isEmpty()) {
                messages.addAll(history);
            }
        }
@@ -92,68 +85,21 @@
        ChatCompletionRequest.Message mcpUser = new ChatCompletionRequest.Message();
        mcpUser.setRole("user");
        mcpUser.setContent("【用户提问】\n" + (prompt == null ? "" : prompt));
        mcpUser.setContent(prompt == null ? "" : prompt);
        runMcpStreamingDiagnosis(messages, mcpSystem, mcpUser, promptTemplate, 0.3, 2048, emitter, finalChatId);
        runMcpStreamingDiagnosis(messages, mcpSystem, mcpUser, mcpUser, promptTemplate, 0.3, 2048, emitter, finalChatId);
    }
    public List<Map<String, Object>> listChats() {
        java.util.Set<String> keys = redisUtil.scanKeys(RedisKeyType.AI_CHAT_META.key, 1000);
        List<Map<String, Object>> resp = new ArrayList<>();
        if (keys != null) {
            for (String key : keys) {
                Map<Object, Object> m = redisUtil.hmget(key);
                if (m != null && !m.isEmpty()) {
                    java.util.HashMap<String, Object> item = new java.util.HashMap<>();
                    for (Map.Entry<Object, Object> e : m.entrySet()) {
                        item.put(String.valueOf(e.getKey()), e.getValue());
                    }
                    String chatId = String.valueOf(item.get("chatId"));
                    String historyKey = RedisKeyType.AI_CHAT_HISTORY.key + chatId;
                    item.put("size", redisUtil.lGetListSize(historyKey));
                    resp.add(item);
                }
            }
        }
        return resp;
        return aiChatStoreService.listChats();
    }
    public boolean deleteChat(String chatId) {
        if (chatId == null || chatId.isEmpty()) return false;
        String historyKey = RedisKeyType.AI_CHAT_HISTORY.key + chatId;
        String metaKey = RedisKeyType.AI_CHAT_META.key + chatId;
        redisUtil.del(historyKey, metaKey);
        return true;
        return aiChatStoreService.deleteChat(chatId);
    }
    public List<ChatCompletionRequest.Message> getChatHistory(String chatId) {
        if (chatId == null || chatId.isEmpty()) return java.util.Collections.emptyList();
        String historyKey = RedisKeyType.AI_CHAT_HISTORY.key + chatId;
        List<Object> stored = redisUtil.lGet(historyKey, 0, -1);
        List<ChatCompletionRequest.Message> result = new ArrayList<>();
        if (stored != null) {
            for (Object o : stored) {
                ChatCompletionRequest.Message m = convertToMessage(o);
                if (m != null) result.add(m);
            }
        }
        return result;
    }
    private ChatCompletionRequest.Message convertToMessage(Object o) {
        if (o instanceof ChatCompletionRequest.Message) {
            return (ChatCompletionRequest.Message) o;
        }
        if (o instanceof Map) {
            Map<?, ?> map = (Map<?, ?>) o;
            ChatCompletionRequest.Message m = new ChatCompletionRequest.Message();
            Object role = map.get("role");
            Object content = map.get("content");
            m.setRole(role == null ? null : String.valueOf(role));
            m.setContent(content == null ? null : String.valueOf(content));
            return m;
        }
        return null;
        return aiChatStoreService.getChatHistory(chatId);
    }
    private String buildTitleFromPrompt(String prompt) {
@@ -165,6 +111,7 @@
    private void runMcpStreamingDiagnosis(List<ChatCompletionRequest.Message> baseMessages,
                                          ChatCompletionRequest.Message systemPrompt,
                                          ChatCompletionRequest.Message userQuestion,
                                          ChatCompletionRequest.Message storedUserQuestion,
                                          AiPromptTemplate promptTemplate,
                                          Double temperature,
                                          Integer maxTokens,
@@ -178,6 +125,8 @@
            if (tools.isEmpty()) {
                throw new IllegalStateException("No MCP tools registered");
            }
            AgentUsageStats usageStats = new AgentUsageStats();
            StringBuilder reasoningBuffer = new StringBuilder();
            baseMessages.add(systemPrompt);
            baseMessages.add(userQuestion);
@@ -186,19 +135,23 @@
            messages.addAll(baseMessages);
 
            sse(emitter, "<think>\\n正在初始化诊断与工具环境...\\n");
            appendReasoning(reasoningBuffer, "正在初始化诊断与工具环境...\n");
            int maxRound = 10;
            int i = 0;
            while(true) {
                sse(emitter, "\\n正在分析(第" + (i + 1) + "轮)...\\n");
                appendReasoning(reasoningBuffer, "\n正在分析(第" + (i + 1) + "轮)...\n");
                ChatCompletionResponse resp = llmChatService.chatCompletion(messages, temperature, maxTokens, tools);
                if (resp == null || resp.getChoices() == null || resp.getChoices().isEmpty() || resp.getChoices().get(0).getMessage() == null) {
                    throw new IllegalStateException("LLM returned empty response");
                }
                usageStats.add(resp.getUsage());
                ChatCompletionRequest.Message assistant = resp.getChoices().get(0).getMessage();
                messages.add(assistant);
                sse(emitter, assistant.getContent());
                appendReasoning(reasoningBuffer, assistant == null ? null : assistant.getContent());
                List<ChatCompletionRequest.ToolCall> toolCalls = assistant.getTool_calls();
                if (toolCalls == null || toolCalls.isEmpty()) {
@@ -209,6 +162,7 @@
                    String toolName = tc != null && tc.getFunction() != null ? tc.getFunction().getName() : null;
                    if (toolName == null || toolName.trim().isEmpty()) continue;
                    sse(emitter, "\\n准备调用工具:" + toolName + "\\n");
                    appendReasoning(reasoningBuffer, "\n准备调用工具:" + toolName + "\n");
                    JSONObject args = new JSONObject();
                    if (tc.getFunction() != null && tc.getFunction().getArguments() != null && !tc.getFunction().getArguments().trim().isEmpty()) {
                        try {
@@ -228,6 +182,7 @@
                        output = err;
                    }
                    sse(emitter, "\\n工具返回,正在继续推理...\\n");
                    appendReasoning(reasoningBuffer, "\n工具返回,正在继续推理...\n");
                    ChatCompletionRequest.Message toolMsg = new ChatCompletionRequest.Message();
                    toolMsg.setRole("tool");
                    toolMsg.setTool_call_id(tc == null ? null : tc.getId());
@@ -239,6 +194,7 @@
            }
            sse(emitter, "\\n正在根据数据进行分析...\\n</think>\\n\\n");
            appendReasoning(reasoningBuffer, "\n正在根据数据进行分析...\n");
            ChatCompletionRequest.Message diagnosisMessage = new ChatCompletionRequest.Message();
            diagnosisMessage.setRole("system");
@@ -256,45 +212,35 @@
                } catch (Exception ignore) {}
            }, () -> {
                try {
                    emitTokenUsage(emitter, usageStats);
                    sse(emitter, "\\n\\n【AI】运行已停止(正常结束)\\n\\n");
                    log.info("AI MCP diagnose stopped: final end");
                    emitter.complete();
                    if (chatId != null) {
                        String historyKey = RedisKeyType.AI_CHAT_HISTORY.key + chatId;
                        String metaKey = RedisKeyType.AI_CHAT_META.key + chatId;
                        ChatCompletionRequest.Message a = new ChatCompletionRequest.Message();
                        a.setRole("assistant");
                        a.setContent(assistantBuffer.toString());
                        redisUtil.lSet(historyKey, userQuestion);
                        redisUtil.lSet(historyKey, a);
                        redisUtil.expire(historyKey, CHAT_TTL_SECONDS);
                        Map<Object, Object> old = redisUtil.hmget(metaKey);
                        Long createdAt = old != null && old.get("createdAt") != null ?
                                (old.get("createdAt") instanceof Number ? ((Number) old.get("createdAt")).longValue() : Long.valueOf(String.valueOf(old.get("createdAt"))))
                                : System.currentTimeMillis();
                        Map<String, Object> meta = new java.util.HashMap<>();
                        meta.put("chatId", chatId);
                        meta.put("title", buildTitleFromPrompt(userQuestion.getContent()));
                        if (promptTemplate != null) {
                            meta.put("promptTemplateId", promptTemplate.getId());
                            meta.put("promptSceneCode", promptTemplate.getSceneCode());
                            meta.put("promptVersion", promptTemplate.getVersion());
                            meta.put("promptName", promptTemplate.getName());
                        }
                        meta.put("createdAt", createdAt);
                        meta.put("updatedAt", System.currentTimeMillis());
                        redisUtil.hmset(metaKey, meta, CHAT_TTL_SECONDS);
                        a.setReasoningContent(reasoningBuffer.toString());
                        aiChatStoreService.saveConversation(chatId,
                                buildTitleFromPrompt(storedUserQuestion == null ? null : storedUserQuestion.getContent()),
                                storedUserQuestion == null ? userQuestion : storedUserQuestion,
                                a,
                                promptTemplate,
                                usageStats.getPromptTokens(),
                                usageStats.getCompletionTokens(),
                                usageStats.getTotalTokens(),
                                usageStats.getLlmCallCount());
                    }
                } catch (Exception ignore) {}
            }, e -> {
                try {
                    emitTokenUsage(emitter, usageStats);
                    sse(emitter, "\\n\\n【AI】分析出错,运行已停止(异常)\\n\\n");
                    log.error("AI MCP diagnose stopped: stream error", e);
                    emitter.complete();
                } catch (Exception ignore) {}
            });
            }, usageStats::add);
        } catch (Exception e) {
            try {
                sse(emitter, "\\n\\n【AI】运行已停止(异常)\\n\\n");
@@ -311,6 +257,35 @@
        } catch (Exception e) {
            log.warn("SSE send failed", e);
        }
    }
    private void emitTokenUsage(SseEmitter emitter, AgentUsageStats usageStats) {
        if (emitter == null || usageStats == null || usageStats.getTotalTokens() <= 0) {
            return;
        }
        try {
            emitter.send(SseEmitter.event()
                    .name("token_usage")
                    .data(JSON.toJSONString(buildTokenUsagePayload(usageStats))));
        } catch (Exception e) {
            log.warn("SSE token usage send failed", e);
        }
    }
    private Map<String, Object> buildTokenUsagePayload(AgentUsageStats usageStats) {
        java.util.LinkedHashMap<String, Object> payload = new java.util.LinkedHashMap<>();
        payload.put("promptTokens", usageStats.getPromptTokens());
        payload.put("completionTokens", usageStats.getCompletionTokens());
        payload.put("totalTokens", usageStats.getTotalTokens());
        payload.put("llmCallCount", usageStats.getLlmCallCount());
        return payload;
    }
    private void appendReasoning(StringBuilder reasoningBuffer, String text) {
        if (reasoningBuffer == null || text == null || text.isEmpty()) {
            return;
        }
        reasoningBuffer.append(text);
    }
    private void sendLargeText(SseEmitter emitter, String text) {
@@ -429,6 +404,46 @@
        }
    }
    private String buildDiagnoseDisplayPrompt(WcsDiagnosisRequest request) {
        if (request == null || request.getAlarmMessage() == null || request.getAlarmMessage().trim().isEmpty()) {
            return "对当前系统进行巡检";
        }
        return request.getAlarmMessage().trim();
    }
    private static class AgentUsageStats {
        private long promptTokens;
        private long completionTokens;
        private long totalTokens;
        private int llmCallCount;
        void add(ChatCompletionResponse.Usage usage) {
            if (usage == null) {
                return;
            }
            promptTokens += usage.getPromptTokens() == null ? 0L : usage.getPromptTokens();
            completionTokens += usage.getCompletionTokens() == null ? 0L : usage.getCompletionTokens();
            totalTokens += usage.getTotalTokens() == null ? 0L : usage.getTotalTokens();
            llmCallCount++;
        }
        long getPromptTokens() {
            return promptTokens;
        }
        long getCompletionTokens() {
            return completionTokens;
        }
        long getTotalTokens() {
            return totalTokens;
        }
        int getLlmCallCount() {
            return llmCallCount;
        }
    }
    private boolean isConclusionText(String content) {
        if (content == null) return false;
        String c = content;