Junjie
2 天以前 63b01db83d9aad8a15276b4236a9a22e4aeef065
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
package com.zy.ai.service.impl;
 
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.zy.ai.entity.AiPromptTemplate;
import com.zy.ai.entity.ChatCompletionRequest;
import com.zy.ai.entity.ChatCompletionResponse;
import com.zy.ai.enums.AiPromptScene;
import com.zy.ai.mcp.service.SpringAiMcpToolManager;
import com.zy.ai.service.AiPromptTemplateService;
import com.zy.ai.service.DataAnalysisAgentService;
import com.zy.ai.service.LlmChatService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
 
import java.time.DayOfWeek;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.LocalTime;
import java.util.*;
 
@Slf4j
@Service
@RequiredArgsConstructor
public class DataAnalysisAgentServiceImpl implements DataAnalysisAgentService {
 
    private static final int MAX_TOOL_ROUNDS = 10;
    private static final double TEMPERATURE = 0.3D;
    private static final int MAX_TOKENS = 4096;
    private static final String MCP_STATUS_SUCCESS = "success";
    private static final String MCP_STATUS_FAILED = "failed";
 
    private static final String TOOL_THROUGHPUT = "wcs_local_analysis_query_task_throughput";
    private static final String TOOL_FAULT_SUMMARY = "wcs_local_analysis_query_device_fault_summary";
    private static final String TOOL_UTILIZATION = "wcs_local_analysis_query_device_utilization";
    private static final String TOOL_ERROR_LOGS = "wcs_local_analysis_query_error_logs";
 
    private static final Set<String> ALLOWED_TOOL_NAMES = Set.of(
            TOOL_THROUGHPUT,
            TOOL_FAULT_SUMMARY,
            TOOL_UTILIZATION,
            TOOL_ERROR_LOGS
    );
 
    private final LlmChatService llmChatService;
    private final SpringAiMcpToolManager mcpToolManager;
    private final AiPromptTemplateService aiPromptTemplateService;
 
    @Override
    public DataAnalysisAgentResult runAnalysis(String periodType) {
        String normalizedPeriod = normalizePeriodType(periodType);
        DateRange dateRange = resolveDateRange(normalizedPeriod);
        UsageCounter usageCounter = new UsageCounter();
        List<McpCallResult> mcpCalls = new ArrayList<>();
        boolean maxRoundsReached = false;
        StringBuilder summaryBuffer = new StringBuilder();
        int toolCallCount = 0;
 
        try {
            List<Object> tools = filterAllowedTools(mcpToolManager.buildOpenAiTools());
            if (tools == null || tools.isEmpty()) {
                throw new IllegalStateException("No data analysis MCP tools registered");
            }
 
            AiPromptTemplate promptTemplate = aiPromptTemplateService.resolvePublished(AiPromptScene.DATA_ANALYSIS.getCode());
            List<ChatCompletionRequest.Message> messages = buildMessages(promptTemplate, normalizedPeriod, dateRange);
 
            for (int round = 0; round < MAX_TOOL_ROUNDS; round++) {
                ChatCompletionResponse response = llmChatService.chatCompletionOrThrow(messages, TEMPERATURE, MAX_TOKENS, tools);
                ChatCompletionRequest.Message assistantMessage = extractAssistantMessage(response);
                usageCounter.add(response.getUsage());
                messages.add(assistantMessage);
                appendSummary(summaryBuffer, assistantMessage.getContent());
 
                List<ChatCompletionRequest.ToolCall> toolCalls = assistantMessage.getTool_calls();
                if (toolCalls == null || toolCalls.isEmpty()) {
                    return buildResult(true, normalizedPeriod, summaryBuffer, toolCallCount, usageCounter, false, mcpCalls);
                }
 
                for (ChatCompletionRequest.ToolCall toolCall : toolCalls) {
                    McpCallResult mcpCall = callAnalysisTool(toolCall, mcpCalls);
                    toolCallCount++;
                    Object toolOutput = parseToolOutput(mcpCall);
                    messages.add(buildToolMessage(toolCall, toolOutput));
                }
            }
            maxRoundsReached = true;
            return buildResult(false, normalizedPeriod, summaryBuffer, toolCallCount, usageCounter, maxRoundsReached, mcpCalls);
        } catch (Exception exception) {
            log.error("Data analysis agent stopped with error", exception);
            appendSummary(summaryBuffer, "数据分析 Agent 执行异常: " + exception.getMessage());
            return buildResult(false, normalizedPeriod, summaryBuffer, toolCallCount, usageCounter, maxRoundsReached, mcpCalls);
        }
    }
 
