From abe9409ecbe5ac752dd6f14100733d1b3739ac09 Mon Sep 17 00:00:00 2001
From: Junjie <fallin.jie@qq.com>
Date: 星期一, 27 四月 2026 12:14:36 +0800
Subject: [PATCH] test: cover auto tune agent safety paths
---
src/test/java/com/zy/ai/service/AutoTuneCoordinatorServiceImplTest.java | 60 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--
1 files changed, 58 insertions(+), 2 deletions(-)
diff --git a/src/test/java/com/zy/ai/service/AutoTuneCoordinatorServiceImplTest.java b/src/test/java/com/zy/ai/service/AutoTuneCoordinatorServiceImplTest.java
index 0aa2f86..17d087f 100644
--- a/src/test/java/com/zy/ai/service/AutoTuneCoordinatorServiceImplTest.java
+++ b/src/test/java/com/zy/ai/service/AutoTuneCoordinatorServiceImplTest.java
@@ -20,12 +20,15 @@
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
+import org.springframework.test.util.ReflectionTestUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.function.LongSupplier;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
@@ -205,6 +208,26 @@
}
@Test
+ void applyToolRejectsExpiredDryRunToken() {
+ AtomicLong currentTimeMillis = new AtomicLong(1_000L);
+ ReflectionTestUtils.invokeMethod(tools, "setCurrentTimeMillisSupplier", (LongSupplier) currentTimeMillis::get);
+ AutoTuneApplyResult dryRunResult = new AutoTuneApplyResult();
+ dryRunResult.setDryRun(true);
+ dryRunResult.setSuccess(true);
+ when(autoTuneApplyService.apply(any(AutoTuneApplyRequest.class))).thenReturn(dryRunResult);
+ List<AutoTuneChangeCommand> changes = Collections.singletonList(
+ change("sys_config", null, "conveyorStationTaskLimit", "12"));
+
+ AutoTuneApplyResult preview = tools.applyAutoTuneChanges("preview", 10, "agent", true, null, changes);
+ currentTimeMillis.addAndGet(10L * 60L * 1000L + 1L);
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
+ () -> tools.applyAutoTuneChanges("apply", 10, "agent", false, preview.getDryRunToken(), changes));
+
+ assertTrue(exception.getMessage().contains("expired"));
+ verify(autoTuneApplyService, times(1)).apply(any(AutoTuneApplyRequest.class));
+ }
+
+ @Test
void rollbackToolDelegatesToApplyServiceRollback() {
AutoTuneApplyResult expected = new AutoTuneApplyResult();
when(autoTuneApplyService.rollbackLastSuccessfulJob("bad result")).thenReturn(expected);
@@ -225,14 +248,26 @@
promptTemplate.setContent("system prompt");
when(aiPromptTemplateService.resolvePublished("wcs_auto_tune_dispatch")).thenReturn(promptTemplate);
when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools());
- when(mcpToolManager.callTool(any(), any(JSONObject.class))).thenReturn(Collections.singletonMap("ok", true));
+ when(mcpToolManager.callTool(any(), any(JSONObject.class))).thenAnswer(invocation -> {
+ String toolName = invocation.getArgument(0);
+ JSONObject arguments = invocation.getArgument(1);
+ if ("wcs_local_dispatch_apply_auto_tune_changes".equals(toolName)
+ && Boolean.TRUE.equals(arguments.getBoolean("dryRun"))) {
+ LinkedHashMap<String, Object> dryRunOutput = new LinkedHashMap<>();
+ dryRunOutput.put("success", true);
+ dryRunOutput.put("dryRun", true);
+ dryRunOutput.put("dryRunToken", "token-123");
+ return dryRunOutput;
+ }
+ return Collections.singletonMap("ok", true);
+ });
when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any()))
.thenReturn(
response("read snapshot", toolCall("call_1", "wcs_local_dispatch_get_auto_tune_snapshot", "{}"), 10, 5),
response("dry-run first", toolCall("call_2", "wcs_local_dispatch_apply_auto_tune_changes",
"{\"dryRun\":true,\"changes\":[{\"targetType\":\"sys_config\",\"targetKey\":\"conveyorStationTaskLimit\",\"newValue\":\"12\"}]}"), 11, 6),
response("apply after dry-run", toolCall("call_3", "wcs_local_dispatch_apply_auto_tune_changes",
- "{\"dryRun\":false,\"changes\":[{\"targetType\":\"sys_config\",\"targetKey\":\"conveyorStationTaskLimit\",\"newValue\":\"12\"}]}"), 12, 7),
+ "{\"dryRun\":false,\"dryRunToken\":\"token-123\",\"changes\":[{\"targetType\":\"sys_config\",\"targetKey\":\"conveyorStationTaskLimit\",\"newValue\":\"12\"}]}"), 12, 7),
response("宸插畬鎴愯嚜鍔ㄨ皟鍙�", null, 13, 8)
);
@@ -254,6 +289,17 @@
assertEquals("wcs_local_dispatch_apply_auto_tune_changes", toolNameCaptor.getAllValues().get(1));
assertEquals(Boolean.TRUE, argumentCaptor.getAllValues().get(1).getBoolean("dryRun"));
assertEquals(Boolean.FALSE, argumentCaptor.getAllValues().get(2).getBoolean("dryRun"));
+ assertEquals("token-123", argumentCaptor.getAllValues().get(2).getString("dryRunToken"));
+
+ ArgumentCaptor<List<Object>> toolsCaptor = ArgumentCaptor.forClass(List.class);
+ verify(llmChatService, times(4)).chatCompletion(any(), anyDouble(), anyInt(), toolsCaptor.capture());
+ List<String> visibleToolNames = toolNames(toolsCaptor.getAllValues().get(0));
+ assertEquals(4, visibleToolNames.size());
+ assertTrue(visibleToolNames.contains("wcs_local_dispatch_get_auto_tune_snapshot"));
+ assertTrue(visibleToolNames.contains("wcs_local_dispatch_get_recent_auto_tune_jobs"));
+ assertTrue(visibleToolNames.contains("wcs_local_dispatch_apply_auto_tune_changes"));
+ assertTrue(visibleToolNames.contains("wcs_local_dispatch_revert_last_auto_tune_job"));
+ assertFalse(visibleToolNames.contains("wcs_local_device_get_crn_status"));
}
@Test
@@ -354,6 +400,16 @@
return tool;
}
+ private List<String> toolNames(List<Object> tools) {
+ List<String> names = new ArrayList<>();
+ for (Object tool : tools) {
+ Map<?, ?> toolMap = (Map<?, ?>) tool;
+ Map<?, ?> function = (Map<?, ?>) toolMap.get("function");
+ names.add(String.valueOf(function.get("name")));
+ }
+ return names;
+ }
+
private ChatCompletionResponse response(String content,
ChatCompletionRequest.ToolCall toolCall,
int promptTokens,
--
Gitblit v1.9.1