package com.zy.ai.service;
|
|
import com.alibaba.fastjson.JSON;
|
import com.alibaba.fastjson.JSONObject;
|
import com.zy.ai.entity.AiPromptTemplate;
|
import com.zy.ai.entity.ChatCompletionRequest;
|
import com.zy.ai.entity.ChatCompletionResponse;
|
import com.zy.ai.entity.WcsDiagnosisRequest;
|
import com.zy.ai.enums.AiPromptScene;
|
import com.zy.ai.mcp.service.SpringAiMcpToolManager;
|
import com.zy.ai.utils.AiUtils;
|
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
import lombok.RequiredArgsConstructor;
|
import lombok.extern.slf4j.Slf4j;
|
import org.springframework.stereotype.Service;
|
|
import java.util.ArrayList;
|
import java.util.List;
|
import java.util.Map;
|
|
@Service
|
@RequiredArgsConstructor
|
@Slf4j
|
public class WcsDiagnosisService {
|
|
@Autowired
|
private LlmChatService llmChatService;
|
@Autowired
|
private AiUtils aiUtils;
|
@Autowired
|
private SpringAiMcpToolManager mcpToolManager;
|
@Autowired
|
private AiPromptTemplateService aiPromptTemplateService;
|
@Autowired
|
private AiChatStoreService aiChatStoreService;
|
|
public void diagnoseStream(WcsDiagnosisRequest request,
|
String chatId,
|
boolean reset,
|
SseEmitter emitter) {
|
List<ChatCompletionRequest.Message> messages = new ArrayList<>();
|
if (chatId != null && !chatId.isEmpty() && reset) {
|
aiChatStoreService.deleteChat(chatId);
|
}
|
AiPromptTemplate promptTemplate = aiPromptTemplateService.resolvePublished(AiPromptScene.DIAGNOSE_STREAM.getCode());
|
|
ChatCompletionRequest.Message mcpSystem = new ChatCompletionRequest.Message();
|
mcpSystem.setRole("system");
|
mcpSystem.setContent(promptTemplate.getContent());
|
|
ChatCompletionRequest.Message mcpUser = new ChatCompletionRequest.Message();
|
mcpUser.setRole("user");
|
mcpUser.setContent(aiUtils.buildDiagnosisUserContentMcp(request));
|
|
ChatCompletionRequest.Message storedUser = new ChatCompletionRequest.Message();
|
storedUser.setRole("user");
|
storedUser.setContent(buildDiagnoseDisplayPrompt(request));
|
|
runMcpStreamingDiagnosis(messages, mcpSystem, mcpUser, storedUser, promptTemplate, 0.3, 2048, emitter, chatId);
|
}
|
|
public void askStream(String prompt,
|
String chatId,
|
boolean reset,
|
SseEmitter emitter) {
|
List<ChatCompletionRequest.Message> messages = new ArrayList<>();
|
|
if (chatId != null && !chatId.isEmpty()) {
|
if (reset) {
|
aiChatStoreService.deleteChat(chatId);
|
}
|
List<ChatCompletionRequest.Message> history = aiChatStoreService.getChatHistory(chatId);
|
if (history != null && !history.isEmpty()) {
|
messages.addAll(history);
|
}
|
}
|
|
final String finalChatId = chatId;
|
AiPromptTemplate promptTemplate = aiPromptTemplateService.resolvePublished(AiPromptScene.SENSOR_CHAT.getCode());
|
|
ChatCompletionRequest.Message mcpSystem = new ChatCompletionRequest.Message();
|
mcpSystem.setRole("system");
|
mcpSystem.setContent(promptTemplate.getContent());
|
|
ChatCompletionRequest.Message mcpUser = new ChatCompletionRequest.Message();
|
mcpUser.setRole("user");
|
mcpUser.setContent(prompt == null ? "" : prompt);
|
|
runMcpStreamingDiagnosis(messages, mcpSystem, mcpUser, mcpUser, promptTemplate, 0.3, 2048, emitter, finalChatId);
|
}
|
|
public List<Map<String, Object>> listChats() {
|
return aiChatStoreService.listChats();
|
}
|
|
public boolean deleteChat(String chatId) {
|
return aiChatStoreService.deleteChat(chatId);
|
}
|
|
public List<ChatCompletionRequest.Message> getChatHistory(String chatId) {
|
return aiChatStoreService.getChatHistory(chatId);
|
}
|
|
private String buildTitleFromPrompt(String prompt) {
|
if (prompt == null || prompt.isEmpty()) return "未命名会话";
|
String p = prompt.replaceAll("\n", " ").trim();
|
return p.length() > 20 ? p.substring(0, 20) : p;
|
}
|
|
private void runMcpStreamingDiagnosis(List<ChatCompletionRequest.Message> baseMessages,
|
ChatCompletionRequest.Message systemPrompt,
|
ChatCompletionRequest.Message userQuestion,
|
ChatCompletionRequest.Message storedUserQuestion,
|
AiPromptTemplate promptTemplate,
|
Double temperature,
|
Integer maxTokens,
|
SseEmitter emitter,
|
String chatId) {
|
try {
|
if (mcpToolManager == null) {
|
throw new IllegalStateException("Spring AI MCP tool manager is unavailable");
|
}
|
List<Object> tools = mcpToolManager.buildOpenAiTools();
|
if (tools.isEmpty()) {
|
throw new IllegalStateException("No MCP tools registered");
|
}
|
AgentUsageStats usageStats = new AgentUsageStats();
|
StringBuilder reasoningBuffer = new StringBuilder();
|
|
baseMessages.add(systemPrompt);
|
baseMessages.add(userQuestion);
|
|
List<ChatCompletionRequest.Message> messages = new ArrayList<>(baseMessages.size() + 8);
|
messages.addAll(baseMessages);
|
|
sse(emitter, "<think>\\n正在初始化诊断与工具环境...\\n");
|
appendReasoning(reasoningBuffer, "正在初始化诊断与工具环境...\n");
|
|
int maxRound = 10;
|
int i = 0;
|
while(true) {
|
sse(emitter, "\\n正在分析(第" + (i + 1) + "轮)...\\n");
|
appendReasoning(reasoningBuffer, "\n正在分析(第" + (i + 1) + "轮)...\n");
|
ChatCompletionResponse resp = llmChatService.chatCompletion(messages, temperature, maxTokens, tools);
|
if (resp == null || resp.getChoices() == null || resp.getChoices().isEmpty() || resp.getChoices().get(0).getMessage() == null) {
|
throw new IllegalStateException("LLM returned empty response");
|
}
|
usageStats.add(resp.getUsage());
|
|
ChatCompletionRequest.Message assistant = resp.getChoices().get(0).getMessage();
|
messages.add(assistant);
|
sse(emitter, assistant.getContent());
|
appendReasoning(reasoningBuffer, assistant == null ? null : assistant.getContent());
|
|
List<ChatCompletionRequest.ToolCall> toolCalls = assistant.getTool_calls();
|
if (toolCalls == null || toolCalls.isEmpty()) {
|
break;
|
}
|
|
for (ChatCompletionRequest.ToolCall tc : toolCalls) {
|
String toolName = tc != null && tc.getFunction() != null ? tc.getFunction().getName() : null;
|
if (toolName == null || toolName.trim().isEmpty()) continue;
|
sse(emitter, "\\n准备调用工具:" + toolName + "\\n");
|
appendReasoning(reasoningBuffer, "\n准备调用工具:" + toolName + "\n");
|
JSONObject args = new JSONObject();
|
if (tc.getFunction() != null && tc.getFunction().getArguments() != null && !tc.getFunction().getArguments().trim().isEmpty()) {
|
try {
|
args = JSON.parseObject(tc.getFunction().getArguments());
|
} catch (Exception ignore) {
|
args = new JSONObject();
|
args.put("_raw", tc.getFunction().getArguments());
|
}
|
}
|
Object output;
|
try {
|
output = mcpToolManager.callTool(toolName, args);
|
} catch (Exception e) {
|
java.util.LinkedHashMap<String, Object> err = new java.util.LinkedHashMap<String, Object>();
|
err.put("tool", toolName);
|
err.put("error", e.getMessage());
|
output = err;
|
}
|
sse(emitter, "\\n工具返回,正在继续推理...\\n");
|
appendReasoning(reasoningBuffer, "\n工具返回,正在继续推理...\n");
|
ChatCompletionRequest.Message toolMsg = new ChatCompletionRequest.Message();
|
toolMsg.setRole("tool");
|
toolMsg.setTool_call_id(tc == null ? null : tc.getId());
|
toolMsg.setContent(JSON.toJSONString(output));
|
messages.add(toolMsg);
|
}
|
|
if(i++ >= maxRound) break;
|
}
|
|
sse(emitter, "\\n正在根据数据进行分析...\\n</think>\\n\\n");
|
appendReasoning(reasoningBuffer, "\n正在根据数据进行分析...\n");
|
|
ChatCompletionRequest.Message diagnosisMessage = new ChatCompletionRequest.Message();
|
diagnosisMessage.setRole("system");
|
diagnosisMessage.setContent("根据以上信息进行分析,并给出完整的诊断结论。");
|
messages.add(diagnosisMessage);
|
|
StringBuilder assistantBuffer = new StringBuilder();
|
llmChatService.chatStreamWithTools(messages, temperature, maxTokens, tools, s -> {
|
try {
|
String safe = s == null ? "" : s.replace("\r", "").replace("\n", "\\n");
|
if (!safe.isEmpty()) {
|
sse(emitter, safe);
|
assistantBuffer.append(safe);
|
}
|
} catch (Exception ignore) {}
|
}, () -> {
|
try {
|
emitTokenUsage(emitter, usageStats);
|
sse(emitter, "\\n\\n【AI】运行已停止(正常结束)\\n\\n");
|
log.info("AI MCP diagnose stopped: final end");
|
emitter.complete();
|
|
if (chatId != null) {
|
ChatCompletionRequest.Message a = new ChatCompletionRequest.Message();
|
a.setRole("assistant");
|
a.setContent(assistantBuffer.toString());
|
a.setReasoningContent(reasoningBuffer.toString());
|
aiChatStoreService.saveConversation(chatId,
|
buildTitleFromPrompt(storedUserQuestion == null ? null : storedUserQuestion.getContent()),
|
storedUserQuestion == null ? userQuestion : storedUserQuestion,
|
a,
|
promptTemplate,
|
usageStats.getPromptTokens(),
|
usageStats.getCompletionTokens(),
|
usageStats.getTotalTokens(),
|
usageStats.getLlmCallCount());
|
}
|
} catch (Exception ignore) {}
|
}, e -> {
|
try {
|
emitTokenUsage(emitter, usageStats);
|
sse(emitter, "\\n\\n【AI】分析出错,运行已停止(异常)\\n\\n");
|
log.error("AI MCP diagnose stopped: stream error", e);
|
emitter.complete();
|
} catch (Exception ignore) {}
|
}, usageStats::add);
|
} catch (Exception e) {
|
try {
|
sse(emitter, "\\n\\n【AI】运行已停止(异常)\\n\\n");
|
log.error("AI MCP diagnose stopped: error", e);
|
emitter.complete();
|
} catch (Exception ignore) {}
|
}
|
}
|
|
private void sse(SseEmitter emitter, String data) {
|
if (data == null) return;
|
try {
|
emitter.send(SseEmitter.event().data(data));
|
} catch (Exception e) {
|
log.warn("SSE send failed", e);
|
}
|
}
|
|
private void emitTokenUsage(SseEmitter emitter, AgentUsageStats usageStats) {
|
if (emitter == null || usageStats == null || usageStats.getTotalTokens() <= 0) {
|
return;
|
}
|
try {
|
emitter.send(SseEmitter.event()
|
.name("token_usage")
|
.data(JSON.toJSONString(buildTokenUsagePayload(usageStats))));
|
} catch (Exception e) {
|
log.warn("SSE token usage send failed", e);
|
}
|
}
|
|
private Map<String, Object> buildTokenUsagePayload(AgentUsageStats usageStats) {
|
java.util.LinkedHashMap<String, Object> payload = new java.util.LinkedHashMap<>();
|
payload.put("promptTokens", usageStats.getPromptTokens());
|
payload.put("completionTokens", usageStats.getCompletionTokens());
|
payload.put("totalTokens", usageStats.getTotalTokens());
|
payload.put("llmCallCount", usageStats.getLlmCallCount());
|
return payload;
|
}
|
|
private void appendReasoning(StringBuilder reasoningBuffer, String text) {
|
if (reasoningBuffer == null || text == null || text.isEmpty()) {
|
return;
|
}
|
reasoningBuffer.append(text);
|
}
|
|
private void sendLargeText(SseEmitter emitter, String text) {
|
if (text == null) return;
|
String safe = text.replace("\r", "").replace("\n", "\\n");
|
int chunkSize = 256;
|
int i = 0;
|
while (i < safe.length()) {
|
int end = Math.min(i + chunkSize, safe.length());
|
String part = safe.substring(i, end);
|
if (!part.isEmpty()) {
|
try { emitter.send(SseEmitter.event().data(part)); } catch (Exception ignore) {}
|
}
|
i = end;
|
}
|
}
|
|
private static final java.util.regex.Pattern DSML_INVOKE_PATTERN =
|
java.util.regex.Pattern.compile("<\\uFF5CDSML\\uFF5Cinvoke\\s+name=\\\"([^\\\"]+)\\\"[^>]*>([\\s\\S]*?)</\\uFF5CDSML\\uFF5Cinvoke>", java.util.regex.Pattern.MULTILINE);
|
private static final java.util.regex.Pattern JSON_OBJECT_PATTERN =
|
java.util.regex.Pattern.compile("\\{[\\s\\S]*\\}");
|
private static final java.util.regex.Pattern DSML_PARAM_PATTERN =
|
java.util.regex.Pattern.compile("<\\uFF5CDSML\\uFF5Cparameter\\s+name=\\\"([^\\\"]+)\\\"\\s*([^>]*)>([\\s\\S]*?)</\\uFF5CDSML\\uFF5Cparameter>", java.util.regex.Pattern.MULTILINE);
|
|
private java.util.List<DsmlInvocation> parseDsmlInvocations(String content) {
|
java.util.List<DsmlInvocation> list = new java.util.ArrayList<>();
|
if (content == null || content.isEmpty()) return list;
|
java.util.regex.Matcher m = DSML_INVOKE_PATTERN.matcher(content);
|
while (m.find()) {
|
String name = m.group(1);
|
String inner = m.group(2);
|
com.alibaba.fastjson.JSONObject args = null;
|
if (inner != null) {
|
java.util.regex.Matcher jm = JSON_OBJECT_PATTERN.matcher(inner);
|
if (jm.find()) {
|
String json = jm.group();
|
try { args = com.alibaba.fastjson.JSON.parseObject(json); } catch (Exception ignore) {}
|
}
|
java.util.regex.Matcher pm = DSML_PARAM_PATTERN.matcher(inner);
|
while (pm.find()) {
|
if (args == null) args = new com.alibaba.fastjson.JSONObject();
|
String pName = pm.group(1);
|
String attr = pm.group(2);
|
String valText = pm.group(3);
|
boolean isString = attr != null && attr.toLowerCase().contains("string=\"true\"");
|
String t = valText == null ? "" : valText.trim();
|
if (isString) {
|
args.put(pName, t);
|
} else {
|
if ("true".equalsIgnoreCase(t) || "false".equalsIgnoreCase(t)) {
|
args.put(pName, Boolean.valueOf(t));
|
} else {
|
try {
|
if (t.contains(".")) {
|
args.put(pName, Double.valueOf(t));
|
} else {
|
args.put(pName, Long.valueOf(t));
|
}
|
} catch (Exception ex) {
|
args.put(pName, t);
|
}
|
}
|
}
|
}
|
}
|
DsmlInvocation inv = new DsmlInvocation();
|
inv.name = name;
|
inv.arguments = args;
|
list.add(inv);
|
}
|
return list;
|
}
|
|
private static class DsmlInvocation {
|
String name;
|
com.alibaba.fastjson.JSONObject arguments;
|
}
|
|
private List<DsmlInvocation> buildDefaultStatusInvocations() {
|
List<DsmlInvocation> list = new ArrayList<>();
|
DsmlInvocation crn = new DsmlInvocation();
|
crn.name = "device.get_crn_status";
|
com.alibaba.fastjson.JSONObject a1 = new com.alibaba.fastjson.JSONObject();
|
a1.put("limit", 20);
|
crn.arguments = a1;
|
list.add(crn);
|
|
DsmlInvocation st = new DsmlInvocation();
|
st.name = "device.get_station_status";
|
com.alibaba.fastjson.JSONObject a2 = new com.alibaba.fastjson.JSONObject();
|
a2.put("limit", 20);
|
st.arguments = a2;
|
list.add(st);
|
|
DsmlInvocation rgv = new DsmlInvocation();
|
rgv.name = "device.get_rgv_status";
|
com.alibaba.fastjson.JSONObject a3 = new com.alibaba.fastjson.JSONObject();
|
a3.put("limit", 20);
|
rgv.arguments = a3;
|
list.add(rgv);
|
|
return list;
|
}
|
|
private void ensureStatusCoverage(List<DsmlInvocation> invs) {
|
if (invs == null) return;
|
java.util.Set<String> names = new java.util.HashSet<String>();
|
for (DsmlInvocation d : invs) {
|
if (d != null && d.name != null) names.add(d.name);
|
}
|
if (!names.contains("device.get_crn_status") || !names.contains("device.get_station_status") || !names.contains("device.get_rgv_status")) {
|
List<DsmlInvocation> defaults = buildDefaultStatusInvocations();
|
for (DsmlInvocation d : defaults) {
|
if (!names.contains(d.name)) invs.add(d);
|
}
|
}
|
}
|
|
private String buildDiagnoseDisplayPrompt(WcsDiagnosisRequest request) {
|
if (request == null || request.getAlarmMessage() == null || request.getAlarmMessage().trim().isEmpty()) {
|
return "对当前系统进行巡检";
|
}
|
return request.getAlarmMessage().trim();
|
}
|
|
private static class AgentUsageStats {
|
private long promptTokens;
|
private long completionTokens;
|
private long totalTokens;
|
private int llmCallCount;
|
|
void add(ChatCompletionResponse.Usage usage) {
|
if (usage == null) {
|
return;
|
}
|
promptTokens += usage.getPromptTokens() == null ? 0L : usage.getPromptTokens();
|
completionTokens += usage.getCompletionTokens() == null ? 0L : usage.getCompletionTokens();
|
totalTokens += usage.getTotalTokens() == null ? 0L : usage.getTotalTokens();
|
llmCallCount++;
|
}
|
|
long getPromptTokens() {
|
return promptTokens;
|
}
|
|
long getCompletionTokens() {
|
return completionTokens;
|
}
|
|
long getTotalTokens() {
|
return totalTokens;
|
}
|
|
int getLlmCallCount() {
|
return llmCallCount;
|
}
|
}
|
|
private boolean isConclusionText(String content) {
|
if (content == null) return false;
|
String c = content;
|
int len = c.length();
|
boolean longEnough = len >= 200;
|
boolean hasAllSections = c.contains("问题概述") && c.contains("可疑设备列表") && c.contains("可能原因") && c.contains("建议排查步骤") && c.contains("风险评估");
|
boolean hasExplicitConclusion = (c.contains("结论") || c.contains("诊断结果")) && longEnough;
|
return hasAllSections || hasExplicitConclusion;
|
}
|
}
|