    private McpCallResult callAnalysisTool(ChatCompletionRequest.ToolCall toolCall, List<McpCallResult> mcpCalls) {
        String toolName = resolveToolName(toolCall);
        if (!ALLOWED_TOOL_NAMES.contains(toolName)) {
            throw new IllegalArgumentException("Disallowed data analysis MCP tool: " + toolName);
        }
        JSONObject arguments = parseArguments(toolCall);
        long startTimeMillis = System.currentTimeMillis();
        McpCallResult mcpCall = new McpCallResult();
        mcpCall.setCallSeq(mcpCalls.size() + 1);
        mcpCall.setToolName(toolName);
        mcpCall.setRequestJson(JSON.toJSONString(arguments == null ? new JSONObject() : arguments));
        try {
            Object output = mcpToolManager.callTool(toolName, arguments);
            mcpCall.setDurationMs(Math.max(0L, System.currentTimeMillis() - startTimeMillis));
            mcpCall.setStatus(MCP_STATUS_SUCCESS);
            mcpCall.setResponseJson(JSON.toJSONString(output));
            mcpCalls.add(mcpCall);
            return mcpCall;
        } catch (Exception exception) {
            mcpCall.setDurationMs(Math.max(0L, System.currentTimeMillis() - startTimeMillis));
            mcpCall.setStatus(MCP_STATUS_FAILED);
            mcpCall.setErrorMessage(exception.getMessage());
            mcpCalls.add(mcpCall);
            throw new IllegalStateException("Data analysis MCP tool failed: " + toolName + ", " + exception.getMessage(), exception);
        }
    }
 
    private Object parseToolOutput(McpCallResult mcpCall) {
        if (MCP_STATUS_FAILED.equals(mcpCall.getStatus())) {
            JSONObject err = new JSONObject();
            err.put("error", mcpCall.getErrorMessage());
            return err;
        }
        if (mcpCall.getResponseJson() == null || mcpCall.getResponseJson().isEmpty()) {
            return new JSONObject();
        }
        try {
            return JSON.parse(mcpCall.getResponseJson());
        } catch (Exception e) {
            return mcpCall.getResponseJson();
        }
    }
 
    private List<ChatCompletionRequest.Message> buildMessages(AiPromptTemplate promptTemplate,
                                                              String periodType,
                                                              DateRange dateRange) {
        List<ChatCompletionRequest.Message> messages = new ArrayList<>();
 
        ChatCompletionRequest.Message systemMessage = new ChatCompletionRequest.Message();
        systemMessage.setRole("system");
        systemMessage.setContent(promptTemplate == null ? "" : promptTemplate.getContent());
        messages.add(systemMessage);
 
        ChatCompletionRequest.Message userMessage = new ChatCompletionRequest.Message();
        userMessage.setRole("user");
        userMessage.setContent("请分析" + periodLabel(periodType) + "的WCS运营数据。"
                + "时间范围:startTime=" + dateRange.start + ", endTime=" + dateRange.end
                + "。请依次调用所有分析工具获取数据,然后生成完整的分析报告。");
        messages.add(userMessage);
        return messages;
    }
 
    private String periodLabel(String periodType) {
        switch (periodType) {
            case "TODAY": return "今天";
            case "YESTERDAY": return "昨天";
            case "THIS_WEEK": return "本周";
            case "THIS_MONTH": return "本月";
            default: return periodType;
        }
    }
 
    private DateRange resolveDateRange(String periodType) {
        LocalDate today = LocalDate.now();
        switch (periodType) {
            case "TODAY":
                return new DateRange(today.atStartOfDay(), today.plusDays(1).atStartOfDay());
            case "YESTERDAY":
                return new DateRange(today.minusDays(1).atStartOfDay(), today.atStartOfDay());
            case "THIS_WEEK":
                LocalDate weekStart = today.with(DayOfWeek.MONDAY);
                return new DateRange(weekStart.atStartOfDay(), today.plusDays(1).atStartOfDay());
            case "THIS_MONTH":
                LocalDate monthStart = today.withDayOfMonth(1);
                return new DateRange(monthStart.atStartOfDay(), today.plusDays(1).atStartOfDay());
            default:
                throw new IllegalArgumentException("Unknown period: " + periodType);
        }
    }
 
    private ChatCompletionRequest.Message extractAssistantMessage(ChatCompletionResponse response) {
        if (response == null || response.getChoices() == null || response.getChoices().isEmpty()) {
            throw new IllegalStateException("LLM returned empty response");
        }
        ChatCompletionRequest.Message message = response.getChoices().get(0).getMessage();
        if (message == null) {
            throw new IllegalStateException("LLM returned empty message");
        }
        return message;
    }
 
    private ChatCompletionRequest.Message buildToolMessage(ChatCompletionRequest.ToolCall toolCall, Object toolOutput) {
        ChatCompletionRequest.Message toolMessage = new ChatCompletionRequest.Message();
        toolMessage.setRole("tool");
        toolMessage.setTool_call_id(toolCall == null ? null : toolCall.getId());
        toolMessage.setContent(JSON.toJSONString(toolOutput));
        return toolMessage;
    }
 
