Junjie
2026-04-27 10bdc4b6e9701befd1a83bccd2998dcc96cb2c43
fix: enforce auto tune agent tool safety
4个文件已修改
412 ■■■■■ 已修改文件
src/main/java/com/zy/ai/domain/autotune/AutoTuneApplyResult.java 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
src/main/java/com/zy/ai/mcp/tool/AutoTuneMcpTools.java 114 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
src/main/java/com/zy/ai/service/impl/AutoTuneAgentServiceImpl.java 120 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
src/test/java/com/zy/ai/service/AutoTuneCoordinatorServiceImplTest.java 176 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
src/main/java/com/zy/ai/domain/autotune/AutoTuneApplyResult.java
@@ -18,6 +18,8 @@
    private String summary;
    private String dryRunToken;
    private Integer successCount;
    private Integer rejectCount;
src/main/java/com/zy/ai/mcp/tool/AutoTuneMcpTools.java
@@ -1,5 +1,6 @@
package com.zy.ai.mcp.tool;
import com.alibaba.fastjson.JSON;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.zy.ai.domain.autotune.AutoTuneApplyRequest;
import com.zy.ai.domain.autotune.AutoTuneApplyResult;
@@ -17,9 +18,14 @@
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
@Component
@RequiredArgsConstructor
@@ -27,11 +33,13 @@
    private static final int DEFAULT_RECENT_JOB_LIMIT = 5;
    private static final int MAX_RECENT_JOB_LIMIT = 20;
    private static final long DRY_RUN_TOKEN_TTL_MILLIS = 10L * 60L * 1000L;
    private final AutoTuneSnapshotService autoTuneSnapshotService;
    private final AutoTuneApplyService autoTuneApplyService;
    private final AiAutoTuneJobService aiAutoTuneJobService;
    private final AiAutoTuneChangeService aiAutoTuneChangeService;
    private final ConcurrentMap<String, DryRunPreview> dryRunPreviews = new ConcurrentHashMap<>();
    @Tool(name = "dispatch_get_auto_tune_snapshot", description = "获取WCS自动调参所需的调度快照、站点运行态、拓扑容量和当前可写参数")
    public AutoTuneSnapshot getAutoTuneSnapshot() {
@@ -63,14 +71,27 @@
            @ToolParam(description = "建议自动调参分析间隔分钟", required = false) Integer analysisIntervalMinutes,
            @ToolParam(description = "触发类型,例如 scheduler/manual/agent", required = false) String triggerType,
            @ToolParam(description = "是否仅试算,实际应用前必须先传 true", required = false) Boolean dryRun,
            @ToolParam(description = "dry-run 成功后返回的预览令牌。dryRun=false 时必须提供,且变更集必须完全一致", required = false) String dryRunToken,
            @ToolParam(description = "调参变更列表") List<AutoTuneChangeCommand> changes) {
        if (dryRun == null) {
            throw new IllegalArgumentException("dryRun is required. Use dryRun=true first to create a preview token.");
        }
        String fingerprint = buildChangeFingerprint(changes);
        if (Boolean.FALSE.equals(dryRun)) {
            requireMatchingDryRunToken(dryRunToken, fingerprint);
        }
        AutoTuneApplyRequest request = new AutoTuneApplyRequest();
        request.setReason(reason);
        request.setAnalysisIntervalMinutes(analysisIntervalMinutes);
        request.setTriggerType(triggerType);
        request.setDryRun(dryRun);
        request.setChanges(changes);
        return autoTuneApplyService.apply(request);
        AutoTuneApplyResult result = autoTuneApplyService.apply(request);
        if (Boolean.TRUE.equals(dryRun) && isSuccessful(result)) {
            result.setDryRunToken(createDryRunToken(fingerprint));
        }
        return result;
    }
    @Tool(name = "dispatch_revert_last_auto_tune_job", description = "回滚最近一次成功的自动调参任务")
