From 10bdc4b6e9701befd1a83bccd2998dcc96cb2c43 Mon Sep 17 00:00:00 2001
From: Junjie <fallin.jie@qq.com>
Date: 星期一, 27 四月 2026 12:07:38 +0800
Subject: [PATCH] fix: enforce auto tune agent tool safety
---
src/test/java/com/zy/ai/service/AutoTuneCoordinatorServiceImplTest.java | 176 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++-
1 files changed, 173 insertions(+), 3 deletions(-)
diff --git a/src/test/java/com/zy/ai/service/AutoTuneCoordinatorServiceImplTest.java b/src/test/java/com/zy/ai/service/AutoTuneCoordinatorServiceImplTest.java
index 9110d12..0aa2f86 100644
--- a/src/test/java/com/zy/ai/service/AutoTuneCoordinatorServiceImplTest.java
+++ b/src/test/java/com/zy/ai/service/AutoTuneCoordinatorServiceImplTest.java
@@ -23,16 +23,21 @@
import java.util.ArrayList;
import java.util.Collections;
+import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertSame;
+import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyDouble;
import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@@ -112,6 +117,7 @@
void applyToolDelegatesToApplyServiceWithDryRunAndChanges() {
AutoTuneApplyResult expected = new AutoTuneApplyResult();
expected.setDryRun(true);
+ expected.setSuccess(true);
when(autoTuneApplyService.apply(any(AutoTuneApplyRequest.class))).thenReturn(expected);
AutoTuneChangeCommand command = new AutoTuneChangeCommand();
@@ -120,9 +126,10 @@
command.setNewValue("12");
List<AutoTuneChangeCommand> changes = Collections.singletonList(command);
- AutoTuneApplyResult result = tools.applyAutoTuneChanges("reduce congestion", 10, "agent", true, changes);
+ AutoTuneApplyResult result = tools.applyAutoTuneChanges("reduce congestion", 10, "agent", true, null, changes);
assertSame(expected, result);
+ assertNotNull(result.getDryRunToken());
ArgumentCaptor<AutoTuneApplyRequest> captor = ArgumentCaptor.forClass(AutoTuneApplyRequest.class);
verify(autoTuneApplyService).apply(captor.capture());
assertEquals("reduce congestion", captor.getValue().getReason());
@@ -130,6 +137,71 @@
assertEquals("agent", captor.getValue().getTriggerType());
assertEquals(Boolean.TRUE, captor.getValue().getDryRun());
assertSame(changes, captor.getValue().getChanges());
+ }
+
+ @Test
+ void applyToolRejectsMissingDryRunBeforeServiceCall() {
+ AutoTuneChangeCommand command = change("sys_config", null, "conveyorStationTaskLimit", "12");
+
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
+ () -> tools.applyAutoTuneChanges("missing dryRun", 10, "agent", null, null,
+ Collections.singletonList(command)));
+
+ assertTrue(exception.getMessage().contains("dryRun is required"));
+ verify(autoTuneApplyService, never()).apply(any(AutoTuneApplyRequest.class));
+ }
+
+ @Test
+ void applyToolRejectsDirectRealApplyWithoutDryRunToken() {
+ AutoTuneChangeCommand command = change("sys_config", null, "conveyorStationTaskLimit", "12");
+
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
+ () -> tools.applyAutoTuneChanges("direct apply", 10, "agent", false, null,
+ Collections.singletonList(command)));
+
+ assertTrue(exception.getMessage().contains("dryRunToken is required"));
+ verify(autoTuneApplyService, never()).apply(any(AutoTuneApplyRequest.class));
+ }
+
+ @Test
+ void applyToolAllowsRealApplyOnlyWithMatchingDryRunToken() {
+ AutoTuneApplyResult dryRunResult = new AutoTuneApplyResult();
+ dryRunResult.setDryRun(true);
+ dryRunResult.setSuccess(true);
+ AutoTuneApplyResult applyResult = new AutoTuneApplyResult();
+ applyResult.setDryRun(false);
+ applyResult.setSuccess(true);
+ when(autoTuneApplyService.apply(any(AutoTuneApplyRequest.class))).thenReturn(dryRunResult, applyResult);
+ AutoTuneChangeCommand command = change(" sys_config ", "ignored", " conveyorStationTaskLimit ", " 12 ");
+ List<AutoTuneChangeCommand> changes = Collections.singletonList(command);
+
+ AutoTuneApplyResult preview = tools.applyAutoTuneChanges("preview", 10, "agent", true, null, changes);
+ AutoTuneApplyResult applied = tools.applyAutoTuneChanges("apply", 10, "agent", false,
+ preview.getDryRunToken(), changes);
+
+ assertSame(applyResult, applied);
+ ArgumentCaptor<AutoTuneApplyRequest> captor = ArgumentCaptor.forClass(AutoTuneApplyRequest.class);
+ verify(autoTuneApplyService, times(2)).apply(captor.capture());
+ assertEquals(Boolean.TRUE, captor.getAllValues().get(0).getDryRun());
+ assertEquals(Boolean.FALSE, captor.getAllValues().get(1).getDryRun());
+ }
+
+ @Test
+ void applyToolRejectsMismatchedDryRunToken() {
+ AutoTuneApplyResult dryRunResult = new AutoTuneApplyResult();
+ dryRunResult.setDryRun(true);
+ dryRunResult.setSuccess(true);
+ when(autoTuneApplyService.apply(any(AutoTuneApplyRequest.class))).thenReturn(dryRunResult);
+
+ AutoTuneApplyResult preview = tools.applyAutoTuneChanges("preview", 10, "agent", true, null,
+ Collections.singletonList(change("sys_config", null, "conveyorStationTaskLimit", "12")));
+
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
+ () -> tools.applyAutoTuneChanges("apply", 10, "agent", false, preview.getDryRunToken(),
+ Collections.singletonList(change("sys_config", null, "conveyorStationTaskLimit", "13"))));
+
+ assertTrue(exception.getMessage().contains("does not match"));
+ verify(autoTuneApplyService, times(1)).apply(any(AutoTuneApplyRequest.class));
}
@Test
@@ -152,7 +224,7 @@
AiPromptTemplate promptTemplate = new AiPromptTemplate();
promptTemplate.setContent("system prompt");
when(aiPromptTemplateService.resolvePublished("wcs_auto_tune_dispatch")).thenReturn(promptTemplate);
- when(mcpToolManager.buildOpenAiTools()).thenReturn(Collections.singletonList(Collections.singletonMap("type", "function")));
+ when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools());
when(mcpToolManager.callTool(any(), any(JSONObject.class))).thenReturn(Collections.singletonMap("ok", true));
when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any()))
.thenReturn(
@@ -177,13 +249,111 @@
ArgumentCaptor<String> toolNameCaptor = ArgumentCaptor.forClass(String.class);
ArgumentCaptor<JSONObject> argumentCaptor = ArgumentCaptor.forClass(JSONObject.class);
- verify(mcpToolManager, org.mockito.Mockito.times(3)).callTool(toolNameCaptor.capture(), argumentCaptor.capture());
+ verify(mcpToolManager, times(3)).callTool(toolNameCaptor.capture(), argumentCaptor.capture());
assertEquals("wcs_local_dispatch_get_auto_tune_snapshot", toolNameCaptor.getAllValues().get(0));
assertEquals("wcs_local_dispatch_apply_auto_tune_changes", toolNameCaptor.getAllValues().get(1));
assertEquals(Boolean.TRUE, argumentCaptor.getAllValues().get(1).getBoolean("dryRun"));
assertEquals(Boolean.FALSE, argumentCaptor.getAllValues().get(2).getBoolean("dryRun"));
}
+ @Test
+ void agentFailsAndDoesNotExecuteDisallowedToolCall() {
+ AutoTuneAgentServiceImpl service = agentService();
+ when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools());
+ when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any()))
+ .thenReturn(response("bad tool", toolCall("call_1", "wcs_local_device_get_crn_status", "{}"), 10, 5));
+
+ AutoTuneAgentService.AutoTuneAgentResult result = service.runAutoTune("scheduler");
+
+ assertFalse(result.getSuccess());
+ assertEquals(0, result.getToolCallCount());
+ assertTrue(result.getSummary().contains("Disallowed auto-tune MCP tool"));
+ verify(mcpToolManager, never()).callTool(any(), any(JSONObject.class));
+ }
+
+ @Test
+ void agentFailsWhenAllowedToolThrows() {
+ AutoTuneAgentServiceImpl service = agentService();
+ when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools());
+ when(mcpToolManager.callTool(any(), any(JSONObject.class))).thenThrow(new RuntimeException("boom"));
+ when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any()))
+ .thenReturn(response("snapshot", toolCall("call_1", "wcs_local_dispatch_get_auto_tune_snapshot", "{}"), 10, 5));
+
+ AutoTuneAgentService.AutoTuneAgentResult result = service.runAutoTune("scheduler");
+
+ assertFalse(result.getSuccess());
+ assertEquals(0, result.getToolCallCount());
+ assertTrue(result.getSummary().contains("Auto-tune MCP tool failed"));
+ verify(mcpToolManager).callTool(any(), any(JSONObject.class));
+ }
+
+ @Test
+ void agentFailsWhenLlmReturnsNoToolCalls() {
+ AutoTuneAgentServiceImpl service = agentService();
+ when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools());
+ when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any()))
+ .thenReturn(response("no changes needed", null, 10, 5));
+
+ AutoTuneAgentService.AutoTuneAgentResult result = service.runAutoTune("scheduler");
+
+ assertFalse(result.getSuccess());
+ assertEquals(0, result.getToolCallCount());
+ assertTrue(result.getSummary().contains("鏈皟鐢ㄤ换浣曞厑璁哥殑 MCP 宸ュ叿"));
+ }
+
+ @Test
+ void agentFailsWhenMaxRoundsReached() {
+ AutoTuneAgentServiceImpl service = agentService();
+ when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools());
+ when(mcpToolManager.callTool(any(), any(JSONObject.class))).thenReturn(Collections.singletonMap("ok", true));
+ when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any()))
+ .thenReturn(response("keep going", toolCall("call_1", "wcs_local_dispatch_get_auto_tune_snapshot", "{}"), 10, 5));
+
+ AutoTuneAgentService.AutoTuneAgentResult result = service.runAutoTune("scheduler");
+
+ assertFalse(result.getSuccess());
+ assertEquals(10, result.getToolCallCount());
+ assertTrue(result.getMaxRoundsReached());
+ assertTrue(result.getSummary().contains("杈惧埌鏈�澶у伐鍏疯皟鐢ㄨ疆娆�"));
+ }
+
+ private AutoTuneAgentServiceImpl agentService() {
+ AiPromptTemplate promptTemplate = new AiPromptTemplate();
+ promptTemplate.setContent("system prompt");
+ when(aiPromptTemplateService.resolvePublished("wcs_auto_tune_dispatch")).thenReturn(promptTemplate);
+ return new AutoTuneAgentServiceImpl(llmChatService, mcpToolManager, aiPromptTemplateService);
+ }
+
+ private AutoTuneChangeCommand change(String targetType, String targetId, String targetKey, String newValue) {
+ AutoTuneChangeCommand command = new AutoTuneChangeCommand();
+ command.setTargetType(targetType);
+ command.setTargetId(targetId);
+ command.setTargetKey(targetKey);
+ command.setNewValue(newValue);
+ return command;
+ }
+
+ private List<Object> allowedOpenAiTools() {
+ List<Object> tools = new ArrayList<>();
+ tools.add(openAiTool("wcs_local_dispatch_get_auto_tune_snapshot"));
+ tools.add(openAiTool("wcs_local_dispatch_get_recent_auto_tune_jobs"));
+ tools.add(openAiTool("wcs_local_dispatch_apply_auto_tune_changes"));
+ tools.add(openAiTool("wcs_local_dispatch_revert_last_auto_tune_job"));
+ tools.add(openAiTool("wcs_local_device_get_crn_status"));
+ return tools;
+ }
+
+ private Map<String, Object> openAiTool(String name) {
+ LinkedHashMap<String, Object> function = new LinkedHashMap<>();
+ function.put("name", name);
+ function.put("parameters", Collections.emptyMap());
+
+ LinkedHashMap<String, Object> tool = new LinkedHashMap<>();
+ tool.put("type", "function");
+ tool.put("function", function);
+ return tool;
+ }
+
private ChatCompletionResponse response(String content,
ChatCompletionRequest.ToolCall toolCall,
int promptTokens,
--
Gitblit v1.9.1