    private String resolveToolName(ChatCompletionRequest.ToolCall toolCall) {
        if (toolCall == null || toolCall.getFunction() == null || toolCall.getFunction().getName() == null
                || toolCall.getFunction().getName().trim().isEmpty()) {
            throw new IllegalArgumentException("missing tool name");
        }
        return toolCall.getFunction().getName();
    }
 
    private JSONObject parseArguments(ChatCompletionRequest.ToolCall toolCall) {
        String rawArguments = toolCall == null || toolCall.getFunction() == null
                ? null
                : toolCall.getFunction().getArguments();
        if (rawArguments == null || rawArguments.trim().isEmpty()) {
            return new JSONObject();
        }
        try {
            return JSON.parseObject(rawArguments);
        } catch (Exception exception) {
            JSONObject arguments = new JSONObject();
            arguments.put("_raw", rawArguments);
            return arguments;
        }
    }
 
    private List<Object> filterAllowedTools(List<Object> tools) {
        List<Object> allowedTools = new ArrayList<>();
        if (tools == null || tools.isEmpty()) {
            return allowedTools;
        }
        for (Object tool : tools) {
            String toolName = resolveOpenAiToolName(tool);
            if (ALLOWED_TOOL_NAMES.contains(toolName)) {
                allowedTools.add(tool);
            }
        }
        return allowedTools;
    }
 
    private String resolveOpenAiToolName(Object tool) {
        if (!(tool instanceof Map<?, ?> toolMap)) {
            return null;
        }
        Object function = toolMap.get("function");
        if (!(function instanceof Map<?, ?> functionMap)) {
            return null;
        }
        Object name = functionMap.get("name");
        return name == null ? null : String.valueOf(name);
    }
 
    private DataAnalysisAgentResult buildResult(boolean success,
                                                String periodType,
                                                StringBuilder summaryBuffer,
                                                int toolCallCount,
                                                UsageCounter usageCounter,
                                                boolean maxRoundsReached,
                                                List<McpCallResult> mcpCalls) {
        DataAnalysisAgentResult result = new DataAnalysisAgentResult();
        result.setSuccess(success);
        result.setPeriodType(periodType);
        result.setTriggerType("agent");
        result.setToolCallCount(toolCallCount);
        result.setLlmCallCount(usageCounter.getLlmCallCount());
        result.setPromptTokens(usageCounter.getPromptTokens());
        result.setCompletionTokens(usageCounter.getCompletionTokens());
        result.setTotalTokens(usageCounter.getTotalTokens());
        result.setMaxRoundsReached(maxRoundsReached);
        result.setMcpCalls(mcpCalls != null ? new ArrayList<>(mcpCalls) : new ArrayList<>());
 
        String summary = summaryBuffer == null ? "" : summaryBuffer.toString().trim();
        if (toolCallCount <= 0) {
            summary = "数据分析 Agent 未调用任何分析工具,未生成报告。" + (summary.isEmpty() ? "" : "\n" + summary);
        }
        if (maxRoundsReached) {
            summary = summary + "\n数据分析 Agent 达到最大工具调用轮次,已停止。";
        }
        result.setSummary(summary);
        return result;
    }
 
    private void appendSummary(StringBuilder summaryBuffer, String content) {
        if (summaryBuffer == null || content == null || content.trim().isEmpty()) {
            return;
        }
        if (summaryBuffer.length() > 0) {
            summaryBuffer.append('\n');
        }
        summaryBuffer.append(content.trim());
    }
 
    private String normalizePeriodType(String periodType) {
        if (periodType == null || periodType.trim().isEmpty()) {
            return "YESTERDAY";
        }
        return periodType.trim().toUpperCase();
    }
 
    private static class DateRange {
        final LocalDateTime start;
        final LocalDateTime end;
 
        DateRange(LocalDateTime start, LocalDateTime end) {
            this.start = start;
            this.end = end;
        }
    }
 
    private static class UsageCounter {
        private long promptTokens;
        private long completionTokens;
        private long totalTokens;
        private int llmCallCount;
 
        void add(ChatCompletionResponse.Usage usage) {
            llmCallCount++;
            if (usage == null) {
                return;
            }
            promptTokens += usage.getPromptTokens() == null ? 0L : usage.getPromptTokens();
            completionTokens += usage.getCompletionTokens() == null ? 0L : usage.getCompletionTokens();
            totalTokens += usage.getTotalTokens() == null ? 0L : usage.getTotalTokens();
        }
 
        long getPromptTokens() { return promptTokens; }
        long getCompletionTokens() { return completionTokens; }
        long getTotalTokens() { return totalTokens; }
        int getLlmCallCount() { return llmCallCount; }
    }
}