Junjie
2026-04-27 913c02b7949bbb1eae9340e7d3cd05a148487152
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
package com.zy.ai.service.impl;
 
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.zy.ai.entity.AiPromptTemplate;
import com.zy.ai.entity.ChatCompletionRequest;
import com.zy.ai.entity.ChatCompletionResponse;
import com.zy.ai.enums.AiPromptScene;
import com.zy.ai.mcp.service.SpringAiMcpToolManager;
import com.zy.ai.service.AiPromptTemplateService;
import com.zy.ai.service.AutoTuneAgentService;
import com.zy.ai.service.LlmChatService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
 
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
 
@Slf4j
@Service
@RequiredArgsConstructor
public class AutoTuneAgentServiceImpl implements AutoTuneAgentService {
 
    private static final int MAX_TOOL_ROUNDS = 10;
    private static final double TEMPERATURE = 0.2D;
    private static final int MAX_TOKENS = 2048;
 
    private final LlmChatService llmChatService;
    private final SpringAiMcpToolManager mcpToolManager;
    private final AiPromptTemplateService aiPromptTemplateService;
 
    @Override
    public AutoTuneAgentResult runAutoTune(String triggerType) {
        String normalizedTriggerType = normalizeTriggerType(triggerType);
        UsageCounter usageCounter = new UsageCounter();
        int toolCallCount = 0;
        boolean maxRoundsReached = false;
        StringBuilder summaryBuffer = new StringBuilder();
 
        try {
            List<Object> tools = mcpToolManager.buildOpenAiTools();
            if (tools == null || tools.isEmpty()) {
                throw new IllegalStateException("No MCP tools registered");
            }
 
            AiPromptTemplate promptTemplate = aiPromptTemplateService.resolvePublished(AiPromptScene.AUTO_TUNE_DISPATCH.getCode());
            List<ChatCompletionRequest.Message> messages = buildMessages(promptTemplate, normalizedTriggerType);
 
            for (int round = 0; round < MAX_TOOL_ROUNDS; round++) {
                ChatCompletionResponse response = llmChatService.chatCompletion(messages, TEMPERATURE, MAX_TOKENS, tools);
                ChatCompletionRequest.Message assistantMessage = extractAssistantMessage(response);
                usageCounter.add(response.getUsage());
                messages.add(assistantMessage);
                appendSummary(summaryBuffer, assistantMessage.getContent());
 
                List<ChatCompletionRequest.ToolCall> toolCalls = assistantMessage.getTool_calls();
                if (toolCalls == null || toolCalls.isEmpty()) {
                    return buildResult(true, normalizedTriggerType, summaryBuffer, toolCallCount, usageCounter, false);
                }
 
                for (ChatCompletionRequest.ToolCall toolCall : toolCalls) {
                    Object toolOutput = callMountedTool(toolCall);
                    toolCallCount++;
                    messages.add(buildToolMessage(toolCall, toolOutput));
                }
            }
            maxRoundsReached = true;
            return buildResult(true, normalizedTriggerType, summaryBuffer, toolCallCount, usageCounter, maxRoundsReached);
        } catch (Exception exception) {
            log.error("Auto tune agent stopped with error", exception);
            appendSummary(summaryBuffer, "自动调参 Agent 执行异常: " + exception.getMessage());
            return buildResult(false, normalizedTriggerType, summaryBuffer, toolCallCount, usageCounter, maxRoundsReached);
        }
    }
 
    private List<ChatCompletionRequest.Message> buildMessages(AiPromptTemplate promptTemplate, String triggerType) {
        List<ChatCompletionRequest.Message> messages = new ArrayList<>();
 
        ChatCompletionRequest.Message systemMessage = new ChatCompletionRequest.Message();
        systemMessage.setRole("system");
        systemMessage.setContent(promptTemplate == null ? "" : promptTemplate.getContent());
        messages.add(systemMessage);
 
        ChatCompletionRequest.Message userMessage = new ChatCompletionRequest.Message();
        userMessage.setRole("user");
        userMessage.setContent("请执行一次后台 WCS 自动调参。triggerType=" + triggerType
                + "。必须先调用 wcs_local_dispatch_get_auto_tune_snapshot 获取事实;如需提交变更,"
                + "必须先 dry-run,再根据 dry-run 结果决定是否实际应用。不要输出自由格式 JSON 供外层解析。");
        messages.add(userMessage);
        return messages;
    }
 