@@ -133,4 +154,95 @@
        }
        return Math.min(limit, MAX_RECENT_JOB_LIMIT);
    }
    private void requireMatchingDryRunToken(String dryRunToken, String fingerprint) {
        cleanExpiredDryRunPreviews();
        if (isBlank(dryRunToken)) {
            throw new IllegalArgumentException("dryRunToken is required when dryRun=false. Run dryRun=true first.");
        }
        DryRunPreview preview = dryRunPreviews.remove(dryRunToken.trim());
        if (preview == null) {
            throw new IllegalArgumentException("dryRunToken is missing, expired, or already used.");
        }
        if (preview.isExpired()) {
            throw new IllegalArgumentException("dryRunToken is expired. Run dryRun=true again.");
        }
        if (!preview.getFingerprint().equals(fingerprint)) {
            throw new IllegalArgumentException("dryRunToken does not match the requested change set.");
        }
    }
    private String createDryRunToken(String fingerprint) {
        cleanExpiredDryRunPreviews();
        String token = UUID.randomUUID().toString();
        dryRunPreviews.put(token, new DryRunPreview(fingerprint, System.currentTimeMillis() + DRY_RUN_TOKEN_TTL_MILLIS));
        return token;
    }
    private void cleanExpiredDryRunPreviews() {
        for (Map.Entry<String, DryRunPreview> entry : dryRunPreviews.entrySet()) {
            if (entry.getValue() == null || entry.getValue().isExpired()) {
                dryRunPreviews.remove(entry.getKey());
            }
        }
    }
    private boolean isSuccessful(AutoTuneApplyResult result) {
        return result != null && Boolean.TRUE.equals(result.getSuccess());
    }
    private String buildChangeFingerprint(List<AutoTuneChangeCommand> changes) {
        List<Map<String, String>> normalizedChanges = new ArrayList<>();
        if (changes != null) {
            for (AutoTuneChangeCommand change : changes) {
                normalizedChanges.add(toNormalizedChange(change));
            }
        }
        normalizedChanges.sort(Comparator
                .comparing((Map<String, String> item) -> item.get("targetType"))
                .thenComparing(item -> item.get("targetId"))
                .thenComparing(item -> item.get("targetKey"))
                .thenComparing(item -> item.get("newValue")));
        return JSON.toJSONString(normalizedChanges);
    }
    private Map<String, String> toNormalizedChange(AutoTuneChangeCommand change) {
        LinkedHashMap<String, String> item = new LinkedHashMap<>();
        String targetType = normalizeLower(change == null ? null : change.getTargetType());
        item.put("targetType", targetType);
        item.put("targetId", "sys_config".equals(targetType) ? "" : normalizeText(change == null ? null : change.getTargetId()));
        item.put("targetKey", normalizeText(change == null ? null : change.getTargetKey()));
        item.put("newValue", normalizeText(change == null ? null : change.getNewValue()));
        return item;
    }
    private String normalizeLower(String value) {
        return normalizeText(value).toLowerCase(Locale.ROOT);
    }
    private String normalizeText(String value) {
        return value == null ? "" : value.trim();
    }
    private boolean isBlank(String value) {
        return value == null || value.trim().isEmpty();
    }
    private static class DryRunPreview {
        private final String fingerprint;
        private final long expireAtMillis;
        DryRunPreview(String fingerprint, long expireAtMillis) {
            this.fingerprint = fingerprint;
            this.expireAtMillis = expireAtMillis;
        }
        String getFingerprint() {
            return fingerprint;
        }
        boolean isExpired() {
            return System.currentTimeMillis() > expireAtMillis;
        }
    }
}
src/main/java/com/zy/ai/service/impl/AutoTuneAgentServiceImpl.java
@@ -15,8 +15,9 @@
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
@Slf4j
@Service
@@ -26,6 +27,16 @@
    private static final int MAX_TOOL_ROUNDS = 10;
    private static final double TEMPERATURE = 0.2D;
    private static final int MAX_TOKENS = 2048;
    private static final String TOOL_GET_SNAPSHOT = "wcs_local_dispatch_get_auto_tune_snapshot";
    private static final String TOOL_GET_RECENT_JOBS = "wcs_local_dispatch_get_recent_auto_tune_jobs";
    private static final String TOOL_APPLY_CHANGES = "wcs_local_dispatch_apply_auto_tune_changes";
    private static final String TOOL_REVERT_LAST_JOB = "wcs_local_dispatch_revert_last_auto_tune_job";
    private static final Set<String> ALLOWED_TOOL_NAMES = Set.of(
            TOOL_GET_SNAPSHOT,
            TOOL_GET_RECENT_JOBS,
            TOOL_APPLY_CHANGES,
            TOOL_REVERT_LAST_JOB
    );
    private final LlmChatService llmChatService;
    private final SpringAiMcpToolManager mcpToolManager;
