From 10bdc4b6e9701befd1a83bccd2998dcc96cb2c43 Mon Sep 17 00:00:00 2001
From: Junjie <fallin.jie@qq.com>
Date: 星期一, 27 四月 2026 12:07:38 +0800
Subject: [PATCH] fix: enforce auto tune agent tool safety

---
 src/test/java/com/zy/ai/service/AutoTuneCoordinatorServiceImplTest.java |  176 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++-
 1 files changed, 173 insertions(+), 3 deletions(-)

diff --git a/src/test/java/com/zy/ai/service/AutoTuneCoordinatorServiceImplTest.java b/src/test/java/com/zy/ai/service/AutoTuneCoordinatorServiceImplTest.java
index 9110d12..0aa2f86 100644
--- a/src/test/java/com/zy/ai/service/AutoTuneCoordinatorServiceImplTest.java
+++ b/src/test/java/com/zy/ai/service/AutoTuneCoordinatorServiceImplTest.java
@@ -23,16 +23,21 @@
 
 import java.util.ArrayList;
 import java.util.Collections;
+import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
 import static org.junit.jupiter.api.Assertions.assertSame;
+import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyDouble;
 import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
@@ -112,6 +117,7 @@
     void applyToolDelegatesToApplyServiceWithDryRunAndChanges() {
         AutoTuneApplyResult expected = new AutoTuneApplyResult();
         expected.setDryRun(true);
+        expected.setSuccess(true);
         when(autoTuneApplyService.apply(any(AutoTuneApplyRequest.class))).thenReturn(expected);
 
         AutoTuneChangeCommand command = new AutoTuneChangeCommand();
@@ -120,9 +126,10 @@
         command.setNewValue("12");
         List<AutoTuneChangeCommand> changes = Collections.singletonList(command);
 
-        AutoTuneApplyResult result = tools.applyAutoTuneChanges("reduce congestion", 10, "agent", true, changes);
+        AutoTuneApplyResult result = tools.applyAutoTuneChanges("reduce congestion", 10, "agent", true, null, changes);
 
         assertSame(expected, result);
+        assertNotNull(result.getDryRunToken());
         ArgumentCaptor<AutoTuneApplyRequest> captor = ArgumentCaptor.forClass(AutoTuneApplyRequest.class);
         verify(autoTuneApplyService).apply(captor.capture());
         assertEquals("reduce congestion", captor.getValue().getReason());
@@ -130,6 +137,71 @@
         assertEquals("agent", captor.getValue().getTriggerType());
         assertEquals(Boolean.TRUE, captor.getValue().getDryRun());
         assertSame(changes, captor.getValue().getChanges());
+    }
+
+    @Test
+    void applyToolRejectsMissingDryRunBeforeServiceCall() {
+        AutoTuneChangeCommand command = change("sys_config", null, "conveyorStationTaskLimit", "12");
+
+        IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
+                () -> tools.applyAutoTuneChanges("missing dryRun", 10, "agent", null, null,
+                        Collections.singletonList(command)));
+
+        assertTrue(exception.getMessage().contains("dryRun is required"));
+        verify(autoTuneApplyService, never()).apply(any(AutoTuneApplyRequest.class));
+    }
+
+    @Test
+    void applyToolRejectsDirectRealApplyWithoutDryRunToken() {
+        AutoTuneChangeCommand command = change("sys_config", null, "conveyorStationTaskLimit", "12");
+
+        IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
+                () -> tools.applyAutoTuneChanges("direct apply", 10, "agent", false, null,
+                        Collections.singletonList(command)));
+
+        assertTrue(exception.getMessage().contains("dryRunToken is required"));
+        verify(autoTuneApplyService, never()).apply(any(AutoTuneApplyRequest.class));
+    }
+
+    @Test
+    void applyToolAllowsRealApplyOnlyWithMatchingDryRunToken() {
+        AutoTuneApplyResult dryRunResult = new AutoTuneApplyResult();
+        dryRunResult.setDryRun(true);
+        dryRunResult.setSuccess(true);
+        AutoTuneApplyResult applyResult = new AutoTuneApplyResult();
+        applyResult.setDryRun(false);
+        applyResult.setSuccess(true);
+        when(autoTuneApplyService.apply(any(AutoTuneApplyRequest.class))).thenReturn(dryRunResult, applyResult);
+        AutoTuneChangeCommand command = change(" sys_config ", "ignored", " conveyorStationTaskLimit ", " 12 ");
+        List<AutoTuneChangeCommand> changes = Collections.singletonList(command);
+
+        AutoTuneApplyResult preview = tools.applyAutoTuneChanges("preview", 10, "agent", true, null, changes);
+        AutoTuneApplyResult applied = tools.applyAutoTuneChanges("apply", 10, "agent", false,
+                preview.getDryRunToken(), changes);
+
+        assertSame(applyResult, applied);
+        ArgumentCaptor<AutoTuneApplyRequest> captor = ArgumentCaptor.forClass(AutoTuneApplyRequest.class);
+        verify(autoTuneApplyService, times(2)).apply(captor.capture());
+        assertEquals(Boolean.TRUE, captor.getAllValues().get(0).getDryRun());
+        assertEquals(Boolean.FALSE, captor.getAllValues().get(1).getDryRun());
+    }
+
+    @Test
+    void applyToolRejectsMismatchedDryRunToken() {
+        AutoTuneApplyResult dryRunResult = new AutoTuneApplyResult();
+        dryRunResult.setDryRun(true);
+        dryRunResult.setSuccess(true);
+        when(autoTuneApplyService.apply(any(AutoTuneApplyRequest.class))).thenReturn(dryRunResult);
+
+        AutoTuneApplyResult preview = tools.applyAutoTuneChanges("preview", 10, "agent", true, null,
+                Collections.singletonList(change("sys_config", null, "conveyorStationTaskLimit", "12")));
+
+        IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
+                () -> tools.applyAutoTuneChanges("apply", 10, "agent", false, preview.getDryRunToken(),
+                        Collections.singletonList(change("sys_config", null, "conveyorStationTaskLimit", "13"))));
+
+        assertTrue(exception.getMessage().contains("does not match"));
+        verify(autoTuneApplyService, times(1)).apply(any(AutoTuneApplyRequest.class));
     }
 
     @Test
