Junjie
2026-04-27 cd04aa8b887e82ec664e42f0bc353c079be1d2c5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
package com.zy.ai.service.impl;
 
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.zy.ai.domain.autotune.AutoTuneTriggerType;
import com.zy.ai.entity.AiPromptTemplate;
import com.zy.ai.entity.ChatCompletionRequest;
import com.zy.ai.entity.ChatCompletionResponse;
import com.zy.ai.enums.AiPromptScene;
import com.zy.ai.mcp.service.SpringAiMcpToolManager;
import com.zy.ai.service.AiPromptTemplateService;
import com.zy.ai.service.AutoTuneAgentService;
import com.zy.ai.service.LlmChatService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
 
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
 
@Slf4j
@Service
@RequiredArgsConstructor
public class AutoTuneAgentServiceImpl implements AutoTuneAgentService {
 
    private static final int MAX_TOOL_ROUNDS = 10;
    private static final double TEMPERATURE = 0.2D;
    private static final int MAX_TOKENS = 2048;
    private static final String TOOL_GET_SNAPSHOT = "wcs_local_dispatch_get_auto_tune_snapshot";
    private static final String TOOL_GET_RECENT_JOBS = "wcs_local_dispatch_get_recent_auto_tune_jobs";
    private static final String TOOL_APPLY_CHANGES = "wcs_local_dispatch_apply_auto_tune_changes";
    private static final String TOOL_REVERT_LAST_JOB = "wcs_local_dispatch_revert_last_auto_tune_job";
    private static final Set<String> ALLOWED_TOOL_NAMES = Set.of(
            TOOL_GET_SNAPSHOT,
            TOOL_GET_RECENT_JOBS,
            TOOL_APPLY_CHANGES,
            TOOL_REVERT_LAST_JOB
    );
 
    private final LlmChatService llmChatService;
    private final SpringAiMcpToolManager mcpToolManager;
    private final AiPromptTemplateService aiPromptTemplateService;
 
    @Override
    public AutoTuneAgentResult runAutoTune(String triggerType) {
        String normalizedTriggerType = normalizeTriggerType(triggerType);
        UsageCounter usageCounter = new UsageCounter();
        RunState runState = new RunState();
        boolean maxRoundsReached = false;
        StringBuilder summaryBuffer = new StringBuilder();
 
        try {
            List<Object> tools = filterAllowedTools(mcpToolManager.buildOpenAiTools());
            if (tools == null || tools.isEmpty()) {
                throw new IllegalStateException("No auto-tune MCP tools registered");
            }
 
            AiPromptTemplate promptTemplate = aiPromptTemplateService.resolvePublished(AiPromptScene.AUTO_TUNE_DISPATCH.getCode());
            List<ChatCompletionRequest.Message> messages = buildMessages(promptTemplate, normalizedTriggerType);
 
            for (int round = 0; round < MAX_TOOL_ROUNDS; round++) {
                ChatCompletionResponse response = llmChatService.chatCompletion(messages, TEMPERATURE, MAX_TOKENS, tools);
                ChatCompletionRequest.Message assistantMessage = extractAssistantMessage(response);
                usageCounter.add(response.getUsage());
                messages.add(assistantMessage);
                appendSummary(summaryBuffer, assistantMessage.getContent());
 
                List<ChatCompletionRequest.ToolCall> toolCalls = assistantMessage.getTool_calls();
                if (toolCalls == null || toolCalls.isEmpty()) {
                    return buildResult(runState.isSuccessful(), normalizedTriggerType, summaryBuffer, runState,
                            usageCounter, false);
                }
 
                for (ChatCompletionRequest.ToolCall toolCall : toolCalls) {
                    Object toolOutput = callMountedTool(toolCall, runState, normalizedTriggerType);
                    messages.add(buildToolMessage(toolCall, toolOutput));
                }
            }
            maxRoundsReached = true;
            return buildResult(false, normalizedTriggerType, summaryBuffer, runState, usageCounter, maxRoundsReached);
        } catch (Exception exception) {
            log.error("Auto tune agent stopped with error", exception);
            appendSummary(summaryBuffer, "自动调参 Agent 执行异常: " + exception.getMessage());
            runState.markToolError();
            return buildResult(false, normalizedTriggerType, summaryBuffer, runState, usageCounter, maxRoundsReached);
        }
    }
 
