package com.zy.ai.service; import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONObject; import com.zy.ai.entity.AiPromptTemplate; import com.zy.ai.entity.ChatCompletionRequest; import com.zy.ai.entity.ChatCompletionResponse; import com.zy.ai.entity.WcsDiagnosisRequest; import com.zy.ai.enums.AiPromptScene; import com.zy.ai.mcp.service.SpringAiMcpToolManager; import com.zy.ai.utils.AiUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; import java.util.ArrayList; import java.util.List; import java.util.Map; @Service @RequiredArgsConstructor @Slf4j public class WcsDiagnosisService { @Autowired private LlmChatService llmChatService; @Autowired private AiUtils aiUtils; @Autowired private SpringAiMcpToolManager mcpToolManager; @Autowired private AiPromptTemplateService aiPromptTemplateService; @Autowired private AiChatStoreService aiChatStoreService; public void diagnoseStream(WcsDiagnosisRequest request, SseEmitter emitter) { List messages = new ArrayList<>(); AiPromptTemplate promptTemplate = aiPromptTemplateService.resolvePublished(AiPromptScene.DIAGNOSE_STREAM.getCode()); ChatCompletionRequest.Message mcpSystem = new ChatCompletionRequest.Message(); mcpSystem.setRole("system"); mcpSystem.setContent(promptTemplate.getContent()); ChatCompletionRequest.Message mcpUser = new ChatCompletionRequest.Message(); mcpUser.setRole("user"); mcpUser.setContent(aiUtils.buildDiagnosisUserContentMcp(request)); runMcpStreamingDiagnosis(messages, mcpSystem, mcpUser, promptTemplate, 0.3, 2048, emitter, null); } public void askStream(String prompt, String chatId, boolean reset, SseEmitter emitter) { List messages = new ArrayList<>(); if (chatId != null && !chatId.isEmpty()) { if (reset) { aiChatStoreService.deleteChat(chatId); } List history = aiChatStoreService.getChatHistory(chatId); if (history != null && !history.isEmpty()) { messages.addAll(history); } } final String finalChatId = chatId; AiPromptTemplate promptTemplate = aiPromptTemplateService.resolvePublished(AiPromptScene.SENSOR_CHAT.getCode()); ChatCompletionRequest.Message mcpSystem = new ChatCompletionRequest.Message(); mcpSystem.setRole("system"); mcpSystem.setContent(promptTemplate.getContent()); ChatCompletionRequest.Message mcpUser = new ChatCompletionRequest.Message(); mcpUser.setRole("user"); mcpUser.setContent(prompt == null ? "" : prompt); runMcpStreamingDiagnosis(messages, mcpSystem, mcpUser, promptTemplate, 0.3, 2048, emitter, finalChatId); } public List> listChats() { return aiChatStoreService.listChats(); } public boolean deleteChat(String chatId) { return aiChatStoreService.deleteChat(chatId); } public List getChatHistory(String chatId) { return aiChatStoreService.getChatHistory(chatId); } private String buildTitleFromPrompt(String prompt) { if (prompt == null || prompt.isEmpty()) return "未命名会话"; String p = prompt.replaceAll("\n", " ").trim(); return p.length() > 20 ? p.substring(0, 20) : p; } private void runMcpStreamingDiagnosis(List baseMessages, ChatCompletionRequest.Message systemPrompt, ChatCompletionRequest.Message userQuestion, AiPromptTemplate promptTemplate, Double temperature, Integer maxTokens, SseEmitter emitter, String chatId) { try { if (mcpToolManager == null) { throw new IllegalStateException("Spring AI MCP tool manager is unavailable"); } List tools = mcpToolManager.buildOpenAiTools(); if (tools.isEmpty()) { throw new IllegalStateException("No MCP tools registered"); } AgentUsageStats usageStats = new AgentUsageStats(); baseMessages.add(systemPrompt); baseMessages.add(userQuestion); List messages = new ArrayList<>(baseMessages.size() + 8); messages.addAll(baseMessages); sse(emitter, "\\n正在初始化诊断与工具环境...\\n"); int maxRound = 10; int i = 0; while(true) { sse(emitter, "\\n正在分析(第" + (i + 1) + "轮)...\\n"); ChatCompletionResponse resp = llmChatService.chatCompletion(messages, temperature, maxTokens, tools); if (resp == null || resp.getChoices() == null || resp.getChoices().isEmpty() || resp.getChoices().get(0).getMessage() == null) { throw new IllegalStateException("LLM returned empty response"); } usageStats.add(resp.getUsage()); ChatCompletionRequest.Message assistant = resp.getChoices().get(0).getMessage(); messages.add(assistant); sse(emitter, assistant.getContent()); List toolCalls = assistant.getTool_calls(); if (toolCalls == null || toolCalls.isEmpty()) { break; } for (ChatCompletionRequest.ToolCall tc : toolCalls) { String toolName = tc != null && tc.getFunction() != null ? tc.getFunction().getName() : null; if (toolName == null || toolName.trim().isEmpty()) continue; sse(emitter, "\\n准备调用工具:" + toolName + "\\n"); JSONObject args = new JSONObject(); if (tc.getFunction() != null && tc.getFunction().getArguments() != null && !tc.getFunction().getArguments().trim().isEmpty()) { try { args = JSON.parseObject(tc.getFunction().getArguments()); } catch (Exception ignore) { args = new JSONObject(); args.put("_raw", tc.getFunction().getArguments()); } } Object output; try { output = mcpToolManager.callTool(toolName, args); } catch (Exception e) { java.util.LinkedHashMap err = new java.util.LinkedHashMap(); err.put("tool", toolName); err.put("error", e.getMessage()); output = err; } sse(emitter, "\\n工具返回,正在继续推理...\\n"); ChatCompletionRequest.Message toolMsg = new ChatCompletionRequest.Message(); toolMsg.setRole("tool"); toolMsg.setTool_call_id(tc == null ? null : tc.getId()); toolMsg.setContent(JSON.toJSONString(output)); messages.add(toolMsg); } if(i++ >= maxRound) break; } sse(emitter, "\\n正在根据数据进行分析...\\n\\n\\n"); ChatCompletionRequest.Message diagnosisMessage = new ChatCompletionRequest.Message(); diagnosisMessage.setRole("system"); diagnosisMessage.setContent("根据以上信息进行分析,并给出完整的诊断结论。"); messages.add(diagnosisMessage); StringBuilder assistantBuffer = new StringBuilder(); llmChatService.chatStreamWithTools(messages, temperature, maxTokens, tools, s -> { try { String safe = s == null ? "" : s.replace("\r", "").replace("\n", "\\n"); if (!safe.isEmpty()) { sse(emitter, safe); assistantBuffer.append(safe); } } catch (Exception ignore) {} }, () -> { try { emitTokenUsage(emitter, usageStats); sse(emitter, "\\n\\n【AI】运行已停止(正常结束)\\n\\n"); log.info("AI MCP diagnose stopped: final end"); emitter.complete(); if (chatId != null) { ChatCompletionRequest.Message a = new ChatCompletionRequest.Message(); a.setRole("assistant"); a.setContent(assistantBuffer.toString()); aiChatStoreService.saveConversation(chatId, buildTitleFromPrompt(userQuestion.getContent()), userQuestion, a, promptTemplate, usageStats.getPromptTokens(), usageStats.getCompletionTokens(), usageStats.getTotalTokens(), usageStats.getLlmCallCount()); } } catch (Exception ignore) {} }, e -> { try { emitTokenUsage(emitter, usageStats); sse(emitter, "\\n\\n【AI】分析出错,运行已停止(异常)\\n\\n"); log.error("AI MCP diagnose stopped: stream error", e); emitter.complete(); } catch (Exception ignore) {} }, usageStats::add); } catch (Exception e) { try { sse(emitter, "\\n\\n【AI】运行已停止(异常)\\n\\n"); log.error("AI MCP diagnose stopped: error", e); emitter.complete(); } catch (Exception ignore) {} } } private void sse(SseEmitter emitter, String data) { if (data == null) return; try { emitter.send(SseEmitter.event().data(data)); } catch (Exception e) { log.warn("SSE send failed", e); } } private void emitTokenUsage(SseEmitter emitter, AgentUsageStats usageStats) { if (emitter == null || usageStats == null || usageStats.getTotalTokens() <= 0) { return; } try { emitter.send(SseEmitter.event() .name("token_usage") .data(JSON.toJSONString(buildTokenUsagePayload(usageStats)))); } catch (Exception e) { log.warn("SSE token usage send failed", e); } } private Map buildTokenUsagePayload(AgentUsageStats usageStats) { java.util.LinkedHashMap payload = new java.util.LinkedHashMap<>(); payload.put("promptTokens", usageStats.getPromptTokens()); payload.put("completionTokens", usageStats.getCompletionTokens()); payload.put("totalTokens", usageStats.getTotalTokens()); payload.put("llmCallCount", usageStats.getLlmCallCount()); return payload; } private void sendLargeText(SseEmitter emitter, String text) { if (text == null) return; String safe = text.replace("\r", "").replace("\n", "\\n"); int chunkSize = 256; int i = 0; while (i < safe.length()) { int end = Math.min(i + chunkSize, safe.length()); String part = safe.substring(i, end); if (!part.isEmpty()) { try { emitter.send(SseEmitter.event().data(part)); } catch (Exception ignore) {} } i = end; } } private static final java.util.regex.Pattern DSML_INVOKE_PATTERN = java.util.regex.Pattern.compile("<\\uFF5CDSML\\uFF5Cinvoke\\s+name=\\\"([^\\\"]+)\\\"[^>]*>([\\s\\S]*?)", java.util.regex.Pattern.MULTILINE); private static final java.util.regex.Pattern JSON_OBJECT_PATTERN = java.util.regex.Pattern.compile("\\{[\\s\\S]*\\}"); private static final java.util.regex.Pattern DSML_PARAM_PATTERN = java.util.regex.Pattern.compile("<\\uFF5CDSML\\uFF5Cparameter\\s+name=\\\"([^\\\"]+)\\\"\\s*([^>]*)>([\\s\\S]*?)", java.util.regex.Pattern.MULTILINE); private java.util.List parseDsmlInvocations(String content) { java.util.List list = new java.util.ArrayList<>(); if (content == null || content.isEmpty()) return list; java.util.regex.Matcher m = DSML_INVOKE_PATTERN.matcher(content); while (m.find()) { String name = m.group(1); String inner = m.group(2); com.alibaba.fastjson.JSONObject args = null; if (inner != null) { java.util.regex.Matcher jm = JSON_OBJECT_PATTERN.matcher(inner); if (jm.find()) { String json = jm.group(); try { args = com.alibaba.fastjson.JSON.parseObject(json); } catch (Exception ignore) {} } java.util.regex.Matcher pm = DSML_PARAM_PATTERN.matcher(inner); while (pm.find()) { if (args == null) args = new com.alibaba.fastjson.JSONObject(); String pName = pm.group(1); String attr = pm.group(2); String valText = pm.group(3); boolean isString = attr != null && attr.toLowerCase().contains("string=\"true\""); String t = valText == null ? "" : valText.trim(); if (isString) { args.put(pName, t); } else { if ("true".equalsIgnoreCase(t) || "false".equalsIgnoreCase(t)) { args.put(pName, Boolean.valueOf(t)); } else { try { if (t.contains(".")) { args.put(pName, Double.valueOf(t)); } else { args.put(pName, Long.valueOf(t)); } } catch (Exception ex) { args.put(pName, t); } } } } } DsmlInvocation inv = new DsmlInvocation(); inv.name = name; inv.arguments = args; list.add(inv); } return list; } private static class DsmlInvocation { String name; com.alibaba.fastjson.JSONObject arguments; } private List buildDefaultStatusInvocations() { List list = new ArrayList<>(); DsmlInvocation crn = new DsmlInvocation(); crn.name = "device.get_crn_status"; com.alibaba.fastjson.JSONObject a1 = new com.alibaba.fastjson.JSONObject(); a1.put("limit", 20); crn.arguments = a1; list.add(crn); DsmlInvocation st = new DsmlInvocation(); st.name = "device.get_station_status"; com.alibaba.fastjson.JSONObject a2 = new com.alibaba.fastjson.JSONObject(); a2.put("limit", 20); st.arguments = a2; list.add(st); DsmlInvocation rgv = new DsmlInvocation(); rgv.name = "device.get_rgv_status"; com.alibaba.fastjson.JSONObject a3 = new com.alibaba.fastjson.JSONObject(); a3.put("limit", 20); rgv.arguments = a3; list.add(rgv); return list; } private void ensureStatusCoverage(List invs) { if (invs == null) return; java.util.Set names = new java.util.HashSet(); for (DsmlInvocation d : invs) { if (d != null && d.name != null) names.add(d.name); } if (!names.contains("device.get_crn_status") || !names.contains("device.get_station_status") || !names.contains("device.get_rgv_status")) { List defaults = buildDefaultStatusInvocations(); for (DsmlInvocation d : defaults) { if (!names.contains(d.name)) invs.add(d); } } } private static class AgentUsageStats { private long promptTokens; private long completionTokens; private long totalTokens; private int llmCallCount; void add(ChatCompletionResponse.Usage usage) { if (usage == null) { return; } promptTokens += usage.getPromptTokens() == null ? 0L : usage.getPromptTokens(); completionTokens += usage.getCompletionTokens() == null ? 0L : usage.getCompletionTokens(); totalTokens += usage.getTotalTokens() == null ? 0L : usage.getTotalTokens(); llmCallCount++; } long getPromptTokens() { return promptTokens; } long getCompletionTokens() { return completionTokens; } long getTotalTokens() { return totalTokens; } int getLlmCallCount() { return llmCallCount; } } private boolean isConclusionText(String content) { if (content == null) return false; String c = content; int len = c.length(); boolean longEnough = len >= 200; boolean hasAllSections = c.contains("问题概述") && c.contains("可疑设备列表") && c.contains("可能原因") && c.contains("建议排查步骤") && c.contains("风险评估"); boolean hasExplicitConclusion = (c.contains("结论") || c.contains("诊断结果")) && longEnough; return hasAllSections || hasExplicitConclusion; } }