@@ -152,7 +224,7 @@
         AiPromptTemplate promptTemplate = new AiPromptTemplate();
         promptTemplate.setContent("system prompt");
         when(aiPromptTemplateService.resolvePublished("wcs_auto_tune_dispatch")).thenReturn(promptTemplate);
-        when(mcpToolManager.buildOpenAiTools()).thenReturn(Collections.singletonList(Collections.singletonMap("type", "function")));
+        when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools());
         when(mcpToolManager.callTool(any(), any(JSONObject.class))).thenReturn(Collections.singletonMap("ok", true));
         when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any()))
                 .thenReturn(
@@ -177,13 +249,111 @@
 
         ArgumentCaptor<String> toolNameCaptor = ArgumentCaptor.forClass(String.class);
         ArgumentCaptor<JSONObject> argumentCaptor = ArgumentCaptor.forClass(JSONObject.class);
-        verify(mcpToolManager, org.mockito.Mockito.times(3)).callTool(toolNameCaptor.capture(), argumentCaptor.capture());
+        verify(mcpToolManager, times(3)).callTool(toolNameCaptor.capture(), argumentCaptor.capture());
         assertEquals("wcs_local_dispatch_get_auto_tune_snapshot", toolNameCaptor.getAllValues().get(0));
         assertEquals("wcs_local_dispatch_apply_auto_tune_changes", toolNameCaptor.getAllValues().get(1));
         assertEquals(Boolean.TRUE, argumentCaptor.getAllValues().get(1).getBoolean("dryRun"));
         assertEquals(Boolean.FALSE, argumentCaptor.getAllValues().get(2).getBoolean("dryRun"));
     }
 
