Junjie
2026-04-27 10bdc4b6e9701befd1a83bccd2998dcc96cb2c43
src/test/java/com/zy/ai/service/AutoTuneCoordinatorServiceImplTest.java
@@ -23,16 +23,21 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyDouble;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@@ -112,6 +117,7 @@
    void applyToolDelegatesToApplyServiceWithDryRunAndChanges() {
        AutoTuneApplyResult expected = new AutoTuneApplyResult();
        expected.setDryRun(true);
        expected.setSuccess(true);
        when(autoTuneApplyService.apply(any(AutoTuneApplyRequest.class))).thenReturn(expected);
        AutoTuneChangeCommand command = new AutoTuneChangeCommand();
@@ -120,9 +126,10 @@
        command.setNewValue("12");
        List<AutoTuneChangeCommand> changes = Collections.singletonList(command);
        AutoTuneApplyResult result = tools.applyAutoTuneChanges("reduce congestion", 10, "agent", true, changes);
        AutoTuneApplyResult result = tools.applyAutoTuneChanges("reduce congestion", 10, "agent", true, null, changes);
        assertSame(expected, result);
        assertNotNull(result.getDryRunToken());
        ArgumentCaptor<AutoTuneApplyRequest> captor = ArgumentCaptor.forClass(AutoTuneApplyRequest.class);
        verify(autoTuneApplyService).apply(captor.capture());
        assertEquals("reduce congestion", captor.getValue().getReason());
@@ -130,6 +137,71 @@
        assertEquals("agent", captor.getValue().getTriggerType());
        assertEquals(Boolean.TRUE, captor.getValue().getDryRun());
        assertSame(changes, captor.getValue().getChanges());
    }
    @Test
    void applyToolRejectsMissingDryRunBeforeServiceCall() {
        AutoTuneChangeCommand command = change("sys_config", null, "conveyorStationTaskLimit", "12");
        IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
                () -> tools.applyAutoTuneChanges("missing dryRun", 10, "agent", null, null,
                        Collections.singletonList(command)));
        assertTrue(exception.getMessage().contains("dryRun is required"));
        verify(autoTuneApplyService, never()).apply(any(AutoTuneApplyRequest.class));
    }
    @Test
    void applyToolRejectsDirectRealApplyWithoutDryRunToken() {
        AutoTuneChangeCommand command = change("sys_config", null, "conveyorStationTaskLimit", "12");
        IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
                () -> tools.applyAutoTuneChanges("direct apply", 10, "agent", false, null,
                        Collections.singletonList(command)));
        assertTrue(exception.getMessage().contains("dryRunToken is required"));
        verify(autoTuneApplyService, never()).apply(any(AutoTuneApplyRequest.class));
    }
    @Test
    void applyToolAllowsRealApplyOnlyWithMatchingDryRunToken() {
        AutoTuneApplyResult dryRunResult = new AutoTuneApplyResult();
        dryRunResult.setDryRun(true);
        dryRunResult.setSuccess(true);
        AutoTuneApplyResult applyResult = new AutoTuneApplyResult();
        applyResult.setDryRun(false);
        applyResult.setSuccess(true);
        when(autoTuneApplyService.apply(any(AutoTuneApplyRequest.class))).thenReturn(dryRunResult, applyResult);
        AutoTuneChangeCommand command = change(" sys_config ", "ignored", " conveyorStationTaskLimit ", " 12 ");
        List<AutoTuneChangeCommand> changes = Collections.singletonList(command);
        AutoTuneApplyResult preview = tools.applyAutoTuneChanges("preview", 10, "agent", true, null, changes);
        AutoTuneApplyResult applied = tools.applyAutoTuneChanges("apply", 10, "agent", false,
                preview.getDryRunToken(), changes);
        assertSame(applyResult, applied);
        ArgumentCaptor<AutoTuneApplyRequest> captor = ArgumentCaptor.forClass(AutoTuneApplyRequest.class);
        verify(autoTuneApplyService, times(2)).apply(captor.capture());
        assertEquals(Boolean.TRUE, captor.getAllValues().get(0).getDryRun());
        assertEquals(Boolean.FALSE, captor.getAllValues().get(1).getDryRun());
    }
    @Test
    void applyToolRejectsMismatchedDryRunToken() {
        AutoTuneApplyResult dryRunResult = new AutoTuneApplyResult();
        dryRunResult.setDryRun(true);
        dryRunResult.setSuccess(true);
        when(autoTuneApplyService.apply(any(AutoTuneApplyRequest.class))).thenReturn(dryRunResult);
        AutoTuneApplyResult preview = tools.applyAutoTuneChanges("preview", 10, "agent", true, null,
                Collections.singletonList(change("sys_config", null, "conveyorStationTaskLimit", "12")));
        IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
                () -> tools.applyAutoTuneChanges("apply", 10, "agent", false, preview.getDryRunToken(),
                        Collections.singletonList(change("sys_config", null, "conveyorStationTaskLimit", "13"))));
        assertTrue(exception.getMessage().contains("does not match"));
        verify(autoTuneApplyService, times(1)).apply(any(AutoTuneApplyRequest.class));
    }
    @Test