@@ -35,14 +46,14 @@
    public AutoTuneAgentResult runAutoTune(String triggerType) {
        String normalizedTriggerType = normalizeTriggerType(triggerType);
        UsageCounter usageCounter = new UsageCounter();
        int toolCallCount = 0;
        RunState runState = new RunState();
        boolean maxRoundsReached = false;
        StringBuilder summaryBuffer = new StringBuilder();
        try {
            List<Object> tools = mcpToolManager.buildOpenAiTools();
            List<Object> tools = filterAllowedTools(mcpToolManager.buildOpenAiTools());
            if (tools == null || tools.isEmpty()) {
                throw new IllegalStateException("No MCP tools registered");
                throw new IllegalStateException("No auto-tune MCP tools registered");
            }
            AiPromptTemplate promptTemplate = aiPromptTemplateService.resolvePublished(AiPromptScene.AUTO_TUNE_DISPATCH.getCode());
@@ -57,21 +68,22 @@
                List<ChatCompletionRequest.ToolCall> toolCalls = assistantMessage.getTool_calls();
                if (toolCalls == null || toolCalls.isEmpty()) {
                    return buildResult(true, normalizedTriggerType, summaryBuffer, toolCallCount, usageCounter, false);
                    return buildResult(runState.isSuccessful(), normalizedTriggerType, summaryBuffer, runState,
                            usageCounter, false);
                }
                for (ChatCompletionRequest.ToolCall toolCall : toolCalls) {
                    Object toolOutput = callMountedTool(toolCall);
                    toolCallCount++;
                    Object toolOutput = callMountedTool(toolCall, runState);
                    messages.add(buildToolMessage(toolCall, toolOutput));
                }
            }
            maxRoundsReached = true;
            return buildResult(true, normalizedTriggerType, summaryBuffer, toolCallCount, usageCounter, maxRoundsReached);
            return buildResult(false, normalizedTriggerType, summaryBuffer, runState, usageCounter, maxRoundsReached);
        } catch (Exception exception) {
            log.error("Auto tune agent stopped with error", exception);
            appendSummary(summaryBuffer, "自动调参 Agent 执行异常: " + exception.getMessage());
            return buildResult(false, normalizedTriggerType, summaryBuffer, toolCallCount, usageCounter, maxRoundsReached);
            runState.markToolError();
            return buildResult(false, normalizedTriggerType, summaryBuffer, runState, usageCounter, maxRoundsReached);
        }
    }
@@ -87,7 +99,8 @@
        userMessage.setRole("user");
        userMessage.setContent("请执行一次后台 WCS 自动调参。triggerType=" + triggerType
                + "。必须先调用 wcs_local_dispatch_get_auto_tune_snapshot 获取事实;如需提交变更,"
                + "必须先 dry-run,再根据 dry-run 结果决定是否实际应用。不要输出自由格式 JSON 供外层解析。");
                + "必须先 dry-run,再根据 dry-run 结果决定是否实际应用;实际应用时必须带上 dry-run 返回的 dryRunToken。"
                + "不要输出自由格式 JSON 供外层解析。");
        messages.add(userMessage);
        return messages;
    }
