| | |
| | | import com.zy.ai.entity.ChatCompletionRequest; |
| | | import com.zy.ai.entity.ChatCompletionResponse; |
| | | import com.zy.ai.entity.WcsDiagnosisRequest; |
| | | import com.zy.ai.mcp.controller.McpController; |
| | | import com.zy.ai.mcp.service.SpringAiMcpToolManager; |
| | | import com.zy.ai.utils.AiPromptUtils; |
| | | import com.zy.ai.utils.AiUtils; |
| | | import com.zy.common.utils.RedisUtil; |
| | |
| | | private AiPromptUtils aiPromptUtils; |
| | | @Autowired |
| | | private AiUtils aiUtils; |
| | | @Autowired(required = false) |
| | | private McpController mcpController; |
| | | @Autowired |
| | | private SpringAiMcpToolManager mcpToolManager; |
| | | |
| | | public void diagnoseStream(WcsDiagnosisRequest request, SseEmitter emitter) { |
| | | List<ChatCompletionRequest.Message> messages = new ArrayList<>(); |
| | |
| | | mcpUser.setRole("user"); |
| | | mcpUser.setContent(aiUtils.buildDiagnosisUserContentMcp(request)); |
| | | |
| | | if (runMcpStreamingDiagnosis(messages, mcpSystem, mcpUser, 0.3, 2048, emitter, null)) { |
| | | return; |
| | | } |
| | | |
| | | messages = new ArrayList<>(); |
| | | ChatCompletionRequest.Message system = new ChatCompletionRequest.Message(); |
| | | system.setRole("system"); |
| | | system.setContent(aiPromptUtils.getAiDiagnosePrompt()); |
| | | messages.add(system); |
| | | |
| | | ChatCompletionRequest.Message user = new ChatCompletionRequest.Message(); |
| | | user.setRole("user"); |
| | | user.setContent(aiUtils.buildDiagnosisUserContent(request)); |
| | | messages.add(user); |
| | | |
| | | llmChatService.chatStream(messages, 0.3, 2048, s -> { |
| | | try { |
| | | String safe = s == null ? "" : s.replace("\r", "").replace("\n", "\\n"); |
| | | if (!safe.isEmpty()) { |
| | | emitter.send(SseEmitter.event().data(safe)); |
| | | } |
| | | } catch (Exception ignore) {} |
| | | }, () -> { |
| | | try { |
| | | log.info("AI diagnose stream stopped: normal end"); |
| | | emitter.complete(); |
| | | } catch (Exception ignore) {} |
| | | }, e -> { |
| | | try { |
| | | try { emitter.send(SseEmitter.event().data("【AI】运行已停止(异常)")); } catch (Exception ignore) {} |
| | | log.error("AI diagnose stream stopped: error", e); |
| | | emitter.complete(); |
| | | } catch (Exception ignore) {} |
| | | }); |
| | | runMcpStreamingDiagnosis(messages, mcpSystem, mcpUser, 0.3, 2048, emitter, null); |
| | | } |
| | | |
| | | public void askStream(WcsDiagnosisRequest request, |
| | | String prompt, |
| | | public void askStream(String prompt, |
| | | String chatId, |
| | | boolean reset, |
| | | SseEmitter emitter) { |
| | |
| | | } |
| | | } |
| | | |
| | | StringBuilder assistantBuffer = new StringBuilder(); |
| | | final String finalChatId = chatId; |
| | | final String finalHistoryKey = historyKey; |
| | | final String finalMetaKey = metaKey; |
| | | final String finalPrompt = prompt; |
| | | |
| | | ChatCompletionRequest.Message mcpSystem = new ChatCompletionRequest.Message(); |
| | | mcpSystem.setRole("system"); |
| | |
| | | mcpUser.setRole("user"); |
| | | mcpUser.setContent("【用户提问】\n" + (prompt == null ? "" : prompt)); |
| | | |
| | | if (runMcpStreamingDiagnosis(messages, mcpSystem, mcpUser, 0.3, 2048, emitter, finalChatId)) { |
| | | return; |
| | | } |
| | | |
| | | messages = new ArrayList<>(); |
| | | ChatCompletionRequest.Message system = new ChatCompletionRequest.Message(); |
| | | system.setRole("system"); |
| | | system.setContent(aiPromptUtils.getWcsSensorPrompt()); |
| | | messages.add(system); |
| | | |
| | | ChatCompletionRequest.Message questionMsg = new ChatCompletionRequest.Message(); |
| | | questionMsg.setRole("user"); |
| | | questionMsg.setContent("【用户提问】\n" + (prompt == null ? "" : prompt)); |
| | | messages.add(questionMsg); |
| | | |
| | | llmChatService.chatStream(messages, 0.3, 2048, s -> { |
| | | try { |
| | | String safe = s == null ? "" : s.replace("\r", "").replace("\n", "\\n"); |
| | | if (!safe.isEmpty()) { |
| | | emitter.send(SseEmitter.event().data(safe)); |
| | | assistantBuffer.append(s); |
| | | } |
| | | } catch (Exception ignore) {} |
| | | }, () -> { |
| | | try { |
| | | if (finalChatId != null && !finalChatId.isEmpty()) { |
| | | ChatCompletionRequest.Message q = new ChatCompletionRequest.Message(); |
| | | q.setRole("user"); |
| | | q.setContent(finalPrompt == null ? "" : finalPrompt); |
| | | ChatCompletionRequest.Message a = new ChatCompletionRequest.Message(); |
| | | a.setRole("assistant"); |
| | | a.setContent(assistantBuffer.toString()); |
| | | redisUtil.lSet(finalHistoryKey, q); |
| | | redisUtil.lSet(finalHistoryKey, a); |
| | | redisUtil.expire(finalHistoryKey, CHAT_TTL_SECONDS); |
| | | Map<Object, Object> old = redisUtil.hmget(finalMetaKey); |
| | | Long createdAt = old != null && old.get("createdAt") != null ? |
| | | (old.get("createdAt") instanceof Number ? ((Number) old.get("createdAt")).longValue() : Long.valueOf(String.valueOf(old.get("createdAt")))) |
| | | : System.currentTimeMillis(); |
| | | Map<String, Object> meta = new java.util.HashMap<>(); |
| | | meta.put("chatId", finalChatId); |
| | | meta.put("title", buildTitleFromPrompt(finalPrompt)); |
| | | meta.put("createdAt", createdAt); |
| | | meta.put("updatedAt", System.currentTimeMillis()); |
| | | redisUtil.hmset(finalMetaKey, meta, CHAT_TTL_SECONDS); |
| | | } |
| | | emitter.complete(); |
| | | } catch (Exception ignore) {} |
| | | }, e -> { |
| | | try { |
| | | try { emitter.send(SseEmitter.event().data("【AI】运行已停止(异常)")); } catch (Exception ignore) {} |
| | | emitter.complete(); |
| | | } catch (Exception ignore) {} |
| | | }); |
| | | runMcpStreamingDiagnosis(messages, mcpSystem, mcpUser, 0.3, 2048, emitter, finalChatId); |
| | | } |
| | | |
| | | public List<Map<String, Object>> listChats() { |
| | |
| | | return p.length() > 20 ? p.substring(0, 20) : p; |
| | | } |
| | | |
| | | private boolean runMcpStreamingDiagnosis(List<ChatCompletionRequest.Message> baseMessages, |
| | | ChatCompletionRequest.Message systemPrompt, |
| | | ChatCompletionRequest.Message userQuestion, |
| | | Double temperature, |
| | | Integer maxTokens, |
| | | SseEmitter emitter, |
| | | String chatId) { |
| | | private void runMcpStreamingDiagnosis(List<ChatCompletionRequest.Message> baseMessages, |
| | | ChatCompletionRequest.Message systemPrompt, |
| | | ChatCompletionRequest.Message userQuestion, |
| | | Double temperature, |
| | | Integer maxTokens, |
| | | SseEmitter emitter, |
| | | String chatId) { |
| | | try { |
| | | if (mcpController == null) return false; |
| | | List<Object> tools = buildOpenAiTools(); |
| | | if (tools.isEmpty()) return false; |
| | | if (mcpToolManager == null) { |
| | | throw new IllegalStateException("Spring AI MCP tool manager is unavailable"); |
| | | } |
| | | List<Object> tools = mcpToolManager.buildOpenAiTools(); |
| | | if (tools.isEmpty()) { |
| | | throw new IllegalStateException("No MCP tools registered"); |
| | | } |
| | | |
| | | baseMessages.add(systemPrompt); |
| | | baseMessages.add(userQuestion); |
| | |
| | | sse(emitter, "\\n正在分析(第" + (i + 1) + "轮)...\\n"); |
| | | ChatCompletionResponse resp = llmChatService.chatCompletion(messages, temperature, maxTokens, tools); |
| | | if (resp == null || resp.getChoices() == null || resp.getChoices().isEmpty() || resp.getChoices().get(0).getMessage() == null) { |
| | | sse(emitter, "\\n分析出错,正在回退...\\n"); |
| | | return false; |
| | | throw new IllegalStateException("LLM returned empty response"); |
| | | } |
| | | |
| | | ChatCompletionRequest.Message assistant = resp.getChoices().get(0).getMessage(); |
| | |
| | | } |
| | | Object output; |
| | | try { |
| | | output = mcpController.callTool(toolName, args); |
| | | output = mcpToolManager.callTool(toolName, args); |
| | | } catch (Exception e) { |
| | | java.util.LinkedHashMap<String, Object> err = new java.util.LinkedHashMap<String, Object>(); |
| | | err.put("tool", toolName); |
| | |
| | | } |
| | | } catch (Exception ignore) {} |
| | | }, e -> { |
| | | sse(emitter, "\\n\\n【AI】分析出错,正在回退...\\n\\n"); |
| | | try { |
| | | sse(emitter, "\\n\\n【AI】分析出错,运行已停止(异常)\\n\\n"); |
| | | log.error("AI MCP diagnose stopped: stream error", e); |
| | | emitter.complete(); |
| | | } catch (Exception ignore) {} |
| | | }); |
| | | return true; |
| | | } catch (Exception e) { |
| | | try { |
| | | sse(emitter, "\\n\\n【AI】运行已停止(异常)\\n\\n"); |
| | | log.error("AI MCP diagnose stopped: error", e); |
| | | emitter.complete(); |
| | | } catch (Exception ignore) {} |
| | | return true; |
| | | } |
| | | } |
| | | |
| | |
| | | } catch (Exception e) { |
| | | log.warn("SSE send failed", e); |
| | | } |
| | | } |
| | | |
| | | private List<Object> buildOpenAiTools() { |
| | | if (mcpController == null) return java.util.Collections.emptyList(); |
| | | List<Map<String, Object>> mcpTools = mcpController.listTools(); |
| | | if (mcpTools == null || mcpTools.isEmpty()) return java.util.Collections.emptyList(); |
| | | |
| | | List<Object> tools = new ArrayList<>(); |
| | | for (Map<String, Object> t : mcpTools) { |
| | | if (t == null) continue; |
| | | Object name = t.get("name"); |
| | | if (name == null) continue; |
| | | Object inputSchema = t.get("inputSchema"); |
| | | java.util.LinkedHashMap<String, Object> function = new java.util.LinkedHashMap<String, Object>(); |
| | | function.put("name", String.valueOf(name)); |
| | | Object desc = t.get("description"); |
| | | if (desc != null) function.put("description", String.valueOf(desc)); |
| | | function.put("parameters", inputSchema == null ? new java.util.LinkedHashMap<String, Object>() : inputSchema); |
| | | |
| | | java.util.LinkedHashMap<String, Object> tool = new java.util.LinkedHashMap<String, Object>(); |
| | | tool.put("type", "function"); |
| | | tool.put("function", function); |
| | | tools.add(tool); |
| | | } |
| | | return tools; |
| | | } |
| | | |
| | | private void sendLargeText(SseEmitter emitter, String text) { |