| | |
| | | |
| | | import java.util.ArrayList; |
| | | import java.util.Collections; |
| | | import java.util.LinkedHashMap; |
| | | import java.util.List; |
| | | import java.util.Map; |
| | | |
| | | import static org.junit.jupiter.api.Assertions.assertEquals; |
| | | import static org.junit.jupiter.api.Assertions.assertFalse; |
| | | import static org.junit.jupiter.api.Assertions.assertNotNull; |
| | | import static org.junit.jupiter.api.Assertions.assertSame; |
| | | import static org.junit.jupiter.api.Assertions.assertThrows; |
| | | import static org.junit.jupiter.api.Assertions.assertTrue; |
| | | import static org.mockito.ArgumentMatchers.any; |
| | | import static org.mockito.ArgumentMatchers.anyDouble; |
| | | import static org.mockito.ArgumentMatchers.anyInt; |
| | | import static org.mockito.Mockito.never; |
| | | import static org.mockito.Mockito.times; |
| | | import static org.mockito.Mockito.verify; |
| | | import static org.mockito.Mockito.when; |
| | | |
| | |
| | | void applyToolDelegatesToApplyServiceWithDryRunAndChanges() { |
| | | AutoTuneApplyResult expected = new AutoTuneApplyResult(); |
| | | expected.setDryRun(true); |
| | | expected.setSuccess(true); |
| | | when(autoTuneApplyService.apply(any(AutoTuneApplyRequest.class))).thenReturn(expected); |
| | | |
| | | AutoTuneChangeCommand command = new AutoTuneChangeCommand(); |
| | |
| | | command.setNewValue("12"); |
| | | List<AutoTuneChangeCommand> changes = Collections.singletonList(command); |
| | | |
| | | AutoTuneApplyResult result = tools.applyAutoTuneChanges("reduce congestion", 10, "agent", true, changes); |
| | | AutoTuneApplyResult result = tools.applyAutoTuneChanges("reduce congestion", 10, "agent", true, null, changes); |
| | | |
| | | assertSame(expected, result); |
| | | assertNotNull(result.getDryRunToken()); |
| | | ArgumentCaptor<AutoTuneApplyRequest> captor = ArgumentCaptor.forClass(AutoTuneApplyRequest.class); |
| | | verify(autoTuneApplyService).apply(captor.capture()); |
| | | assertEquals("reduce congestion", captor.getValue().getReason()); |
| | |
| | | assertEquals("agent", captor.getValue().getTriggerType()); |
| | | assertEquals(Boolean.TRUE, captor.getValue().getDryRun()); |
| | | assertSame(changes, captor.getValue().getChanges()); |
| | | } |
| | | |
| | | @Test |
| | | void applyToolRejectsMissingDryRunBeforeServiceCall() { |
| | | AutoTuneChangeCommand command = change("sys_config", null, "conveyorStationTaskLimit", "12"); |
| | | |
| | | IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, |
| | | () -> tools.applyAutoTuneChanges("missing dryRun", 10, "agent", null, null, |
| | | Collections.singletonList(command))); |
| | | |
| | | assertTrue(exception.getMessage().contains("dryRun is required")); |
| | | verify(autoTuneApplyService, never()).apply(any(AutoTuneApplyRequest.class)); |
| | | } |
| | | |
| | | @Test |
| | | void applyToolRejectsDirectRealApplyWithoutDryRunToken() { |
| | | AutoTuneChangeCommand command = change("sys_config", null, "conveyorStationTaskLimit", "12"); |
| | | |
| | | IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, |
| | | () -> tools.applyAutoTuneChanges("direct apply", 10, "agent", false, null, |
| | | Collections.singletonList(command))); |
| | | |
| | | assertTrue(exception.getMessage().contains("dryRunToken is required")); |
| | | verify(autoTuneApplyService, never()).apply(any(AutoTuneApplyRequest.class)); |
| | | } |
| | | |
| | | @Test |
| | | void applyToolAllowsRealApplyOnlyWithMatchingDryRunToken() { |
| | | AutoTuneApplyResult dryRunResult = new AutoTuneApplyResult(); |
| | | dryRunResult.setDryRun(true); |
| | | dryRunResult.setSuccess(true); |
| | | AutoTuneApplyResult applyResult = new AutoTuneApplyResult(); |
| | | applyResult.setDryRun(false); |
| | | applyResult.setSuccess(true); |
| | | when(autoTuneApplyService.apply(any(AutoTuneApplyRequest.class))).thenReturn(dryRunResult, applyResult); |
| | | AutoTuneChangeCommand command = change(" sys_config ", "ignored", " conveyorStationTaskLimit ", " 12 "); |
| | | List<AutoTuneChangeCommand> changes = Collections.singletonList(command); |
| | | |
| | | AutoTuneApplyResult preview = tools.applyAutoTuneChanges("preview", 10, "agent", true, null, changes); |
| | | AutoTuneApplyResult applied = tools.applyAutoTuneChanges("apply", 10, "agent", false, |
| | | preview.getDryRunToken(), changes); |
| | | |
| | | assertSame(applyResult, applied); |
| | | ArgumentCaptor<AutoTuneApplyRequest> captor = ArgumentCaptor.forClass(AutoTuneApplyRequest.class); |
| | | verify(autoTuneApplyService, times(2)).apply(captor.capture()); |
| | | assertEquals(Boolean.TRUE, captor.getAllValues().get(0).getDryRun()); |
| | | assertEquals(Boolean.FALSE, captor.getAllValues().get(1).getDryRun()); |
| | | } |
| | | |
| | | @Test |
| | | void applyToolRejectsMismatchedDryRunToken() { |
| | | AutoTuneApplyResult dryRunResult = new AutoTuneApplyResult(); |
| | | dryRunResult.setDryRun(true); |
| | | dryRunResult.setSuccess(true); |
| | | when(autoTuneApplyService.apply(any(AutoTuneApplyRequest.class))).thenReturn(dryRunResult); |
| | | |
| | | AutoTuneApplyResult preview = tools.applyAutoTuneChanges("preview", 10, "agent", true, null, |
| | | Collections.singletonList(change("sys_config", null, "conveyorStationTaskLimit", "12"))); |
| | | |
| | | IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, |
| | | () -> tools.applyAutoTuneChanges("apply", 10, "agent", false, preview.getDryRunToken(), |
| | | Collections.singletonList(change("sys_config", null, "conveyorStationTaskLimit", "13")))); |
| | | |
| | | assertTrue(exception.getMessage().contains("does not match")); |
| | | verify(autoTuneApplyService, times(1)).apply(any(AutoTuneApplyRequest.class)); |
| | | } |
| | | |
| | | @Test |
| | |
| | | AiPromptTemplate promptTemplate = new AiPromptTemplate(); |
| | | promptTemplate.setContent("system prompt"); |
| | | when(aiPromptTemplateService.resolvePublished("wcs_auto_tune_dispatch")).thenReturn(promptTemplate); |
| | | when(mcpToolManager.buildOpenAiTools()).thenReturn(Collections.singletonList(Collections.singletonMap("type", "function"))); |
| | | when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools()); |
| | | when(mcpToolManager.callTool(any(), any(JSONObject.class))).thenReturn(Collections.singletonMap("ok", true)); |
| | | when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any())) |
| | | .thenReturn( |
| | |
| | | |
| | | ArgumentCaptor<String> toolNameCaptor = ArgumentCaptor.forClass(String.class); |
| | | ArgumentCaptor<JSONObject> argumentCaptor = ArgumentCaptor.forClass(JSONObject.class); |
| | | verify(mcpToolManager, org.mockito.Mockito.times(3)).callTool(toolNameCaptor.capture(), argumentCaptor.capture()); |
| | | verify(mcpToolManager, times(3)).callTool(toolNameCaptor.capture(), argumentCaptor.capture()); |
| | | assertEquals("wcs_local_dispatch_get_auto_tune_snapshot", toolNameCaptor.getAllValues().get(0)); |
| | | assertEquals("wcs_local_dispatch_apply_auto_tune_changes", toolNameCaptor.getAllValues().get(1)); |
| | | assertEquals(Boolean.TRUE, argumentCaptor.getAllValues().get(1).getBoolean("dryRun")); |
| | | assertEquals(Boolean.FALSE, argumentCaptor.getAllValues().get(2).getBoolean("dryRun")); |
| | | } |
| | | |
| | | @Test |
| | | void agentFailsAndDoesNotExecuteDisallowedToolCall() { |
| | | AutoTuneAgentServiceImpl service = agentService(); |
| | | when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools()); |
| | | when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any())) |
| | | .thenReturn(response("bad tool", toolCall("call_1", "wcs_local_device_get_crn_status", "{}"), 10, 5)); |
| | | |
| | | AutoTuneAgentService.AutoTuneAgentResult result = service.runAutoTune("scheduler"); |
| | | |
| | | assertFalse(result.getSuccess()); |
| | | assertEquals(0, result.getToolCallCount()); |
| | | assertTrue(result.getSummary().contains("Disallowed auto-tune MCP tool")); |
| | | verify(mcpToolManager, never()).callTool(any(), any(JSONObject.class)); |
| | | } |
| | | |
| | | @Test |
| | | void agentFailsWhenAllowedToolThrows() { |
| | | AutoTuneAgentServiceImpl service = agentService(); |
| | | when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools()); |
| | | when(mcpToolManager.callTool(any(), any(JSONObject.class))).thenThrow(new RuntimeException("boom")); |
| | | when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any())) |
| | | .thenReturn(response("snapshot", toolCall("call_1", "wcs_local_dispatch_get_auto_tune_snapshot", "{}"), 10, 5)); |
| | | |
| | | AutoTuneAgentService.AutoTuneAgentResult result = service.runAutoTune("scheduler"); |
| | | |
| | | assertFalse(result.getSuccess()); |
| | | assertEquals(0, result.getToolCallCount()); |
| | | assertTrue(result.getSummary().contains("Auto-tune MCP tool failed")); |
| | | verify(mcpToolManager).callTool(any(), any(JSONObject.class)); |
| | | } |
| | | |
| | | @Test |
| | | void agentFailsWhenLlmReturnsNoToolCalls() { |
| | | AutoTuneAgentServiceImpl service = agentService(); |
| | | when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools()); |
| | | when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any())) |
| | | .thenReturn(response("no changes needed", null, 10, 5)); |
| | | |
| | | AutoTuneAgentService.AutoTuneAgentResult result = service.runAutoTune("scheduler"); |
| | | |
| | | assertFalse(result.getSuccess()); |
| | | assertEquals(0, result.getToolCallCount()); |
| | | assertTrue(result.getSummary().contains("未调用任何允许的 MCP 工具")); |
| | | } |
| | | |
| | | @Test |
| | | void agentFailsWhenMaxRoundsReached() { |
| | | AutoTuneAgentServiceImpl service = agentService(); |
| | | when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools()); |
| | | when(mcpToolManager.callTool(any(), any(JSONObject.class))).thenReturn(Collections.singletonMap("ok", true)); |
| | | when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any())) |
| | | .thenReturn(response("keep going", toolCall("call_1", "wcs_local_dispatch_get_auto_tune_snapshot", "{}"), 10, 5)); |
| | | |
| | | AutoTuneAgentService.AutoTuneAgentResult result = service.runAutoTune("scheduler"); |
| | | |
| | | assertFalse(result.getSuccess()); |
| | | assertEquals(10, result.getToolCallCount()); |
| | | assertTrue(result.getMaxRoundsReached()); |
| | | assertTrue(result.getSummary().contains("达到最大工具调用轮次")); |
| | | } |
| | | |
| | | private AutoTuneAgentServiceImpl agentService() { |
| | | AiPromptTemplate promptTemplate = new AiPromptTemplate(); |
| | | promptTemplate.setContent("system prompt"); |
| | | when(aiPromptTemplateService.resolvePublished("wcs_auto_tune_dispatch")).thenReturn(promptTemplate); |
| | | return new AutoTuneAgentServiceImpl(llmChatService, mcpToolManager, aiPromptTemplateService); |
| | | } |
| | | |
| | | private AutoTuneChangeCommand change(String targetType, String targetId, String targetKey, String newValue) { |
| | | AutoTuneChangeCommand command = new AutoTuneChangeCommand(); |
| | | command.setTargetType(targetType); |
| | | command.setTargetId(targetId); |
| | | command.setTargetKey(targetKey); |
| | | command.setNewValue(newValue); |
| | | return command; |
| | | } |
| | | |
| | | private List<Object> allowedOpenAiTools() { |
| | | List<Object> tools = new ArrayList<>(); |
| | | tools.add(openAiTool("wcs_local_dispatch_get_auto_tune_snapshot")); |
| | | tools.add(openAiTool("wcs_local_dispatch_get_recent_auto_tune_jobs")); |
| | | tools.add(openAiTool("wcs_local_dispatch_apply_auto_tune_changes")); |
| | | tools.add(openAiTool("wcs_local_dispatch_revert_last_auto_tune_job")); |
| | | tools.add(openAiTool("wcs_local_device_get_crn_status")); |
| | | return tools; |
| | | } |
| | | |
| | | private Map<String, Object> openAiTool(String name) { |
| | | LinkedHashMap<String, Object> function = new LinkedHashMap<>(); |
| | | function.put("name", name); |
| | | function.put("parameters", Collections.emptyMap()); |
| | | |
| | | LinkedHashMap<String, Object> tool = new LinkedHashMap<>(); |
| | | tool.put("type", "function"); |
| | | tool.put("function", function); |
| | | return tool; |
| | | } |
| | | |
| | | private ChatCompletionResponse response(String content, |
| | | ChatCompletionRequest.ToolCall toolCall, |
| | | int promptTokens, |