    private List<ChatCompletionRequest.Message> buildMessages(AiPromptTemplate promptTemplate, String triggerType) {
        List<ChatCompletionRequest.Message> messages = new ArrayList<>();
 
        ChatCompletionRequest.Message systemMessage = new ChatCompletionRequest.Message();
        systemMessage.setRole("system");
        systemMessage.setContent(promptTemplate == null ? "" : promptTemplate.getContent());
        messages.add(systemMessage);
 
        ChatCompletionRequest.Message userMessage = new ChatCompletionRequest.Message();
        userMessage.setRole("user");
        userMessage.setContent("请执行一次后台 WCS 自动调参。triggerType=" + triggerType
                + "。必须先调用 wcs_local_dispatch_get_auto_tune_snapshot 获取事实;如需提交变更,"
                + "必须先 dry-run,再根据 dry-run 结果决定是否实际应用;实际应用时必须带上 dry-run 返回的 dryRunToken。"
                + "不要输出自由格式 JSON 供外层解析。");
        messages.add(userMessage);
        return messages;
    }
 
    private ChatCompletionRequest.Message extractAssistantMessage(ChatCompletionResponse response) {
        if (response == null || response.getChoices() == null || response.getChoices().isEmpty()) {
            throw new IllegalStateException("LLM returned empty response");
        }
        ChatCompletionRequest.Message message = response.getChoices().get(0).getMessage();
        if (message == null) {
            throw new IllegalStateException("LLM returned empty message");
        }
        return message;
    }
 
    private Object callMountedTool(ChatCompletionRequest.ToolCall toolCall, RunState runState, String triggerType) {
        String toolName = resolveToolName(toolCall);
        if (!ALLOWED_TOOL_NAMES.contains(toolName)) {
            throw new IllegalArgumentException("Disallowed auto-tune MCP tool: " + toolName);
        }
        JSONObject arguments = parseArguments(toolCall);
        applySchedulerTriggerType(toolName, triggerType, arguments);
        try {
            Object output = mcpToolManager.callTool(toolName, arguments);
            runState.markToolSuccess(toolName);
            return output;
        } catch (Exception exception) {
            throw new IllegalStateException("Auto-tune MCP tool failed: " + toolName + ", " + exception.getMessage(),
                    exception);
        }
    }
 
    private void applySchedulerTriggerType(String toolName, String triggerType, JSONObject arguments) {
        if (!TOOL_APPLY_CHANGES.equals(toolName)) {
            return;
        }
        if (!AutoTuneTriggerType.AUTO.getCode().equals(triggerType)) {
            return;
        }
        arguments.put("triggerType", AutoTuneTriggerType.AUTO.getCode());
    }
 
    private ChatCompletionRequest.Message buildToolMessage(ChatCompletionRequest.ToolCall toolCall, Object toolOutput) {
        ChatCompletionRequest.Message toolMessage = new ChatCompletionRequest.Message();
        toolMessage.setRole("tool");
        toolMessage.setTool_call_id(toolCall == null ? null : toolCall.getId());
        toolMessage.setContent(JSON.toJSONString(toolOutput));
        return toolMessage;
    }
 
    private String resolveToolName(ChatCompletionRequest.ToolCall toolCall) {
        if (toolCall == null || toolCall.getFunction() == null || isBlank(toolCall.getFunction().getName())) {
            throw new IllegalArgumentException("missing tool name");
        }
        return toolCall.getFunction().getName();
    }
 
    private JSONObject parseArguments(ChatCompletionRequest.ToolCall toolCall) {
        String rawArguments = toolCall == null || toolCall.getFunction() == null
                ? null
                : toolCall.getFunction().getArguments();
        if (isBlank(rawArguments)) {
            return new JSONObject();
        }
        try {
            return JSON.parseObject(rawArguments);
        } catch (Exception exception) {
            JSONObject arguments = new JSONObject();
            arguments.put("_raw", rawArguments);
            return arguments;
        }
    }
 
