From e4e91b46d0ce781e7dc87dcdf0d2909b01911d4b Mon Sep 17 00:00:00 2001
From: Junjie <fallin.jie@qq.com>
Date: 星期一, 27 四月 2026 12:34:31 +0800
Subject: [PATCH] fix: harden auto tune scheduler throttling
---
src/main/java/com/zy/ai/service/impl/AutoTuneAgentServiceImpl.java | 138 ++++++++++++++++++++++++++++++++++++++-------
1 files changed, 115 insertions(+), 23 deletions(-)
diff --git a/src/main/java/com/zy/ai/service/impl/AutoTuneAgentServiceImpl.java b/src/main/java/com/zy/ai/service/impl/AutoTuneAgentServiceImpl.java
index b7a20f4..fe0ca9b 100644
--- a/src/main/java/com/zy/ai/service/impl/AutoTuneAgentServiceImpl.java
+++ b/src/main/java/com/zy/ai/service/impl/AutoTuneAgentServiceImpl.java
@@ -2,6 +2,7 @@
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
+import com.zy.ai.domain.autotune.AutoTuneTriggerType;
import com.zy.ai.entity.AiPromptTemplate;
import com.zy.ai.entity.ChatCompletionRequest;
import com.zy.ai.entity.ChatCompletionResponse;
@@ -15,8 +16,9 @@
import org.springframework.stereotype.Service;
import java.util.ArrayList;
-import java.util.LinkedHashMap;
import java.util.List;
+import java.util.Map;
+import java.util.Set;
@Slf4j
@Service
@@ -26,6 +28,16 @@
private static final int MAX_TOOL_ROUNDS = 10;
private static final double TEMPERATURE = 0.2D;
private static final int MAX_TOKENS = 2048;
+ private static final String TOOL_GET_SNAPSHOT = "wcs_local_dispatch_get_auto_tune_snapshot";
+ private static final String TOOL_GET_RECENT_JOBS = "wcs_local_dispatch_get_recent_auto_tune_jobs";
+ private static final String TOOL_APPLY_CHANGES = "wcs_local_dispatch_apply_auto_tune_changes";
+ private static final String TOOL_REVERT_LAST_JOB = "wcs_local_dispatch_revert_last_auto_tune_job";
+ private static final Set<String> ALLOWED_TOOL_NAMES = Set.of(
+ TOOL_GET_SNAPSHOT,
+ TOOL_GET_RECENT_JOBS,
+ TOOL_APPLY_CHANGES,
+ TOOL_REVERT_LAST_JOB
+ );
private final LlmChatService llmChatService;
private final SpringAiMcpToolManager mcpToolManager;
@@ -35,14 +47,14 @@
public AutoTuneAgentResult runAutoTune(String triggerType) {
String normalizedTriggerType = normalizeTriggerType(triggerType);
UsageCounter usageCounter = new UsageCounter();
- int toolCallCount = 0;
+ RunState runState = new RunState();
boolean maxRoundsReached = false;
StringBuilder summaryBuffer = new StringBuilder();
try {
- List<Object> tools = mcpToolManager.buildOpenAiTools();
+ List<Object> tools = filterAllowedTools(mcpToolManager.buildOpenAiTools());
if (tools == null || tools.isEmpty()) {
- throw new IllegalStateException("No MCP tools registered");
+ throw new IllegalStateException("No auto-tune MCP tools registered");
}
AiPromptTemplate promptTemplate = aiPromptTemplateService.resolvePublished(AiPromptScene.AUTO_TUNE_DISPATCH.getCode());
@@ -57,21 +69,22 @@
List<ChatCompletionRequest.ToolCall> toolCalls = assistantMessage.getTool_calls();
if (toolCalls == null || toolCalls.isEmpty()) {
- return buildResult(true, normalizedTriggerType, summaryBuffer, toolCallCount, usageCounter, false);
+ return buildResult(runState.isSuccessful(), normalizedTriggerType, summaryBuffer, runState,
+ usageCounter, false);
}
for (ChatCompletionRequest.ToolCall toolCall : toolCalls) {
- Object toolOutput = callMountedTool(toolCall);
- toolCallCount++;
+ Object toolOutput = callMountedTool(toolCall, runState, normalizedTriggerType);
messages.add(buildToolMessage(toolCall, toolOutput));
}
}
maxRoundsReached = true;
- return buildResult(true, normalizedTriggerType, summaryBuffer, toolCallCount, usageCounter, maxRoundsReached);
+ return buildResult(false, normalizedTriggerType, summaryBuffer, runState, usageCounter, maxRoundsReached);
} catch (Exception exception) {
log.error("Auto tune agent stopped with error", exception);
appendSummary(summaryBuffer, "鑷姩璋冨弬 Agent 鎵ц寮傚父: " + exception.getMessage());
- return buildResult(false, normalizedTriggerType, summaryBuffer, toolCallCount, usageCounter, maxRoundsReached);
+ runState.markToolError();
+ return buildResult(false, normalizedTriggerType, summaryBuffer, runState, usageCounter, maxRoundsReached);
}
}
@@ -87,7 +100,8 @@
userMessage.setRole("user");
userMessage.setContent("璇锋墽琛屼竴娆″悗鍙� WCS 鑷姩璋冨弬銆倀riggerType=" + triggerType
+ "銆傚繀椤诲厛璋冪敤 wcs_local_dispatch_get_auto_tune_snapshot 鑾峰彇浜嬪疄锛涘闇�鎻愪氦鍙樻洿锛�"
- + "蹇呴』鍏� dry-run锛屽啀鏍规嵁 dry-run 缁撴灉鍐冲畾鏄惁瀹為檯搴旂敤銆備笉瑕佽緭鍑鸿嚜鐢辨牸寮� JSON 渚涘灞傝В鏋愩��");
+ + "蹇呴』鍏� dry-run锛屽啀鏍规嵁 dry-run 缁撴灉鍐冲畾鏄惁瀹為檯搴旂敤锛涘疄闄呭簲鐢ㄦ椂蹇呴』甯︿笂 dry-run 杩斿洖鐨� dryRunToken銆�"
+ + "涓嶈杈撳嚭鑷敱鏍煎紡 JSON 渚涘灞傝В鏋愩��");
messages.add(userMessage);
return messages;
}
@@ -103,17 +117,31 @@
return message;
}
- private Object callMountedTool(ChatCompletionRequest.ToolCall toolCall) {
+ private Object callMountedTool(ChatCompletionRequest.ToolCall toolCall, RunState runState, String triggerType) {
String toolName = resolveToolName(toolCall);
- JSONObject arguments = parseArguments(toolCall);
- try {
- return mcpToolManager.callTool(toolName, arguments);
- } catch (Exception exception) {
- LinkedHashMap<String, Object> error = new LinkedHashMap<>();
- error.put("tool", toolName);
- error.put("error", exception.getMessage());
- return error;
+ if (!ALLOWED_TOOL_NAMES.contains(toolName)) {
+ throw new IllegalArgumentException("Disallowed auto-tune MCP tool: " + toolName);
}
+ JSONObject arguments = parseArguments(toolCall);
+ applySchedulerTriggerType(toolName, triggerType, arguments);
+ try {
+ Object output = mcpToolManager.callTool(toolName, arguments);
+ runState.markToolSuccess(toolName);
+ return output;
+ } catch (Exception exception) {
+ throw new IllegalStateException("Auto-tune MCP tool failed: " + toolName + ", " + exception.getMessage(),
+ exception);
+ }
+ }
+
+ private void applySchedulerTriggerType(String toolName, String triggerType, JSONObject arguments) {
+ if (!TOOL_APPLY_CHANGES.equals(toolName)) {
+ return;
+ }
+ if (!AutoTuneTriggerType.AUTO.getCode().equals(triggerType)) {
+ return;
+ }
+ arguments.put("triggerType", AutoTuneTriggerType.AUTO.getCode());
}
private ChatCompletionRequest.Message buildToolMessage(ChatCompletionRequest.ToolCall toolCall, Object toolOutput) {
@@ -150,13 +178,13 @@
private AutoTuneAgentResult buildResult(boolean success,
String triggerType,
StringBuilder summaryBuffer,
- int toolCallCount,
+ RunState runState,
UsageCounter usageCounter,
boolean maxRoundsReached) {
AutoTuneAgentResult result = new AutoTuneAgentResult();
result.setSuccess(success);
result.setTriggerType(triggerType);
- result.setToolCallCount(toolCallCount);
+ result.setToolCallCount(runState.getToolCallCount());
result.setLlmCallCount(usageCounter.getLlmCallCount());
result.setPromptTokens(usageCounter.getPromptTokens());
result.setCompletionTokens(usageCounter.getCompletionTokens());
@@ -164,14 +192,45 @@
result.setMaxRoundsReached(maxRoundsReached);
String summary = summaryBuffer == null ? "" : summaryBuffer.toString().trim();
- if (toolCallCount <= 0 && success) {
- summary = "鑷姩璋冨弬 Agent 鏈皟鐢ㄤ换浣� MCP 宸ュ叿锛屾湭鎵ц璋冨弬銆�" + (summary.isEmpty() ? "" : "\n" + summary);
+ if (runState.getToolCallCount() <= 0) {
+ summary = "鑷姩璋冨弬 Agent 鏈皟鐢ㄤ换浣曞厑璁哥殑 MCP 宸ュ叿锛屾湭鎵ц璋冨弬銆�" + (summary.isEmpty() ? "" : "\n" + summary);
+ } else if (!runState.isSnapshotCalled()) {
+ summary = summary + "\n鑷姩璋冨弬 Agent 鏈皟鐢ㄥ揩鐓у伐鍏凤紝缁撴灉涓嶅畬鏁淬��";
+ }
+ if (runState.hasToolError()) {
+ summary = summary + "\n鑷姩璋冨弬 Agent 瀛樺湪宸ュ叿璋冪敤閿欒锛屽凡鏍囪涓哄け璐ャ��";
}
if (maxRoundsReached) {
summary = summary + "\n鑷姩璋冨弬 Agent 杈惧埌鏈�澶у伐鍏疯皟鐢ㄨ疆娆★紝宸插仠姝€��";
}
result.setSummary(summary);
return result;
+ }
+
+ private List<Object> filterAllowedTools(List<Object> tools) {
+ List<Object> allowedTools = new ArrayList<>();
+ if (tools == null || tools.isEmpty()) {
+ return allowedTools;
+ }
+ for (Object tool : tools) {
+ String toolName = resolveOpenAiToolName(tool);
+ if (ALLOWED_TOOL_NAMES.contains(toolName)) {
+ allowedTools.add(tool);
+ }
+ }
+ return allowedTools;
+ }
+
+ private String resolveOpenAiToolName(Object tool) {
+ if (!(tool instanceof Map<?, ?> toolMap)) {
+ return null;
+ }
+ Object function = toolMap.get("function");
+ if (!(function instanceof Map<?, ?> functionMap)) {
+ return null;
+ }
+ Object name = functionMap.get("name");
+ return name == null ? null : String.valueOf(name);
}
private void appendSummary(StringBuilder summaryBuffer, String content) {
@@ -224,4 +283,37 @@
return llmCallCount;
}
}
+
+ private static class RunState {
+ private int toolCallCount;
+ private boolean snapshotCalled;
+ private boolean toolError;
+
+ void markToolSuccess(String toolName) {
+ toolCallCount++;
+ if (TOOL_GET_SNAPSHOT.equals(toolName)) {
+ snapshotCalled = true;
+ }
+ }
+
+ void markToolError() {
+ toolError = true;
+ }
+
+ boolean isSuccessful() {
+ return toolCallCount > 0 && snapshotCalled && !toolError;
+ }
+
+ int getToolCallCount() {
+ return toolCallCount;
+ }
+
+ boolean isSnapshotCalled() {
+ return snapshotCalled;
+ }
+
+ boolean hasToolError() {
+ return toolError;
+ }
+ }
}
--
Gitblit v1.9.1