@@ -152,7 +224,7 @@
        AiPromptTemplate promptTemplate = new AiPromptTemplate();
        promptTemplate.setContent("system prompt");
        when(aiPromptTemplateService.resolvePublished("wcs_auto_tune_dispatch")).thenReturn(promptTemplate);
        when(mcpToolManager.buildOpenAiTools()).thenReturn(Collections.singletonList(Collections.singletonMap("type", "function")));
        when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools());
        when(mcpToolManager.callTool(any(), any(JSONObject.class))).thenReturn(Collections.singletonMap("ok", true));
        when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any()))
                .thenReturn(
@@ -177,13 +249,111 @@
        ArgumentCaptor<String> toolNameCaptor = ArgumentCaptor.forClass(String.class);
        ArgumentCaptor<JSONObject> argumentCaptor = ArgumentCaptor.forClass(JSONObject.class);
        verify(mcpToolManager, org.mockito.Mockito.times(3)).callTool(toolNameCaptor.capture(), argumentCaptor.capture());
        verify(mcpToolManager, times(3)).callTool(toolNameCaptor.capture(), argumentCaptor.capture());
        assertEquals("wcs_local_dispatch_get_auto_tune_snapshot", toolNameCaptor.getAllValues().get(0));
        assertEquals("wcs_local_dispatch_apply_auto_tune_changes", toolNameCaptor.getAllValues().get(1));
        assertEquals(Boolean.TRUE, argumentCaptor.getAllValues().get(1).getBoolean("dryRun"));
        assertEquals(Boolean.FALSE, argumentCaptor.getAllValues().get(2).getBoolean("dryRun"));
    }
    @Test
    void agentFailsAndDoesNotExecuteDisallowedToolCall() {
        AutoTuneAgentServiceImpl service = agentService();
        when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools());
        when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any()))
                .thenReturn(response("bad tool", toolCall("call_1", "wcs_local_device_get_crn_status", "{}"), 10, 5));
        AutoTuneAgentService.AutoTuneAgentResult result = service.runAutoTune("scheduler");
        assertFalse(result.getSuccess());
        assertEquals(0, result.getToolCallCount());
        assertTrue(result.getSummary().contains("Disallowed auto-tune MCP tool"));
        verify(mcpToolManager, never()).callTool(any(), any(JSONObject.class));
    }
    @Test
    void agentFailsWhenAllowedToolThrows() {
        AutoTuneAgentServiceImpl service = agentService();
        when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools());
        when(mcpToolManager.callTool(any(), any(JSONObject.class))).thenThrow(new RuntimeException("boom"));
        when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any()))
                .thenReturn(response("snapshot", toolCall("call_1", "wcs_local_dispatch_get_auto_tune_snapshot", "{}"), 10, 5));
        AutoTuneAgentService.AutoTuneAgentResult result = service.runAutoTune("scheduler");
        assertFalse(result.getSuccess());
        assertEquals(0, result.getToolCallCount());
        assertTrue(result.getSummary().contains("Auto-tune MCP tool failed"));
        verify(mcpToolManager).callTool(any(), any(JSONObject.class));
    }
    @Test
    void agentFailsWhenLlmReturnsNoToolCalls() {
        AutoTuneAgentServiceImpl service = agentService();
        when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools());
        when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any()))
                .thenReturn(response("no changes needed", null, 10, 5));
        AutoTuneAgentService.AutoTuneAgentResult result = service.runAutoTune("scheduler");
        assertFalse(result.getSuccess());
        assertEquals(0, result.getToolCallCount());
        assertTrue(result.getSummary().contains("未调用任何允许的 MCP 工具"));
    }
    @Test
    void agentFailsWhenMaxRoundsReached() {
        AutoTuneAgentServiceImpl service = agentService();
        when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools());
        when(mcpToolManager.callTool(any(), any(JSONObject.class))).thenReturn(Collections.singletonMap("ok", true));
        when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any()))
                .thenReturn(response("keep going", toolCall("call_1", "wcs_local_dispatch_get_auto_tune_snapshot", "{}"), 10, 5));
        AutoTuneAgentService.AutoTuneAgentResult result = service.runAutoTune("scheduler");
        assertFalse(result.getSuccess());
        assertEquals(10, result.getToolCallCount());
        assertTrue(result.getMaxRoundsReached());
        assertTrue(result.getSummary().contains("达到最大工具调用轮次"));
    }
    private AutoTuneAgentServiceImpl agentService() {
        AiPromptTemplate promptTemplate = new AiPromptTemplate();
        promptTemplate.setContent("system prompt");
        when(aiPromptTemplateService.resolvePublished("wcs_auto_tune_dispatch")).thenReturn(promptTemplate);
        return new AutoTuneAgentServiceImpl(llmChatService, mcpToolManager, aiPromptTemplateService);
    }
    private AutoTuneChangeCommand change(String targetType, String targetId, String targetKey, String newValue) {
        AutoTuneChangeCommand command = new AutoTuneChangeCommand();
        command.setTargetType(targetType);
        command.setTargetId(targetId);
        command.setTargetKey(targetKey);
        command.setNewValue(newValue);
        return command;
    }
    private List<Object> allowedOpenAiTools() {
        List<Object> tools = new ArrayList<>();
        tools.add(openAiTool("wcs_local_dispatch_get_auto_tune_snapshot"));
        tools.add(openAiTool("wcs_local_dispatch_get_recent_auto_tune_jobs"));
        tools.add(openAiTool("wcs_local_dispatch_apply_auto_tune_changes"));
        tools.add(openAiTool("wcs_local_dispatch_revert_last_auto_tune_job"));
        tools.add(openAiTool("wcs_local_device_get_crn_status"));
        return tools;
    }
    private Map<String, Object> openAiTool(String name) {
        LinkedHashMap<String, Object> function = new LinkedHashMap<>();
        function.put("name", name);
        function.put("parameters", Collections.emptyMap());
        LinkedHashMap<String, Object> tool = new LinkedHashMap<>();
        tool.put("type", "function");
        tool.put("function", function);
        return tool;
    }
    private ChatCompletionResponse response(String content,
                                            ChatCompletionRequest.ToolCall toolCall,
                                            int promptTokens,