+    @Test
+    void agentFailsAndDoesNotExecuteDisallowedToolCall() {
+        AutoTuneAgentServiceImpl service = agentService();
+        when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools());
+        when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any()))
+                .thenReturn(response("bad tool", toolCall("call_1", "wcs_local_device_get_crn_status", "{}"), 10, 5));
+
+        AutoTuneAgentService.AutoTuneAgentResult result = service.runAutoTune("scheduler");
+
+        assertFalse(result.getSuccess());
+        assertEquals(0, result.getToolCallCount());
+        assertTrue(result.getSummary().contains("Disallowed auto-tune MCP tool"));
+        verify(mcpToolManager, never()).callTool(any(), any(JSONObject.class));
+    }
+
+    @Test
+    void agentFailsWhenAllowedToolThrows() {
+        AutoTuneAgentServiceImpl service = agentService();
+        when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools());
+        when(mcpToolManager.callTool(any(), any(JSONObject.class))).thenThrow(new RuntimeException("boom"));
+        when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any()))
+                .thenReturn(response("snapshot", toolCall("call_1", "wcs_local_dispatch_get_auto_tune_snapshot", "{}"), 10, 5));
+
+        AutoTuneAgentService.AutoTuneAgentResult result = service.runAutoTune("scheduler");
+
+        assertFalse(result.getSuccess());
+        assertEquals(0, result.getToolCallCount());
+        assertTrue(result.getSummary().contains("Auto-tune MCP tool failed"));
+        verify(mcpToolManager).callTool(any(), any(JSONObject.class));
+    }
+
+    @Test
+    void agentFailsWhenLlmReturnsNoToolCalls() {
+        AutoTuneAgentServiceImpl service = agentService();
+        when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools());
+        when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any()))
+                .thenReturn(response("no changes needed", null, 10, 5));
+
+        AutoTuneAgentService.AutoTuneAgentResult result = service.runAutoTune("scheduler");
+
+        assertFalse(result.getSuccess());
+        assertEquals(0, result.getToolCallCount());
+        assertTrue(result.getSummary().contains("鏈皟鐢ㄤ换浣曞厑璁哥殑 MCP 宸ュ叿"));
+    }
+
+    @Test
+    void agentFailsWhenMaxRoundsReached() {
+        AutoTuneAgentServiceImpl service = agentService();
+        when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools());
+        when(mcpToolManager.callTool(any(), any(JSONObject.class))).thenReturn(Collections.singletonMap("ok", true));
+        when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any()))
+                .thenReturn(response("keep going", toolCall("call_1", "wcs_local_dispatch_get_auto_tune_snapshot", "{}"), 10, 5));
+
+        AutoTuneAgentService.AutoTuneAgentResult result = service.runAutoTune("scheduler");
+
+        assertFalse(result.getSuccess());
+        assertEquals(10, result.getToolCallCount());
+        assertTrue(result.getMaxRoundsReached());
+        assertTrue(result.getSummary().contains("杈惧埌鏈�澶у伐鍏疯皟鐢ㄨ疆娆�"));
+    }
+
+    private AutoTuneAgentServiceImpl agentService() {
+        AiPromptTemplate promptTemplate = new AiPromptTemplate();
+        promptTemplate.setContent("system prompt");
+        when(aiPromptTemplateService.resolvePublished("wcs_auto_tune_dispatch")).thenReturn(promptTemplate);
+        return new AutoTuneAgentServiceImpl(llmChatService, mcpToolManager, aiPromptTemplateService);
+    }
+
+    private AutoTuneChangeCommand change(String targetType, String targetId, String targetKey, String newValue) {
+        AutoTuneChangeCommand command = new AutoTuneChangeCommand();
+        command.setTargetType(targetType);
+        command.setTargetId(targetId);
+        command.setTargetKey(targetKey);
+        command.setNewValue(newValue);
+        return command;
+    }
+
+    private List<Object> allowedOpenAiTools() {
+        List<Object> tools = new ArrayList<>();
+        tools.add(openAiTool("wcs_local_dispatch_get_auto_tune_snapshot"));
+        tools.add(openAiTool("wcs_local_dispatch_get_recent_auto_tune_jobs"));
+        tools.add(openAiTool("wcs_local_dispatch_apply_auto_tune_changes"));
+        tools.add(openAiTool("wcs_local_dispatch_revert_last_auto_tune_job"));
+        tools.add(openAiTool("wcs_local_device_get_crn_status"));
+        return tools;
+    }
+
+    private Map<String, Object> openAiTool(String name) {
+        LinkedHashMap<String, Object> function = new LinkedHashMap<>();
+        function.put("name", name);
+        function.put("parameters", Collections.emptyMap());
+
+        LinkedHashMap<String, Object> tool = new LinkedHashMap<>();
+        tool.put("type", "function");
+        tool.put("function", function);
+        return tool;
+    }
+
     private ChatCompletionResponse response(String content,
                                             ChatCompletionRequest.ToolCall toolCall,
                                             int promptTokens,

--
Gitblit v1.9.1