    private AutoTuneAgentResult buildResult(boolean success,
                                            String triggerType,
                                            StringBuilder summaryBuffer,
                                            RunState runState,
                                            UsageCounter usageCounter,
                                            boolean maxRoundsReached) {
        AutoTuneAgentResult result = new AutoTuneAgentResult();
        result.setSuccess(success);
        result.setTriggerType(triggerType);
        result.setToolCallCount(runState.getToolCallCount());
        result.setLlmCallCount(usageCounter.getLlmCallCount());
        result.setPromptTokens(usageCounter.getPromptTokens());
        result.setCompletionTokens(usageCounter.getCompletionTokens());
        result.setTotalTokens(usageCounter.getTotalTokens());
        result.setMaxRoundsReached(maxRoundsReached);
 
        String summary = summaryBuffer == null ? "" : summaryBuffer.toString().trim();
        if (runState.getToolCallCount() <= 0) {
            summary = "自动调参 Agent 未调用任何允许的 MCP 工具,未执行调参。" + (summary.isEmpty() ? "" : "\n" + summary);
        } else if (!runState.isSnapshotCalled()) {
            summary = summary + "\n自动调参 Agent 未调用快照工具,结果不完整。";
        }
        if (runState.hasToolError()) {
            summary = summary + "\n自动调参 Agent 存在工具调用错误,已标记为失败。";
        }
        if (maxRoundsReached) {
            summary = summary + "\n自动调参 Agent 达到最大工具调用轮次,已停止。";
        }
        result.setSummary(summary);
        return result;
    }
 
    private List<Object> filterAllowedTools(List<Object> tools) {
        List<Object> allowedTools = new ArrayList<>();
        if (tools == null || tools.isEmpty()) {
            return allowedTools;
        }
        for (Object tool : tools) {
            String toolName = resolveOpenAiToolName(tool);
            if (ALLOWED_TOOL_NAMES.contains(toolName)) {
                allowedTools.add(tool);
            }
        }
        return allowedTools;
    }
 
    private String resolveOpenAiToolName(Object tool) {
        if (!(tool instanceof Map<?, ?> toolMap)) {
            return null;
        }
        Object function = toolMap.get("function");
        if (!(function instanceof Map<?, ?> functionMap)) {
            return null;
        }
        Object name = functionMap.get("name");
        return name == null ? null : String.valueOf(name);
    }
 
    private void appendSummary(StringBuilder summaryBuffer, String content) {
        if (summaryBuffer == null || isBlank(content)) {
            return;
        }
        if (summaryBuffer.length() > 0) {
            summaryBuffer.append('\n');
        }
        summaryBuffer.append(content.trim());
    }
 
    private String normalizeTriggerType(String triggerType) {
        return isBlank(triggerType) ? "agent" : triggerType.trim();
    }
 
    private boolean isBlank(String value) {
        return value == null || value.trim().isEmpty();
    }
 
    private static class UsageCounter {
        private long promptTokens;
        private long completionTokens;
        private long totalTokens;
        private int llmCallCount;
 
        void add(ChatCompletionResponse.Usage usage) {
            llmCallCount++;
            if (usage == null) {
                return;
            }
            promptTokens += usage.getPromptTokens() == null ? 0L : usage.getPromptTokens();
            completionTokens += usage.getCompletionTokens() == null ? 0L : usage.getCompletionTokens();
            totalTokens += usage.getTotalTokens() == null ? 0L : usage.getTotalTokens();
        }
 
        long getPromptTokens() {
            return promptTokens;
        }
 
        long getCompletionTokens() {
            return completionTokens;
        }
 
        long getTotalTokens() {
            return totalTokens;
        }
 
        int getLlmCallCount() {
            return llmCallCount;
        }
    }
 
    private static class RunState {
        private int toolCallCount;
        private boolean snapshotCalled;
        private boolean toolError;
 
        void markToolSuccess(String toolName) {
            toolCallCount++;
            if (TOOL_GET_SNAPSHOT.equals(toolName)) {
                snapshotCalled = true;
            }
        }
 
        void markToolError() {
            toolError = true;
        }
 
        boolean isSuccessful() {
            return toolCallCount > 0 && snapshotCalled && !toolError;
        }
 
        int getToolCallCount() {
            return toolCallCount;
        }
 
        boolean isSnapshotCalled() {
            return snapshotCalled;
        }
 
        boolean hasToolError() {
            return toolError;
        }
    }
}