@@ -103,16 +116,19 @@
        return message;
    }
    private Object callMountedTool(ChatCompletionRequest.ToolCall toolCall) {
    private Object callMountedTool(ChatCompletionRequest.ToolCall toolCall, RunState runState) {
        String toolName = resolveToolName(toolCall);
        if (!ALLOWED_TOOL_NAMES.contains(toolName)) {
            throw new IllegalArgumentException("Disallowed auto-tune MCP tool: " + toolName);
        }
        JSONObject arguments = parseArguments(toolCall);
        try {
            return mcpToolManager.callTool(toolName, arguments);
            Object output = mcpToolManager.callTool(toolName, arguments);
            runState.markToolSuccess(toolName);
            return output;
        } catch (Exception exception) {
            LinkedHashMap<String, Object> error = new LinkedHashMap<>();
            error.put("tool", toolName);
            error.put("error", exception.getMessage());
            return error;
            throw new IllegalStateException("Auto-tune MCP tool failed: " + toolName + ", " + exception.getMessage(),
                    exception);
        }
    }
@@ -150,13 +166,13 @@
    private AutoTuneAgentResult buildResult(boolean success,
                                            String triggerType,
                                            StringBuilder summaryBuffer,
                                            int toolCallCount,
                                            RunState runState,
                                            UsageCounter usageCounter,
                                            boolean maxRoundsReached) {
        AutoTuneAgentResult result = new AutoTuneAgentResult();
        result.setSuccess(success);
        result.setTriggerType(triggerType);
        result.setToolCallCount(toolCallCount);
        result.setToolCallCount(runState.getToolCallCount());
        result.setLlmCallCount(usageCounter.getLlmCallCount());
        result.setPromptTokens(usageCounter.getPromptTokens());
        result.setCompletionTokens(usageCounter.getCompletionTokens());
@@ -164,14 +180,45 @@
        result.setMaxRoundsReached(maxRoundsReached);
        String summary = summaryBuffer == null ? "" : summaryBuffer.toString().trim();
        if (toolCallCount <= 0 && success) {
            summary = "自动调参 Agent 未调用任何 MCP 工具,未执行调参。" + (summary.isEmpty() ? "" : "\n" + summary);
        if (runState.getToolCallCount() <= 0) {
            summary = "自动调参 Agent 未调用任何允许的 MCP 工具,未执行调参。" + (summary.isEmpty() ? "" : "\n" + summary);
        } else if (!runState.isSnapshotCalled()) {
            summary = summary + "\n自动调参 Agent 未调用快照工具,结果不完整。";
        }
        if (runState.hasToolError()) {
            summary = summary + "\n自动调参 Agent 存在工具调用错误,已标记为失败。";
        }
        if (maxRoundsReached) {
            summary = summary + "\n自动调参 Agent 达到最大工具调用轮次,已停止。";
        }
        result.setSummary(summary);
        return result;
    }
    private List<Object> filterAllowedTools(List<Object> tools) {
        List<Object> allowedTools = new ArrayList<>();
        if (tools == null || tools.isEmpty()) {
            return allowedTools;
        }
        for (Object tool : tools) {
            String toolName = resolveOpenAiToolName(tool);
            if (ALLOWED_TOOL_NAMES.contains(toolName)) {
                allowedTools.add(tool);
            }
        }
        return allowedTools;
    }
    private String resolveOpenAiToolName(Object tool) {
        if (!(tool instanceof Map<?, ?> toolMap)) {
            return null;
        }
        Object function = toolMap.get("function");
        if (!(function instanceof Map<?, ?> functionMap)) {
            return null;
        }
        Object name = functionMap.get("name");
        return name == null ? null : String.valueOf(name);
    }
    private void appendSummary(StringBuilder summaryBuffer, String content) {
@@ -224,4 +271,37 @@
            return llmCallCount;
        }
    }
    private static class RunState {
        private int toolCallCount;
        private boolean snapshotCalled;
        private boolean toolError;
        void markToolSuccess(String toolName) {
            toolCallCount++;
            if (TOOL_GET_SNAPSHOT.equals(toolName)) {
                snapshotCalled = true;
            }
        }
        void markToolError() {
            toolError = true;
        }
        boolean isSuccessful() {
            return toolCallCount > 0 && snapshotCalled && !toolError;
        }
        int getToolCallCount() {
            return toolCallCount;
        }
        boolean isSnapshotCalled() {
            return snapshotCalled;
        }
        boolean hasToolError() {
            return toolError;
        }
    }
}
src/test/java/com/zy/ai/service/AutoTuneCoordinatorServiceImplTest.java
@@ -23,16 +23,21 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyDouble;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@@ -112,6 +117,7 @@
    void applyToolDelegatesToApplyServiceWithDryRunAndChanges() {
        AutoTuneApplyResult expected = new AutoTuneApplyResult();
        expected.setDryRun(true);
        expected.setSuccess(true);
        when(autoTuneApplyService.apply(any(AutoTuneApplyRequest.class))).thenReturn(expected);
        AutoTuneChangeCommand command = new AutoTuneChangeCommand();
@@ -120,9 +126,10 @@
        command.setNewValue("12");
        List<AutoTuneChangeCommand> changes = Collections.singletonList(command);
        AutoTuneApplyResult result = tools.applyAutoTuneChanges("reduce congestion", 10, "agent", true, changes);
        AutoTuneApplyResult result = tools.applyAutoTuneChanges("reduce congestion", 10, "agent", true, null, changes);
        assertSame(expected, result);
        assertNotNull(result.getDryRunToken());
        ArgumentCaptor<AutoTuneApplyRequest> captor = ArgumentCaptor.forClass(AutoTuneApplyRequest.class);
        verify(autoTuneApplyService).apply(captor.capture());
        assertEquals("reduce congestion", captor.getValue().getReason());
@@ -130,6 +137,71 @@
        assertEquals("agent", captor.getValue().getTriggerType());
        assertEquals(Boolean.TRUE, captor.getValue().getDryRun());
        assertSame(changes, captor.getValue().getChanges());
    }
    @Test
    void applyToolRejectsMissingDryRunBeforeServiceCall() {
        AutoTuneChangeCommand command = change("sys_config", null, "conveyorStationTaskLimit", "12");
        IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
                () -> tools.applyAutoTuneChanges("missing dryRun", 10, "agent", null, null,
                        Collections.singletonList(command)));
        assertTrue(exception.getMessage().contains("dryRun is required"));
        verify(autoTuneApplyService, never()).apply(any(AutoTuneApplyRequest.class));
    }
    @Test
    void applyToolRejectsDirectRealApplyWithoutDryRunToken() {
        AutoTuneChangeCommand command = change("sys_config", null, "conveyorStationTaskLimit", "12");
        IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
                () -> tools.applyAutoTuneChanges("direct apply", 10, "agent", false, null,
                        Collections.singletonList(command)));
        assertTrue(exception.getMessage().contains("dryRunToken is required"));
        verify(autoTuneApplyService, never()).apply(any(AutoTuneApplyRequest.class));
    }
    @Test
    void applyToolAllowsRealApplyOnlyWithMatchingDryRunToken() {
        AutoTuneApplyResult dryRunResult = new AutoTuneApplyResult();
        dryRunResult.setDryRun(true);
        dryRunResult.setSuccess(true);
        AutoTuneApplyResult applyResult = new AutoTuneApplyResult();
        applyResult.setDryRun(false);
        applyResult.setSuccess(true);
        when(autoTuneApplyService.apply(any(AutoTuneApplyRequest.class))).thenReturn(dryRunResult, applyResult);
        AutoTuneChangeCommand command = change(" sys_config ", "ignored", " conveyorStationTaskLimit ", " 12 ");
        List<AutoTuneChangeCommand> changes = Collections.singletonList(command);
        AutoTuneApplyResult preview = tools.applyAutoTuneChanges("preview", 10, "agent", true, null, changes);
        AutoTuneApplyResult applied = tools.applyAutoTuneChanges("apply", 10, "agent", false,
                preview.getDryRunToken(), changes);
        assertSame(applyResult, applied);
        ArgumentCaptor<AutoTuneApplyRequest> captor = ArgumentCaptor.forClass(AutoTuneApplyRequest.class);
        verify(autoTuneApplyService, times(2)).apply(captor.capture());
        assertEquals(Boolean.TRUE, captor.getAllValues().get(0).getDryRun());
        assertEquals(Boolean.FALSE, captor.getAllValues().get(1).getDryRun());
    }
    @Test
    void applyToolRejectsMismatchedDryRunToken() {
        AutoTuneApplyResult dryRunResult = new AutoTuneApplyResult();
        dryRunResult.setDryRun(true);
        dryRunResult.setSuccess(true);
        when(autoTuneApplyService.apply(any(AutoTuneApplyRequest.class))).thenReturn(dryRunResult);
        AutoTuneApplyResult preview = tools.applyAutoTuneChanges("preview", 10, "agent", true, null,
                Collections.singletonList(change("sys_config", null, "conveyorStationTaskLimit", "12")));
        IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
                () -> tools.applyAutoTuneChanges("apply", 10, "agent", false, preview.getDryRunToken(),
                        Collections.singletonList(change("sys_config", null, "conveyorStationTaskLimit", "13"))));
        assertTrue(exception.getMessage().contains("does not match"));
        verify(autoTuneApplyService, times(1)).apply(any(AutoTuneApplyRequest.class));
    }
    @Test