    private ChatCompletionRequest.Message extractAssistantMessage(ChatCompletionResponse response) {
        if (response == null || response.getChoices() == null || response.getChoices().isEmpty()) {
            throw new IllegalStateException("LLM returned empty response");
        }
        ChatCompletionRequest.Message message = response.getChoices().get(0).getMessage();
        if (message == null) {
            throw new IllegalStateException("LLM returned empty message");
        }
        return message;
    }
 
    private Object callMountedTool(ChatCompletionRequest.ToolCall toolCall) {
        String toolName = resolveToolName(toolCall);
        JSONObject arguments = parseArguments(toolCall);
        try {
            return mcpToolManager.callTool(toolName, arguments);
        } catch (Exception exception) {
            LinkedHashMap<String, Object> error = new LinkedHashMap<>();
            error.put("tool", toolName);
            error.put("error", exception.getMessage());
            return error;
        }
    }
 
    private ChatCompletionRequest.Message buildToolMessage(ChatCompletionRequest.ToolCall toolCall, Object toolOutput) {
        ChatCompletionRequest.Message toolMessage = new ChatCompletionRequest.Message();
        toolMessage.setRole("tool");
        toolMessage.setTool_call_id(toolCall == null ? null : toolCall.getId());
        toolMessage.setContent(JSON.toJSONString(toolOutput));
        return toolMessage;
    }
 
    private String resolveToolName(ChatCompletionRequest.ToolCall toolCall) {
        if (toolCall == null || toolCall.getFunction() == null || isBlank(toolCall.getFunction().getName())) {
            throw new IllegalArgumentException("missing tool name");
        }
        return toolCall.getFunction().getName();
    }
 
    private JSONObject parseArguments(ChatCompletionRequest.ToolCall toolCall) {
        String rawArguments = toolCall == null || toolCall.getFunction() == null
                ? null
                : toolCall.getFunction().getArguments();
        if (isBlank(rawArguments)) {
            return new JSONObject();
        }
        try {
            return JSON.parseObject(rawArguments);
        } catch (Exception exception) {
            JSONObject arguments = new JSONObject();
            arguments.put("_raw", rawArguments);
            return arguments;
        }
    }
 
    private AutoTuneAgentResult buildResult(boolean success,
                                            String triggerType,
                                            StringBuilder summaryBuffer,
                                            int toolCallCount,
                                            UsageCounter usageCounter,
                                            boolean maxRoundsReached) {
        AutoTuneAgentResult result = new AutoTuneAgentResult();
        result.setSuccess(success);
        result.setTriggerType(triggerType);
        result.setToolCallCount(toolCallCount);
        result.setLlmCallCount(usageCounter.getLlmCallCount());
        result.setPromptTokens(usageCounter.getPromptTokens());
        result.setCompletionTokens(usageCounter.getCompletionTokens());
        result.setTotalTokens(usageCounter.getTotalTokens());
        result.setMaxRoundsReached(maxRoundsReached);
 
        String summary = summaryBuffer == null ? "" : summaryBuffer.toString().trim();
        if (toolCallCount <= 0 && success) {
            summary = "自动调参 Agent 未调用任何 MCP 工具,未执行调参。" + (summary.isEmpty() ? "" : "\n" + summary);
        }
        if (maxRoundsReached) {
            summary = summary + "\n自动调参 Agent 达到最大工具调用轮次,已停止。";
        }
        result.setSummary(summary);
        return result;
    }
 
    private void appendSummary(StringBuilder summaryBuffer, String content) {
        if (summaryBuffer == null || isBlank(content)) {
            return;
        }
        if (summaryBuffer.length() > 0) {
            summaryBuffer.append('\n');
        }
        summaryBuffer.append(content.trim());
    }
 
    private String normalizeTriggerType(String triggerType) {
        return isBlank(triggerType) ? "agent" : triggerType.trim();
    }
 
    private boolean isBlank(String value) {
        return value == null || value.trim().isEmpty();
    }
 
    private static class UsageCounter {
        private long promptTokens;
        private long completionTokens;
        private long totalTokens;
        private int llmCallCount;
 
        void add(ChatCompletionResponse.Usage usage) {
            llmCallCount++;
            if (usage == null) {
                return;
            }
            promptTokens += usage.getPromptTokens() == null ? 0L : usage.getPromptTokens();
            completionTokens += usage.getCompletionTokens() == null ? 0L : usage.getCompletionTokens();
            totalTokens += usage.getTotalTokens() == null ? 0L : usage.getTotalTokens();
        }
 
        long getPromptTokens() {
            return promptTokens;
        }
 
        long getCompletionTokens() {
            return completionTokens;
        }
 
        long getTotalTokens() {
            return totalTokens;
        }
 
        int getLlmCallCount() {
            return llmCallCount;
        }
    }
}