package com.zy.ai.service.impl; import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONObject; import com.zy.ai.entity.AiPromptTemplate; import com.zy.ai.entity.ChatCompletionRequest; import com.zy.ai.entity.ChatCompletionResponse; import com.zy.ai.enums.AiPromptScene; import com.zy.ai.mcp.service.SpringAiMcpToolManager; import com.zy.ai.service.AiPromptTemplateService; import com.zy.ai.service.DataAnalysisAgentService; import com.zy.ai.service.LlmChatService; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; import java.time.DayOfWeek; import java.time.LocalDate; import java.time.LocalDateTime; import java.time.LocalTime; import java.util.*; @Slf4j @Service @RequiredArgsConstructor public class DataAnalysisAgentServiceImpl implements DataAnalysisAgentService { private static final int MAX_TOOL_ROUNDS = 10; private static final double TEMPERATURE = 0.3D; private static final int MAX_TOKENS = 4096; private static final String MCP_STATUS_SUCCESS = "success"; private static final String MCP_STATUS_FAILED = "failed"; private static final String TOOL_THROUGHPUT = "wcs_local_analysis_query_task_throughput"; private static final String TOOL_FAULT_SUMMARY = "wcs_local_analysis_query_device_fault_summary"; private static final String TOOL_UTILIZATION = "wcs_local_analysis_query_device_utilization"; private static final String TOOL_ERROR_LOGS = "wcs_local_analysis_query_error_logs"; private static final Set ALLOWED_TOOL_NAMES = Set.of( TOOL_THROUGHPUT, TOOL_FAULT_SUMMARY, TOOL_UTILIZATION, TOOL_ERROR_LOGS ); private final LlmChatService llmChatService; private final SpringAiMcpToolManager mcpToolManager; private final AiPromptTemplateService aiPromptTemplateService; @Override public DataAnalysisAgentResult runAnalysis(String periodType) { String normalizedPeriod = normalizePeriodType(periodType); DateRange dateRange = resolveDateRange(normalizedPeriod); UsageCounter usageCounter = new UsageCounter(); List mcpCalls = new ArrayList<>(); boolean maxRoundsReached = false; StringBuilder summaryBuffer = new StringBuilder(); int toolCallCount = 0; try { List tools = filterAllowedTools(mcpToolManager.buildOpenAiTools()); if (tools == null || tools.isEmpty()) { throw new IllegalStateException("No data analysis MCP tools registered"); } AiPromptTemplate promptTemplate = aiPromptTemplateService.resolvePublished(AiPromptScene.DATA_ANALYSIS.getCode()); List messages = buildMessages(promptTemplate, normalizedPeriod, dateRange); for (int round = 0; round < MAX_TOOL_ROUNDS; round++) { ChatCompletionResponse response = llmChatService.chatCompletionOrThrow(messages, TEMPERATURE, MAX_TOKENS, tools); ChatCompletionRequest.Message assistantMessage = extractAssistantMessage(response); usageCounter.add(response.getUsage()); messages.add(assistantMessage); appendSummary(summaryBuffer, assistantMessage.getContent()); List toolCalls = assistantMessage.getTool_calls(); if (toolCalls == null || toolCalls.isEmpty()) { return buildResult(true, normalizedPeriod, summaryBuffer, toolCallCount, usageCounter, false, mcpCalls); } for (ChatCompletionRequest.ToolCall toolCall : toolCalls) { McpCallResult mcpCall = callAnalysisTool(toolCall, mcpCalls); toolCallCount++; Object toolOutput = parseToolOutput(mcpCall); messages.add(buildToolMessage(toolCall, toolOutput)); } } maxRoundsReached = true; return buildResult(false, normalizedPeriod, summaryBuffer, toolCallCount, usageCounter, maxRoundsReached, mcpCalls); } catch (Exception exception) { log.error("Data analysis agent stopped with error", exception); appendSummary(summaryBuffer, "数据分析 Agent 执行异常: " + exception.getMessage()); return buildResult(false, normalizedPeriod, summaryBuffer, toolCallCount, usageCounter, maxRoundsReached, mcpCalls); } } private McpCallResult callAnalysisTool(ChatCompletionRequest.ToolCall toolCall, List mcpCalls) { String toolName = resolveToolName(toolCall); if (!ALLOWED_TOOL_NAMES.contains(toolName)) { throw new IllegalArgumentException("Disallowed data analysis MCP tool: " + toolName); } JSONObject arguments = parseArguments(toolCall); long startTimeMillis = System.currentTimeMillis(); McpCallResult mcpCall = new McpCallResult(); mcpCall.setCallSeq(mcpCalls.size() + 1); mcpCall.setToolName(toolName); mcpCall.setRequestJson(JSON.toJSONString(arguments == null ? new JSONObject() : arguments)); try { Object output = mcpToolManager.callTool(toolName, arguments); mcpCall.setDurationMs(Math.max(0L, System.currentTimeMillis() - startTimeMillis)); mcpCall.setStatus(MCP_STATUS_SUCCESS); mcpCall.setResponseJson(JSON.toJSONString(output)); mcpCalls.add(mcpCall); return mcpCall; } catch (Exception exception) { mcpCall.setDurationMs(Math.max(0L, System.currentTimeMillis() - startTimeMillis)); mcpCall.setStatus(MCP_STATUS_FAILED); mcpCall.setErrorMessage(exception.getMessage()); mcpCalls.add(mcpCall); throw new IllegalStateException("Data analysis MCP tool failed: " + toolName + ", " + exception.getMessage(), exception); } } private Object parseToolOutput(McpCallResult mcpCall) { if (MCP_STATUS_FAILED.equals(mcpCall.getStatus())) { JSONObject err = new JSONObject(); err.put("error", mcpCall.getErrorMessage()); return err; } if (mcpCall.getResponseJson() == null || mcpCall.getResponseJson().isEmpty()) { return new JSONObject(); } try { return JSON.parse(mcpCall.getResponseJson()); } catch (Exception e) { return mcpCall.getResponseJson(); } } private List buildMessages(AiPromptTemplate promptTemplate, String periodType, DateRange dateRange) { List messages = new ArrayList<>(); ChatCompletionRequest.Message systemMessage = new ChatCompletionRequest.Message(); systemMessage.setRole("system"); systemMessage.setContent(promptTemplate == null ? "" : promptTemplate.getContent()); messages.add(systemMessage); ChatCompletionRequest.Message userMessage = new ChatCompletionRequest.Message(); userMessage.setRole("user"); userMessage.setContent("请分析" + periodLabel(periodType) + "的WCS运营数据。" + "时间范围:startTime=" + dateRange.start + ", endTime=" + dateRange.end + "。请依次调用所有分析工具获取数据,然后生成完整的分析报告。"); messages.add(userMessage); return messages; } private String periodLabel(String periodType) { switch (periodType) { case "TODAY": return "今天"; case "YESTERDAY": return "昨天"; case "THIS_WEEK": return "本周"; case "THIS_MONTH": return "本月"; default: return periodType; } } private DateRange resolveDateRange(String periodType) { LocalDate today = LocalDate.now(); switch (periodType) { case "TODAY": return new DateRange(today.atStartOfDay(), today.plusDays(1).atStartOfDay()); case "YESTERDAY": return new DateRange(today.minusDays(1).atStartOfDay(), today.atStartOfDay()); case "THIS_WEEK": LocalDate weekStart = today.with(DayOfWeek.MONDAY); return new DateRange(weekStart.atStartOfDay(), today.plusDays(1).atStartOfDay()); case "THIS_MONTH": LocalDate monthStart = today.withDayOfMonth(1); return new DateRange(monthStart.atStartOfDay(), today.plusDays(1).atStartOfDay()); default: throw new IllegalArgumentException("Unknown period: " + periodType); } } private ChatCompletionRequest.Message extractAssistantMessage(ChatCompletionResponse response) { if (response == null || response.getChoices() == null || response.getChoices().isEmpty()) { throw new IllegalStateException("LLM returned empty response"); } ChatCompletionRequest.Message message = response.getChoices().get(0).getMessage(); if (message == null) { throw new IllegalStateException("LLM returned empty message"); } return message; } private ChatCompletionRequest.Message buildToolMessage(ChatCompletionRequest.ToolCall toolCall, Object toolOutput) { ChatCompletionRequest.Message toolMessage = new ChatCompletionRequest.Message(); toolMessage.setRole("tool"); toolMessage.setTool_call_id(toolCall == null ? null : toolCall.getId()); toolMessage.setContent(JSON.toJSONString(toolOutput)); return toolMessage; } private String resolveToolName(ChatCompletionRequest.ToolCall toolCall) { if (toolCall == null || toolCall.getFunction() == null || toolCall.getFunction().getName() == null || toolCall.getFunction().getName().trim().isEmpty()) { throw new IllegalArgumentException("missing tool name"); } return toolCall.getFunction().getName(); } private JSONObject parseArguments(ChatCompletionRequest.ToolCall toolCall) { String rawArguments = toolCall == null || toolCall.getFunction() == null ? null : toolCall.getFunction().getArguments(); if (rawArguments == null || rawArguments.trim().isEmpty()) { return new JSONObject(); } try { return JSON.parseObject(rawArguments); } catch (Exception exception) { JSONObject arguments = new JSONObject(); arguments.put("_raw", rawArguments); return arguments; } } private List filterAllowedTools(List tools) { List allowedTools = new ArrayList<>(); if (tools == null || tools.isEmpty()) { return allowedTools; } for (Object tool : tools) { String toolName = resolveOpenAiToolName(tool); if (ALLOWED_TOOL_NAMES.contains(toolName)) { allowedTools.add(tool); } } return allowedTools; } private String resolveOpenAiToolName(Object tool) { if (!(tool instanceof Map toolMap)) { return null; } Object function = toolMap.get("function"); if (!(function instanceof Map functionMap)) { return null; } Object name = functionMap.get("name"); return name == null ? null : String.valueOf(name); } private DataAnalysisAgentResult buildResult(boolean success, String periodType, StringBuilder summaryBuffer, int toolCallCount, UsageCounter usageCounter, boolean maxRoundsReached, List mcpCalls) { DataAnalysisAgentResult result = new DataAnalysisAgentResult(); result.setSuccess(success); result.setPeriodType(periodType); result.setTriggerType("agent"); result.setToolCallCount(toolCallCount); result.setLlmCallCount(usageCounter.getLlmCallCount()); result.setPromptTokens(usageCounter.getPromptTokens()); result.setCompletionTokens(usageCounter.getCompletionTokens()); result.setTotalTokens(usageCounter.getTotalTokens()); result.setMaxRoundsReached(maxRoundsReached); result.setMcpCalls(mcpCalls != null ? new ArrayList<>(mcpCalls) : new ArrayList<>()); String summary = summaryBuffer == null ? "" : summaryBuffer.toString().trim(); if (toolCallCount <= 0) { summary = "数据分析 Agent 未调用任何分析工具,未生成报告。" + (summary.isEmpty() ? "" : "\n" + summary); } if (maxRoundsReached) { summary = summary + "\n数据分析 Agent 达到最大工具调用轮次,已停止。"; } result.setSummary(summary); return result; } private void appendSummary(StringBuilder summaryBuffer, String content) { if (summaryBuffer == null || content == null || content.trim().isEmpty()) { return; } if (summaryBuffer.length() > 0) { summaryBuffer.append('\n'); } summaryBuffer.append(content.trim()); } private String normalizePeriodType(String periodType) { if (periodType == null || periodType.trim().isEmpty()) { return "YESTERDAY"; } return periodType.trim().toUpperCase(); } private static class DateRange { final LocalDateTime start; final LocalDateTime end; DateRange(LocalDateTime start, LocalDateTime end) { this.start = start; this.end = end; } } private static class UsageCounter { private long promptTokens; private long completionTokens; private long totalTokens; private int llmCallCount; void add(ChatCompletionResponse.Usage usage) { llmCallCount++; if (usage == null) { return; } promptTokens += usage.getPromptTokens() == null ? 0L : usage.getPromptTokens(); completionTokens += usage.getCompletionTokens() == null ? 0L : usage.getCompletionTokens(); totalTokens += usage.getTotalTokens() == null ? 0L : usage.getTotalTokens(); } long getPromptTokens() { return promptTokens; } long getCompletionTokens() { return completionTokens; } long getTotalTokens() { return totalTokens; } int getLlmCallCount() { return llmCallCount; } } }