@@ -152,7 +224,7 @@
        AiPromptTemplate promptTemplate = new AiPromptTemplate();
        promptTemplate.setContent("system prompt");
        when(aiPromptTemplateService.resolvePublished("wcs_auto_tune_dispatch")).thenReturn(promptTemplate);
        when(mcpToolManager.buildOpenAiTools()).thenReturn(Collections.singletonList(Collections.singletonMap("type", "function")));
        when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools());
        when(mcpToolManager.callTool(any(), any(JSONObject.class))).thenReturn(Collections.singletonMap("ok", true));
        when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any()))
                .thenReturn(
@@ -177,13 +249,111 @@
        ArgumentCaptor<String> toolNameCaptor = ArgumentCaptor.forClass(String.class);
        ArgumentCaptor<JSONObject> argumentCaptor = ArgumentCaptor.forClass(JSONObject.class);
        verify(mcpToolManager, org.mockito.Mockito.times(3)).callTool(toolNameCaptor.capture(), argumentCaptor.capture());
        verify(mcpToolManager, times(3)).callTool(toolNameCaptor.capture(), argumentCaptor.capture());
        assertEquals("wcs_local_dispatch_get_auto_tune_snapshot", toolNameCaptor.getAllValues().get(0));
        assertEquals("wcs_local_dispatch_apply_auto_tune_changes", toolNameCaptor.getAllValues().get(1));
        assertEquals(Boolean.TRUE, argumentCaptor.getAllValues().get(1).getBoolean("dryRun"));
        assertEquals(Boolean.FALSE, argumentCaptor.getAllValues().get(2).getBoolean("dryRun"));
    }
    @Test
    void agentFailsAndDoesNotExecuteDisallowedToolCall() {
        AutoTuneAgentServiceImpl service = agentService();
        when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools());
        when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any()))
                .thenReturn(response("bad tool", toolCall("call_1", "wcs_local_device_get_crn_status", "{}"), 10, 5));
        AutoTuneAgentService.AutoTuneAgentResult result = service.runAutoTune("scheduler");
        assertFalse(result.getSuccess());
        assertEquals(0, result.getToolCallCount());
        assertTrue(result.getSummary().contains("Disallowed auto-tune MCP tool"));
        verify(mcpToolManager, never()).callTool(any(), any(JSONObject.class));
    }
    @Test
    void agentFailsWhenAllowedToolThrows() {
        AutoTuneAgentServiceImpl service = agentService();
        when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools());
        when(mcpToolManager.callTool(any(), any(JSONObject.class))).thenThrow(new RuntimeException("boom"));
        when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any()))
                .thenReturn(response("snapshot", toolCall("call_1", "wcs_local_dispatch_get_auto_tune_snapshot", "{}"), 10, 5));
        AutoTuneAgentService.AutoTuneAgentResult result = service.runAutoTune("scheduler");
        assertFalse(result.getSuccess());
        assertEquals(0, result.getToolCallCount());
        assertTrue(result.getSummary().contains("Auto-tune MCP tool failed"));
        verify(mcpToolManager).callTool(any(), any(JSONObject.class));
    }
    @Test
    void agentFailsWhenLlmReturnsNoToolCalls() {
        AutoTuneAgentServiceImpl service = agentService();
        when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools());
        when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any()))
                .thenReturn(response("no changes needed", null, 10, 5));
        AutoTuneAgentService.AutoTuneAgentResult result = service.runAutoTune("scheduler");
        assertFalse(result.getSuccess());
        assertEquals(0, result.getToolCallCount());
        assertTrue(result.getSummary().contains("未调用任何允许的 MCP 工具"));
    }
    @Test
    void agentFailsWhenMaxRoundsReached() {
        AutoTuneAgentServiceImpl service = agentService();
        when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools());
        when(mcpToolManager.callTool(any(), any(JSONObject.class))).thenReturn(Collections.singletonMap("ok", true));
        when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any()))
                .thenReturn(response("keep going", toolCall("call_1", "wcs_local_dispatch_get_auto_tune_snapshot", "{}"), 10, 5));
        AutoTuneAgentService.AutoTuneAgentResult result = service.runAutoTune("scheduler");
        assertFalse(result.getSuccess());
        assertEquals(10, result.getToolCallCount());
        assertTrue(result.getMaxRoundsReached());
        assertTrue(result.getSummary().contains("达到最大工具调用轮次"));
    }
    private AutoTuneAgentServiceImpl agentService() {
        AiPromptTemplate promptTemplate = new AiPromptTemplate();
        promptTemplate.setContent("system prompt");
        when(aiPromptTemplateService.resolvePublished("wcs_auto_tune_dispatch")).thenReturn(promptTemplate);
        return new AutoTuneAgentServiceImpl(llmChatService, mcpToolManager, aiPromptTemplateService);
    }
    private AutoTuneChangeCommand change(String targetType, String targetId, String targetKey, String newValue) {
        AutoTuneChangeCommand command = new AutoTuneChangeCommand();
        command.setTargetType(targetType);
        command.setTargetId(targetId);
        command.setTargetKey(targetKey);
        command.setNewValue(newValue);
        return command;
    }
    private List<Object> allowedOpenAiTools() {
        List<Object> tools = new ArrayList<>();
        tools.add(openAiTool("wcs_local_dispatch_get_auto_tune_snapshot"));
        tools.add(openAiTool("wcs_local_dispatch_get_recent_auto_tune_jobs"));
        tools.add(openAiTool("wcs_local_dispatch_apply_auto_tune_changes"));
        tools.add(openAiTool("wcs_local_dispatch_revert_last_auto_tune_job"));
        tools.add(openAiTool("wcs_local_device_get_crn_status"));
        return tools;
    }
    private Map<String, Object> openAiTool(String name) {
        LinkedHashMap<String, Object> function = new LinkedHashMap<>();
        function.put("name", name);
        function.put("parameters", Collections.emptyMap());
        LinkedHashMap<String, Object> tool = new LinkedHashMap<>();
        tool.put("type", "function");
        tool.put("function", function);
        return tool;
    }
    private ChatCompletionResponse response(String content,
                                            ChatCompletionRequest.ToolCall toolCall,
                                            int promptTokens,