package com.zy.ai.service.impl;
|
|
import com.alibaba.fastjson.JSON;
|
import com.alibaba.fastjson.JSONObject;
|
import com.zy.ai.entity.AiPromptTemplate;
|
import com.zy.ai.entity.ChatCompletionRequest;
|
import com.zy.ai.entity.ChatCompletionResponse;
|
import com.zy.ai.enums.AiPromptScene;
|
import com.zy.ai.mcp.service.SpringAiMcpToolManager;
|
import com.zy.ai.service.AiPromptTemplateService;
|
import com.zy.ai.service.DataAnalysisAgentService;
|
import com.zy.ai.service.LlmChatService;
|
import lombok.RequiredArgsConstructor;
|
import lombok.extern.slf4j.Slf4j;
|
import org.springframework.stereotype.Service;
|
|
import java.time.DayOfWeek;
|
import java.time.LocalDate;
|
import java.time.LocalDateTime;
|
import java.time.LocalTime;
|
import java.util.*;
|
|
@Slf4j
|
@Service
|
@RequiredArgsConstructor
|
public class DataAnalysisAgentServiceImpl implements DataAnalysisAgentService {
|
|
private static final int MAX_TOOL_ROUNDS = 10;
|
private static final double TEMPERATURE = 0.3D;
|
private static final int MAX_TOKENS = 4096;
|
private static final String MCP_STATUS_SUCCESS = "success";
|
private static final String MCP_STATUS_FAILED = "failed";
|
|
private static final String TOOL_THROUGHPUT = "wcs_local_analysis_query_task_throughput";
|
private static final String TOOL_FAULT_SUMMARY = "wcs_local_analysis_query_device_fault_summary";
|
private static final String TOOL_UTILIZATION = "wcs_local_analysis_query_device_utilization";
|
private static final String TOOL_ERROR_LOGS = "wcs_local_analysis_query_error_logs";
|
|
private static final Set<String> ALLOWED_TOOL_NAMES = Set.of(
|
TOOL_THROUGHPUT,
|
TOOL_FAULT_SUMMARY,
|
TOOL_UTILIZATION,
|
TOOL_ERROR_LOGS
|
);
|
|
private final LlmChatService llmChatService;
|
private final SpringAiMcpToolManager mcpToolManager;
|
private final AiPromptTemplateService aiPromptTemplateService;
|
|
@Override
|
public DataAnalysisAgentResult runAnalysis(String periodType) {
|
String normalizedPeriod = normalizePeriodType(periodType);
|
DateRange dateRange = resolveDateRange(normalizedPeriod);
|
UsageCounter usageCounter = new UsageCounter();
|
List<McpCallResult> mcpCalls = new ArrayList<>();
|
boolean maxRoundsReached = false;
|
StringBuilder summaryBuffer = new StringBuilder();
|
int toolCallCount = 0;
|
|
try {
|
List<Object> tools = filterAllowedTools(mcpToolManager.buildOpenAiTools());
|
if (tools == null || tools.isEmpty()) {
|
throw new IllegalStateException("No data analysis MCP tools registered");
|
}
|
|
AiPromptTemplate promptTemplate = aiPromptTemplateService.resolvePublished(AiPromptScene.DATA_ANALYSIS.getCode());
|
List<ChatCompletionRequest.Message> messages = buildMessages(promptTemplate, normalizedPeriod, dateRange);
|
|
for (int round = 0; round < MAX_TOOL_ROUNDS; round++) {
|
ChatCompletionResponse response = llmChatService.chatCompletionOrThrow(messages, TEMPERATURE, MAX_TOKENS, tools);
|
ChatCompletionRequest.Message assistantMessage = extractAssistantMessage(response);
|
usageCounter.add(response.getUsage());
|
messages.add(assistantMessage);
|
appendSummary(summaryBuffer, assistantMessage.getContent());
|
|
List<ChatCompletionRequest.ToolCall> toolCalls = assistantMessage.getTool_calls();
|
if (toolCalls == null || toolCalls.isEmpty()) {
|
return buildResult(true, normalizedPeriod, summaryBuffer, toolCallCount, usageCounter, false, mcpCalls);
|
}
|
|
for (ChatCompletionRequest.ToolCall toolCall : toolCalls) {
|
McpCallResult mcpCall = callAnalysisTool(toolCall, mcpCalls);
|
toolCallCount++;
|
Object toolOutput = parseToolOutput(mcpCall);
|
messages.add(buildToolMessage(toolCall, toolOutput));
|
}
|
}
|
maxRoundsReached = true;
|
return buildResult(false, normalizedPeriod, summaryBuffer, toolCallCount, usageCounter, maxRoundsReached, mcpCalls);
|
} catch (Exception exception) {
|
log.error("Data analysis agent stopped with error", exception);
|
appendSummary(summaryBuffer, "数据分析 Agent 执行异常: " + exception.getMessage());
|
return buildResult(false, normalizedPeriod, summaryBuffer, toolCallCount, usageCounter, maxRoundsReached, mcpCalls);
|
}
|
}
|
|
private McpCallResult callAnalysisTool(ChatCompletionRequest.ToolCall toolCall, List<McpCallResult> mcpCalls) {
|
String toolName = resolveToolName(toolCall);
|
if (!ALLOWED_TOOL_NAMES.contains(toolName)) {
|
throw new IllegalArgumentException("Disallowed data analysis MCP tool: " + toolName);
|
}
|
JSONObject arguments = parseArguments(toolCall);
|
long startTimeMillis = System.currentTimeMillis();
|
McpCallResult mcpCall = new McpCallResult();
|
mcpCall.setCallSeq(mcpCalls.size() + 1);
|
mcpCall.setToolName(toolName);
|
mcpCall.setRequestJson(JSON.toJSONString(arguments == null ? new JSONObject() : arguments));
|
try {
|
Object output = mcpToolManager.callTool(toolName, arguments);
|
mcpCall.setDurationMs(Math.max(0L, System.currentTimeMillis() - startTimeMillis));
|
mcpCall.setStatus(MCP_STATUS_SUCCESS);
|
mcpCall.setResponseJson(JSON.toJSONString(output));
|
mcpCalls.add(mcpCall);
|
return mcpCall;
|
} catch (Exception exception) {
|
mcpCall.setDurationMs(Math.max(0L, System.currentTimeMillis() - startTimeMillis));
|
mcpCall.setStatus(MCP_STATUS_FAILED);
|
mcpCall.setErrorMessage(exception.getMessage());
|
mcpCalls.add(mcpCall);
|
throw new IllegalStateException("Data analysis MCP tool failed: " + toolName + ", " + exception.getMessage(), exception);
|
}
|
}
|
|
private Object parseToolOutput(McpCallResult mcpCall) {
|
if (MCP_STATUS_FAILED.equals(mcpCall.getStatus())) {
|
JSONObject err = new JSONObject();
|
err.put("error", mcpCall.getErrorMessage());
|
return err;
|
}
|
if (mcpCall.getResponseJson() == null || mcpCall.getResponseJson().isEmpty()) {
|
return new JSONObject();
|
}
|
try {
|
return JSON.parse(mcpCall.getResponseJson());
|
} catch (Exception e) {
|
return mcpCall.getResponseJson();
|
}
|
}
|
|
private List<ChatCompletionRequest.Message> buildMessages(AiPromptTemplate promptTemplate,
|
String periodType,
|
DateRange dateRange) {
|
List<ChatCompletionRequest.Message> messages = new ArrayList<>();
|
|
ChatCompletionRequest.Message systemMessage = new ChatCompletionRequest.Message();
|
systemMessage.setRole("system");
|
systemMessage.setContent(promptTemplate == null ? "" : promptTemplate.getContent());
|
messages.add(systemMessage);
|
|
ChatCompletionRequest.Message userMessage = new ChatCompletionRequest.Message();
|
userMessage.setRole("user");
|
userMessage.setContent("请分析" + periodLabel(periodType) + "的WCS运营数据。"
|
+ "时间范围:startTime=" + dateRange.start + ", endTime=" + dateRange.end
|
+ "。请依次调用所有分析工具获取数据,然后生成完整的分析报告。");
|
messages.add(userMessage);
|
return messages;
|
}
|
|
private String periodLabel(String periodType) {
|
switch (periodType) {
|
case "TODAY": return "今天";
|
case "YESTERDAY": return "昨天";
|
case "THIS_WEEK": return "本周";
|
case "THIS_MONTH": return "本月";
|
default: return periodType;
|
}
|
}
|
|
private DateRange resolveDateRange(String periodType) {
|
LocalDate today = LocalDate.now();
|
switch (periodType) {
|
case "TODAY":
|
return new DateRange(today.atStartOfDay(), today.plusDays(1).atStartOfDay());
|
case "YESTERDAY":
|
return new DateRange(today.minusDays(1).atStartOfDay(), today.atStartOfDay());
|
case "THIS_WEEK":
|
LocalDate weekStart = today.with(DayOfWeek.MONDAY);
|
return new DateRange(weekStart.atStartOfDay(), today.plusDays(1).atStartOfDay());
|
case "THIS_MONTH":
|
LocalDate monthStart = today.withDayOfMonth(1);
|
return new DateRange(monthStart.atStartOfDay(), today.plusDays(1).atStartOfDay());
|
default:
|
throw new IllegalArgumentException("Unknown period: " + periodType);
|
}
|
}
|
|
private ChatCompletionRequest.Message extractAssistantMessage(ChatCompletionResponse response) {
|
if (response == null || response.getChoices() == null || response.getChoices().isEmpty()) {
|
throw new IllegalStateException("LLM returned empty response");
|
}
|
ChatCompletionRequest.Message message = response.getChoices().get(0).getMessage();
|
if (message == null) {
|
throw new IllegalStateException("LLM returned empty message");
|
}
|
return message;
|
}
|
|
private ChatCompletionRequest.Message buildToolMessage(ChatCompletionRequest.ToolCall toolCall, Object toolOutput) {
|
ChatCompletionRequest.Message toolMessage = new ChatCompletionRequest.Message();
|
toolMessage.setRole("tool");
|
toolMessage.setTool_call_id(toolCall == null ? null : toolCall.getId());
|
toolMessage.setContent(JSON.toJSONString(toolOutput));
|
return toolMessage;
|
}
|
|
private String resolveToolName(ChatCompletionRequest.ToolCall toolCall) {
|
if (toolCall == null || toolCall.getFunction() == null || toolCall.getFunction().getName() == null
|
|| toolCall.getFunction().getName().trim().isEmpty()) {
|
throw new IllegalArgumentException("missing tool name");
|
}
|
return toolCall.getFunction().getName();
|
}
|
|
private JSONObject parseArguments(ChatCompletionRequest.ToolCall toolCall) {
|
String rawArguments = toolCall == null || toolCall.getFunction() == null
|
? null
|
: toolCall.getFunction().getArguments();
|
if (rawArguments == null || rawArguments.trim().isEmpty()) {
|
return new JSONObject();
|
}
|
try {
|
return JSON.parseObject(rawArguments);
|
} catch (Exception exception) {
|
JSONObject arguments = new JSONObject();
|
arguments.put("_raw", rawArguments);
|
return arguments;
|
}
|
}
|
|
private List<Object> filterAllowedTools(List<Object> tools) {
|
List<Object> allowedTools = new ArrayList<>();
|
if (tools == null || tools.isEmpty()) {
|
return allowedTools;
|
}
|
for (Object tool : tools) {
|
String toolName = resolveOpenAiToolName(tool);
|
if (ALLOWED_TOOL_NAMES.contains(toolName)) {
|
allowedTools.add(tool);
|
}
|
}
|
return allowedTools;
|
}
|
|
private String resolveOpenAiToolName(Object tool) {
|
if (!(tool instanceof Map<?, ?> toolMap)) {
|
return null;
|
}
|
Object function = toolMap.get("function");
|
if (!(function instanceof Map<?, ?> functionMap)) {
|
return null;
|
}
|
Object name = functionMap.get("name");
|
return name == null ? null : String.valueOf(name);
|
}
|
|
private DataAnalysisAgentResult buildResult(boolean success,
|
String periodType,
|
StringBuilder summaryBuffer,
|
int toolCallCount,
|
UsageCounter usageCounter,
|
boolean maxRoundsReached,
|
List<McpCallResult> mcpCalls) {
|
DataAnalysisAgentResult result = new DataAnalysisAgentResult();
|
result.setSuccess(success);
|
result.setPeriodType(periodType);
|
result.setTriggerType("agent");
|
result.setToolCallCount(toolCallCount);
|
result.setLlmCallCount(usageCounter.getLlmCallCount());
|
result.setPromptTokens(usageCounter.getPromptTokens());
|
result.setCompletionTokens(usageCounter.getCompletionTokens());
|
result.setTotalTokens(usageCounter.getTotalTokens());
|
result.setMaxRoundsReached(maxRoundsReached);
|
result.setMcpCalls(mcpCalls != null ? new ArrayList<>(mcpCalls) : new ArrayList<>());
|
|
String summary = summaryBuffer == null ? "" : summaryBuffer.toString().trim();
|
if (toolCallCount <= 0) {
|
summary = "数据分析 Agent 未调用任何分析工具,未生成报告。" + (summary.isEmpty() ? "" : "\n" + summary);
|
}
|
if (maxRoundsReached) {
|
summary = summary + "\n数据分析 Agent 达到最大工具调用轮次,已停止。";
|
}
|
result.setSummary(summary);
|
return result;
|
}
|
|
private void appendSummary(StringBuilder summaryBuffer, String content) {
|
if (summaryBuffer == null || content == null || content.trim().isEmpty()) {
|
return;
|
}
|
if (summaryBuffer.length() > 0) {
|
summaryBuffer.append('\n');
|
}
|
summaryBuffer.append(content.trim());
|
}
|
|
private String normalizePeriodType(String periodType) {
|
if (periodType == null || periodType.trim().isEmpty()) {
|
return "YESTERDAY";
|
}
|
return periodType.trim().toUpperCase();
|
}
|
|
private static class DateRange {
|
final LocalDateTime start;
|
final LocalDateTime end;
|
|
DateRange(LocalDateTime start, LocalDateTime end) {
|
this.start = start;
|
this.end = end;
|
}
|
}
|
|
private static class UsageCounter {
|
private long promptTokens;
|
private long completionTokens;
|
private long totalTokens;
|
private int llmCallCount;
|
|
void add(ChatCompletionResponse.Usage usage) {
|
llmCallCount++;
|
if (usage == null) {
|
return;
|
}
|
promptTokens += usage.getPromptTokens() == null ? 0L : usage.getPromptTokens();
|
completionTokens += usage.getCompletionTokens() == null ? 0L : usage.getCompletionTokens();
|
totalTokens += usage.getTotalTokens() == null ? 0L : usage.getTotalTokens();
|
}
|
|
long getPromptTokens() { return promptTokens; }
|
long getCompletionTokens() { return completionTokens; }
|
long getTotalTokens() { return totalTokens; }
|
int getLlmCallCount() { return llmCallCount; }
|
}
|
}
|