| | |
| | | |
| | | import com.alibaba.fastjson.JSON; |
| | | import com.alibaba.fastjson.JSONObject; |
| | | import com.zy.ai.entity.AiPromptTemplate; |
| | | import com.zy.ai.entity.ChatCompletionRequest; |
| | | import com.zy.ai.entity.ChatCompletionResponse; |
| | | import com.zy.ai.entity.WcsDiagnosisRequest; |
| | | import com.zy.ai.mcp.controller.McpController; |
| | | import com.zy.ai.utils.AiPromptUtils; |
| | | import com.zy.ai.enums.AiPromptScene; |
| | | import com.zy.ai.mcp.service.SpringAiMcpToolManager; |
| | | import com.zy.ai.utils.AiUtils; |
| | | import com.zy.common.utils.RedisUtil; |
| | | import com.zy.core.enums.RedisKeyType; |
| | | import org.springframework.beans.factory.annotation.Autowired; |
| | | import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; |
| | | import lombok.RequiredArgsConstructor; |
| | |
| | | @Slf4j |
| | | public class WcsDiagnosisService { |
| | | |
| | | private static final long CHAT_TTL_SECONDS = 7L * 24 * 3600; |
| | | |
| | | @Autowired |
| | | private LlmChatService llmChatService; |
| | | @Autowired |
| | | private RedisUtil redisUtil; |
| | | @Autowired |
| | | private AiPromptUtils aiPromptUtils; |
| | | @Autowired |
| | | private AiUtils aiUtils; |
| | | @Autowired(required = false) |
| | | private McpController mcpController; |
| | | @Autowired |
| | | private SpringAiMcpToolManager mcpToolManager; |
| | | @Autowired |
| | | private AiPromptTemplateService aiPromptTemplateService; |
| | | @Autowired |
| | | private AiChatStoreService aiChatStoreService; |
| | | |
| | | public void diagnoseStream(WcsDiagnosisRequest request, SseEmitter emitter) { |
| | | List<ChatCompletionRequest.Message> messages = new ArrayList<>(); |
| | | AiPromptTemplate promptTemplate = aiPromptTemplateService.resolvePublished(AiPromptScene.DIAGNOSE_STREAM.getCode()); |
| | | |
| | | ChatCompletionRequest.Message mcpSystem = new ChatCompletionRequest.Message(); |
| | | mcpSystem.setRole("system"); |
| | | mcpSystem.setContent(aiPromptUtils.getAiDiagnosePromptMcp()); |
| | | mcpSystem.setContent(promptTemplate.getContent()); |
| | | |
| | | ChatCompletionRequest.Message mcpUser = new ChatCompletionRequest.Message(); |
| | | mcpUser.setRole("user"); |
| | | mcpUser.setContent(aiUtils.buildDiagnosisUserContentMcp(request)); |
| | | |
| | | if (runMcpStreamingDiagnosis(messages, mcpSystem, mcpUser, 0.3, 2048, emitter, null)) { |
| | | return; |
| | | } |
| | | |
| | | messages = new ArrayList<>(); |
| | | ChatCompletionRequest.Message system = new ChatCompletionRequest.Message(); |
| | | system.setRole("system"); |
| | | system.setContent(aiPromptUtils.getAiDiagnosePrompt()); |
| | | messages.add(system); |
| | | |
| | | ChatCompletionRequest.Message user = new ChatCompletionRequest.Message(); |
| | | user.setRole("user"); |
| | | user.setContent(aiUtils.buildDiagnosisUserContent(request)); |
| | | messages.add(user); |
| | | |
| | | llmChatService.chatStream(messages, 0.3, 2048, s -> { |
| | | try { |
| | | String safe = s == null ? "" : s.replace("\r", "").replace("\n", "\\n"); |
| | | if (!safe.isEmpty()) { |
| | | emitter.send(SseEmitter.event().data(safe)); |
| | | } |
| | | } catch (Exception ignore) {} |
| | | }, () -> { |
| | | try { |
| | | log.info("AI diagnose stream stopped: normal end"); |
| | | emitter.complete(); |
| | | } catch (Exception ignore) {} |
| | | }, e -> { |
| | | try { |
| | | try { emitter.send(SseEmitter.event().data("【AI】运行已停止(异常)")); } catch (Exception ignore) {} |
| | | log.error("AI diagnose stream stopped: error", e); |
| | | emitter.complete(); |
| | | } catch (Exception ignore) {} |
| | | }); |
| | | runMcpStreamingDiagnosis(messages, mcpSystem, mcpUser, promptTemplate, 0.3, 2048, emitter, null); |
| | | } |
| | | |
| | | public void askStream(WcsDiagnosisRequest request, |
| | | String prompt, |
| | | public void askStream(String prompt, |
| | | String chatId, |
| | | boolean reset, |
| | | SseEmitter emitter) { |
| | | List<ChatCompletionRequest.Message> messages = new ArrayList<>(); |
| | | |
| | | List<ChatCompletionRequest.Message> history = null; |
| | | String historyKey = null; |
| | | String metaKey = null; |
| | | if (chatId != null && !chatId.isEmpty()) { |
| | | historyKey = RedisKeyType.AI_CHAT_HISTORY.key + chatId; |
| | | metaKey = RedisKeyType.AI_CHAT_META.key + chatId; |
| | | if (reset) { |
| | | redisUtil.del(historyKey, metaKey); |
| | | aiChatStoreService.deleteChat(chatId); |
| | | } |
| | | List<Object> stored = redisUtil.lGet(historyKey, 0, -1); |
| | | if (stored != null && !stored.isEmpty()) { |
| | | history = new ArrayList<>(stored.size()); |
| | | for (Object o : stored) { |
| | | ChatCompletionRequest.Message m = convertToMessage(o); |
| | | if (m != null) history.add(m); |
| | | } |
| | | if (!history.isEmpty()) messages.addAll(history); |
| | | } else { |
| | | history = new ArrayList<>(); |
| | | List<ChatCompletionRequest.Message> history = aiChatStoreService.getChatHistory(chatId); |
| | | if (history != null && !history.isEmpty()) { |
| | | messages.addAll(history); |
| | | } |
| | | } |
| | | |
| | | StringBuilder assistantBuffer = new StringBuilder(); |
| | | final String finalChatId = chatId; |
| | | final String finalHistoryKey = historyKey; |
| | | final String finalMetaKey = metaKey; |
| | | final String finalPrompt = prompt; |
| | | AiPromptTemplate promptTemplate = aiPromptTemplateService.resolvePublished(AiPromptScene.SENSOR_CHAT.getCode()); |
| | | |
| | | ChatCompletionRequest.Message mcpSystem = new ChatCompletionRequest.Message(); |
| | | mcpSystem.setRole("system"); |
| | | mcpSystem.setContent(aiPromptUtils.getWcsSensorPromptMcp()); |
| | | mcpSystem.setContent(promptTemplate.getContent()); |
| | | |
| | | ChatCompletionRequest.Message mcpUser = new ChatCompletionRequest.Message(); |
| | | mcpUser.setRole("user"); |
| | | mcpUser.setContent("【用户提问】\n" + (prompt == null ? "" : prompt)); |
| | | mcpUser.setContent(prompt == null ? "" : prompt); |
| | | |
| | | if (runMcpStreamingDiagnosis(messages, mcpSystem, mcpUser, 0.3, 2048, emitter, finalChatId)) { |
| | | return; |
| | | } |
| | | |
| | | messages = new ArrayList<>(); |
| | | ChatCompletionRequest.Message system = new ChatCompletionRequest.Message(); |
| | | system.setRole("system"); |
| | | system.setContent(aiPromptUtils.getWcsSensorPrompt()); |
| | | messages.add(system); |
| | | |
| | | ChatCompletionRequest.Message questionMsg = new ChatCompletionRequest.Message(); |
| | | questionMsg.setRole("user"); |
| | | questionMsg.setContent("【用户提问】\n" + (prompt == null ? "" : prompt)); |
| | | messages.add(questionMsg); |
| | | |
| | | llmChatService.chatStream(messages, 0.3, 2048, s -> { |
| | | try { |
| | | String safe = s == null ? "" : s.replace("\r", "").replace("\n", "\\n"); |
| | | if (!safe.isEmpty()) { |
| | | emitter.send(SseEmitter.event().data(safe)); |
| | | assistantBuffer.append(s); |
| | | } |
| | | } catch (Exception ignore) {} |
| | | }, () -> { |
| | | try { |
| | | if (finalChatId != null && !finalChatId.isEmpty()) { |
| | | ChatCompletionRequest.Message q = new ChatCompletionRequest.Message(); |
| | | q.setRole("user"); |
| | | q.setContent(finalPrompt == null ? "" : finalPrompt); |
| | | ChatCompletionRequest.Message a = new ChatCompletionRequest.Message(); |
| | | a.setRole("assistant"); |
| | | a.setContent(assistantBuffer.toString()); |
| | | redisUtil.lSet(finalHistoryKey, q); |
| | | redisUtil.lSet(finalHistoryKey, a); |
| | | redisUtil.expire(finalHistoryKey, CHAT_TTL_SECONDS); |
| | | Map<Object, Object> old = redisUtil.hmget(finalMetaKey); |
| | | Long createdAt = old != null && old.get("createdAt") != null ? |
| | | (old.get("createdAt") instanceof Number ? ((Number) old.get("createdAt")).longValue() : Long.valueOf(String.valueOf(old.get("createdAt")))) |
| | | : System.currentTimeMillis(); |
| | | Map<String, Object> meta = new java.util.HashMap<>(); |
| | | meta.put("chatId", finalChatId); |
| | | meta.put("title", buildTitleFromPrompt(finalPrompt)); |
| | | meta.put("createdAt", createdAt); |
| | | meta.put("updatedAt", System.currentTimeMillis()); |
| | | redisUtil.hmset(finalMetaKey, meta, CHAT_TTL_SECONDS); |
| | | } |
| | | emitter.complete(); |
| | | } catch (Exception ignore) {} |
| | | }, e -> { |
| | | try { |
| | | try { emitter.send(SseEmitter.event().data("【AI】运行已停止(异常)")); } catch (Exception ignore) {} |
| | | emitter.complete(); |
| | | } catch (Exception ignore) {} |
| | | }); |
| | | runMcpStreamingDiagnosis(messages, mcpSystem, mcpUser, promptTemplate, 0.3, 2048, emitter, finalChatId); |
| | | } |
| | | |
| | | public List<Map<String, Object>> listChats() { |
| | | java.util.Set<String> keys = redisUtil.scanKeys(RedisKeyType.AI_CHAT_META.key, 1000); |
| | | List<Map<String, Object>> resp = new ArrayList<>(); |
| | | if (keys != null) { |
| | | for (String key : keys) { |
| | | Map<Object, Object> m = redisUtil.hmget(key); |
| | | if (m != null && !m.isEmpty()) { |
| | | java.util.HashMap<String, Object> item = new java.util.HashMap<>(); |
| | | for (Map.Entry<Object, Object> e : m.entrySet()) { |
| | | item.put(String.valueOf(e.getKey()), e.getValue()); |
| | | } |
| | | String chatId = String.valueOf(item.get("chatId")); |
| | | String historyKey = RedisKeyType.AI_CHAT_HISTORY.key + chatId; |
| | | item.put("size", redisUtil.lGetListSize(historyKey)); |
| | | resp.add(item); |
| | | } |
| | | } |
| | | } |
| | | return resp; |
| | | return aiChatStoreService.listChats(); |
| | | } |
| | | |
| | | public boolean deleteChat(String chatId) { |
| | | if (chatId == null || chatId.isEmpty()) return false; |
| | | String historyKey = RedisKeyType.AI_CHAT_HISTORY.key + chatId; |
| | | String metaKey = RedisKeyType.AI_CHAT_META.key + chatId; |
| | | redisUtil.del(historyKey, metaKey); |
| | | return true; |
| | | return aiChatStoreService.deleteChat(chatId); |
| | | } |
| | | |
| | | public List<ChatCompletionRequest.Message> getChatHistory(String chatId) { |
| | | if (chatId == null || chatId.isEmpty()) return java.util.Collections.emptyList(); |
| | | String historyKey = RedisKeyType.AI_CHAT_HISTORY.key + chatId; |
| | | List<Object> stored = redisUtil.lGet(historyKey, 0, -1); |
| | | List<ChatCompletionRequest.Message> result = new ArrayList<>(); |
| | | if (stored != null) { |
| | | for (Object o : stored) { |
| | | ChatCompletionRequest.Message m = convertToMessage(o); |
| | | if (m != null) result.add(m); |
| | | } |
| | | } |
| | | return result; |
| | | } |
| | | |
| | | private ChatCompletionRequest.Message convertToMessage(Object o) { |
| | | if (o instanceof ChatCompletionRequest.Message) { |
| | | return (ChatCompletionRequest.Message) o; |
| | | } |
| | | if (o instanceof Map) { |
| | | Map<?, ?> map = (Map<?, ?>) o; |
| | | ChatCompletionRequest.Message m = new ChatCompletionRequest.Message(); |
| | | Object role = map.get("role"); |
| | | Object content = map.get("content"); |
| | | m.setRole(role == null ? null : String.valueOf(role)); |
| | | m.setContent(content == null ? null : String.valueOf(content)); |
| | | return m; |
| | | } |
| | | return null; |
| | | return aiChatStoreService.getChatHistory(chatId); |
| | | } |
| | | |
| | | private String buildTitleFromPrompt(String prompt) { |
| | |
| | | return p.length() > 20 ? p.substring(0, 20) : p; |
| | | } |
| | | |
| | | private boolean runMcpStreamingDiagnosis(List<ChatCompletionRequest.Message> baseMessages, |
| | | ChatCompletionRequest.Message systemPrompt, |
| | | ChatCompletionRequest.Message userQuestion, |
| | | Double temperature, |
| | | Integer maxTokens, |
| | | SseEmitter emitter, |
| | | String chatId) { |
| | | private void runMcpStreamingDiagnosis(List<ChatCompletionRequest.Message> baseMessages, |
| | | ChatCompletionRequest.Message systemPrompt, |
| | | ChatCompletionRequest.Message userQuestion, |
| | | AiPromptTemplate promptTemplate, |
| | | Double temperature, |
| | | Integer maxTokens, |
| | | SseEmitter emitter, |
| | | String chatId) { |
| | | try { |
| | | if (mcpController == null) return false; |
| | | List<Object> tools = buildOpenAiTools(); |
| | | if (tools.isEmpty()) return false; |
| | | if (mcpToolManager == null) { |
| | | throw new IllegalStateException("Spring AI MCP tool manager is unavailable"); |
| | | } |
| | | List<Object> tools = mcpToolManager.buildOpenAiTools(); |
| | | if (tools.isEmpty()) { |
| | | throw new IllegalStateException("No MCP tools registered"); |
| | | } |
| | | AgentUsageStats usageStats = new AgentUsageStats(); |
| | | |
| | | baseMessages.add(systemPrompt); |
| | | baseMessages.add(userQuestion); |
| | |
| | | sse(emitter, "\\n正在分析(第" + (i + 1) + "轮)...\\n"); |
| | | ChatCompletionResponse resp = llmChatService.chatCompletion(messages, temperature, maxTokens, tools); |
| | | if (resp == null || resp.getChoices() == null || resp.getChoices().isEmpty() || resp.getChoices().get(0).getMessage() == null) { |
| | | sse(emitter, "\\n分析出错,正在回退...\\n"); |
| | | return false; |
| | | throw new IllegalStateException("LLM returned empty response"); |
| | | } |
| | | usageStats.add(resp.getUsage()); |
| | | |
| | | ChatCompletionRequest.Message assistant = resp.getChoices().get(0).getMessage(); |
| | | messages.add(assistant); |
| | |
| | | } |
| | | Object output; |
| | | try { |
| | | output = mcpController.callTool(toolName, args); |
| | | output = mcpToolManager.callTool(toolName, args); |
| | | } catch (Exception e) { |
| | | java.util.LinkedHashMap<String, Object> err = new java.util.LinkedHashMap<String, Object>(); |
| | | err.put("tool", toolName); |
| | |
| | | } catch (Exception ignore) {} |
| | | }, () -> { |
| | | try { |
| | | emitTokenUsage(emitter, usageStats); |
| | | sse(emitter, "\\n\\n【AI】运行已停止(正常结束)\\n\\n"); |
| | | log.info("AI MCP diagnose stopped: final end"); |
| | | emitter.complete(); |
| | | |
| | | if (chatId != null) { |
| | | String historyKey = RedisKeyType.AI_CHAT_HISTORY.key + chatId; |
| | | String metaKey = RedisKeyType.AI_CHAT_META.key + chatId; |
| | | |
| | | ChatCompletionRequest.Message a = new ChatCompletionRequest.Message(); |
| | | a.setRole("assistant"); |
| | | a.setContent(assistantBuffer.toString()); |
| | | redisUtil.lSet(historyKey, userQuestion); |
| | | redisUtil.lSet(historyKey, a); |
| | | redisUtil.expire(historyKey, CHAT_TTL_SECONDS); |
| | | Map<Object, Object> old = redisUtil.hmget(metaKey); |
| | | Long createdAt = old != null && old.get("createdAt") != null ? |
| | | (old.get("createdAt") instanceof Number ? ((Number) old.get("createdAt")).longValue() : Long.valueOf(String.valueOf(old.get("createdAt")))) |
| | | : System.currentTimeMillis(); |
| | | Map<String, Object> meta = new java.util.HashMap<>(); |
| | | meta.put("chatId", chatId); |
| | | meta.put("title", buildTitleFromPrompt(userQuestion.getContent())); |
| | | meta.put("createdAt", createdAt); |
| | | meta.put("updatedAt", System.currentTimeMillis()); |
| | | redisUtil.hmset(metaKey, meta, CHAT_TTL_SECONDS); |
| | | aiChatStoreService.saveConversation(chatId, |
| | | buildTitleFromPrompt(userQuestion.getContent()), |
| | | userQuestion, |
| | | a, |
| | | promptTemplate, |
| | | usageStats.getPromptTokens(), |
| | | usageStats.getCompletionTokens(), |
| | | usageStats.getTotalTokens(), |
| | | usageStats.getLlmCallCount()); |
| | | } |
| | | } catch (Exception ignore) {} |
| | | }, e -> { |
| | | sse(emitter, "\\n\\n【AI】分析出错,正在回退...\\n\\n"); |
| | | }); |
| | | return true; |
| | | try { |
| | | emitTokenUsage(emitter, usageStats); |
| | | sse(emitter, "\\n\\n【AI】分析出错,运行已停止(异常)\\n\\n"); |
| | | log.error("AI MCP diagnose stopped: stream error", e); |
| | | emitter.complete(); |
| | | } catch (Exception ignore) {} |
| | | }, usageStats::add); |
| | | } catch (Exception e) { |
| | | try { |
| | | sse(emitter, "\\n\\n【AI】运行已停止(异常)\\n\\n"); |
| | | log.error("AI MCP diagnose stopped: error", e); |
| | | emitter.complete(); |
| | | } catch (Exception ignore) {} |
| | | return true; |
| | | } |
| | | } |
| | | |
| | |
| | | } |
| | | } |
| | | |
| | | private List<Object> buildOpenAiTools() { |
| | | if (mcpController == null) return java.util.Collections.emptyList(); |
| | | List<Map<String, Object>> mcpTools = mcpController.listTools(); |
| | | if (mcpTools == null || mcpTools.isEmpty()) return java.util.Collections.emptyList(); |
| | | |
| | | List<Object> tools = new ArrayList<>(); |
| | | for (Map<String, Object> t : mcpTools) { |
| | | if (t == null) continue; |
| | | Object name = t.get("name"); |
| | | if (name == null) continue; |
| | | Object inputSchema = t.get("inputSchema"); |
| | | java.util.LinkedHashMap<String, Object> function = new java.util.LinkedHashMap<String, Object>(); |
| | | function.put("name", String.valueOf(name)); |
| | | Object desc = t.get("description"); |
| | | if (desc != null) function.put("description", String.valueOf(desc)); |
| | | function.put("parameters", inputSchema == null ? new java.util.LinkedHashMap<String, Object>() : inputSchema); |
| | | |
| | | java.util.LinkedHashMap<String, Object> tool = new java.util.LinkedHashMap<String, Object>(); |
| | | tool.put("type", "function"); |
| | | tool.put("function", function); |
| | | tools.add(tool); |
| | | private void emitTokenUsage(SseEmitter emitter, AgentUsageStats usageStats) { |
| | | if (emitter == null || usageStats == null || usageStats.getTotalTokens() <= 0) { |
| | | return; |
| | | } |
| | | return tools; |
| | | try { |
| | | emitter.send(SseEmitter.event() |
| | | .name("token_usage") |
| | | .data(JSON.toJSONString(buildTokenUsagePayload(usageStats)))); |
| | | } catch (Exception e) { |
| | | log.warn("SSE token usage send failed", e); |
| | | } |
| | | } |
| | | |
| | | private Map<String, Object> buildTokenUsagePayload(AgentUsageStats usageStats) { |
| | | java.util.LinkedHashMap<String, Object> payload = new java.util.LinkedHashMap<>(); |
| | | payload.put("promptTokens", usageStats.getPromptTokens()); |
| | | payload.put("completionTokens", usageStats.getCompletionTokens()); |
| | | payload.put("totalTokens", usageStats.getTotalTokens()); |
| | | payload.put("llmCallCount", usageStats.getLlmCallCount()); |
| | | return payload; |
| | | } |
| | | |
| | | private void sendLargeText(SseEmitter emitter, String text) { |
| | |
| | | } |
| | | } |
| | | |
| | | private static class AgentUsageStats { |
| | | private long promptTokens; |
| | | private long completionTokens; |
| | | private long totalTokens; |
| | | private int llmCallCount; |
| | | |
| | | void add(ChatCompletionResponse.Usage usage) { |
| | | if (usage == null) { |
| | | return; |
| | | } |
| | | promptTokens += usage.getPromptTokens() == null ? 0L : usage.getPromptTokens(); |
| | | completionTokens += usage.getCompletionTokens() == null ? 0L : usage.getCompletionTokens(); |
| | | totalTokens += usage.getTotalTokens() == null ? 0L : usage.getTotalTokens(); |
| | | llmCallCount++; |
| | | } |
| | | |
| | | long getPromptTokens() { |
| | | return promptTokens; |
| | | } |
| | | |
| | | long getCompletionTokens() { |
| | | return completionTokens; |
| | | } |
| | | |
| | | long getTotalTokens() { |
| | | return totalTokens; |
| | | } |
| | | |
| | | int getLlmCallCount() { |
| | | return llmCallCount; |
| | | } |
| | | } |
| | | |
| | | private boolean isConclusionText(String content) { |
| | | if (content == null) return false; |
| | | String c = content; |