package com.zy.ai.service; import com.alibaba.fastjson.JSONObject; import com.baomidou.mybatisplus.core.conditions.Wrapper; import com.zy.ai.domain.autotune.AutoTuneApplyRequest; import com.zy.ai.domain.autotune.AutoTuneApplyResult; import com.zy.ai.domain.autotune.AutoTuneChangeCommand; import com.zy.ai.domain.autotune.AutoTuneSnapshot; import com.zy.ai.entity.AiAutoTuneChange; import com.zy.ai.entity.AiAutoTuneJob; import com.zy.ai.entity.AiPromptTemplate; import com.zy.ai.entity.ChatCompletionRequest; import com.zy.ai.entity.ChatCompletionResponse; import com.zy.ai.mcp.service.SpringAiMcpToolManager; import com.zy.ai.mcp.tool.AutoTuneMcpTools; import com.zy.ai.service.impl.AutoTuneAgentServiceImpl; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import java.util.ArrayList; import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyDouble; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) class AutoTuneCoordinatorServiceImplTest { private AutoTuneMcpTools tools; @Mock private AutoTuneSnapshotService autoTuneSnapshotService; @Mock private AutoTuneApplyService autoTuneApplyService; @Mock private AiAutoTuneJobService aiAutoTuneJobService; @Mock private AiAutoTuneChangeService aiAutoTuneChangeService; @Mock private LlmChatService llmChatService; @Mock private SpringAiMcpToolManager mcpToolManager; @Mock private AiPromptTemplateService aiPromptTemplateService; @BeforeEach void setUp() { tools = new AutoTuneMcpTools( autoTuneSnapshotService, autoTuneApplyService, aiAutoTuneJobService, aiAutoTuneChangeService); } @Test void snapshotToolDelegatesToSnapshotService() { AutoTuneSnapshot snapshot = new AutoTuneSnapshot(); when(autoTuneSnapshotService.buildSnapshot()).thenReturn(snapshot); AutoTuneSnapshot result = tools.getAutoTuneSnapshot(); assertSame(snapshot, result); verify(autoTuneSnapshotService).buildSnapshot(); } @Test void recentJobsReturnsBoundedCompactSummariesWithChanges() { AiAutoTuneJob job = new AiAutoTuneJob(); job.setId(7L); job.setTriggerType("agent"); job.setStatus("success"); job.setSummary("applied"); job.setSuccessCount(1); job.setRejectCount(0); AiAutoTuneChange change = new AiAutoTuneChange(); change.setJobId(7L); change.setTargetType("sys_config"); change.setTargetKey("conveyorStationTaskLimit"); change.setRequestedValue("12"); change.setResultStatus("success"); when(aiAutoTuneJobService.list(any(Wrapper.class))).thenReturn(Collections.singletonList(job)); when(aiAutoTuneChangeService.list(any(Wrapper.class))).thenReturn(Collections.singletonList(change)); List> result = tools.getRecentAutoTuneJobs(99); assertEquals(1, result.size()); assertEquals(7L, result.get(0).get("id")); assertFalse(result.get(0).containsKey("reasoningDigest")); List changes = (List) result.get(0).get("changes"); assertEquals(1, changes.size()); ArgumentCaptor> wrapperCaptor = ArgumentCaptor.forClass(Wrapper.class); verify(aiAutoTuneJobService).list(wrapperCaptor.capture()); assertTrue(wrapperCaptor.getValue().getSqlSegment().contains("limit 20")); } @Test void applyToolDelegatesToApplyServiceWithDryRunAndChanges() { AutoTuneApplyResult expected = new AutoTuneApplyResult(); expected.setDryRun(true); expected.setSuccess(true); when(autoTuneApplyService.apply(any(AutoTuneApplyRequest.class))).thenReturn(expected); AutoTuneChangeCommand command = new AutoTuneChangeCommand(); command.setTargetType("sys_config"); command.setTargetKey("conveyorStationTaskLimit"); command.setNewValue("12"); List changes = Collections.singletonList(command); AutoTuneApplyResult result = tools.applyAutoTuneChanges("reduce congestion", 10, "agent", true, null, changes); assertSame(expected, result); assertNotNull(result.getDryRunToken()); ArgumentCaptor captor = ArgumentCaptor.forClass(AutoTuneApplyRequest.class); verify(autoTuneApplyService).apply(captor.capture()); assertEquals("reduce congestion", captor.getValue().getReason()); assertEquals(10, captor.getValue().getAnalysisIntervalMinutes()); assertEquals("agent", captor.getValue().getTriggerType()); assertEquals(Boolean.TRUE, captor.getValue().getDryRun()); assertSame(changes, captor.getValue().getChanges()); } @Test void applyToolRejectsMissingDryRunBeforeServiceCall() { AutoTuneChangeCommand command = change("sys_config", null, "conveyorStationTaskLimit", "12"); IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> tools.applyAutoTuneChanges("missing dryRun", 10, "agent", null, null, Collections.singletonList(command))); assertTrue(exception.getMessage().contains("dryRun is required")); verify(autoTuneApplyService, never()).apply(any(AutoTuneApplyRequest.class)); } @Test void applyToolRejectsDirectRealApplyWithoutDryRunToken() { AutoTuneChangeCommand command = change("sys_config", null, "conveyorStationTaskLimit", "12"); IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> tools.applyAutoTuneChanges("direct apply", 10, "agent", false, null, Collections.singletonList(command))); assertTrue(exception.getMessage().contains("dryRunToken is required")); verify(autoTuneApplyService, never()).apply(any(AutoTuneApplyRequest.class)); } @Test void applyToolAllowsRealApplyOnlyWithMatchingDryRunToken() { AutoTuneApplyResult dryRunResult = new AutoTuneApplyResult(); dryRunResult.setDryRun(true); dryRunResult.setSuccess(true); AutoTuneApplyResult applyResult = new AutoTuneApplyResult(); applyResult.setDryRun(false); applyResult.setSuccess(true); when(autoTuneApplyService.apply(any(AutoTuneApplyRequest.class))).thenReturn(dryRunResult, applyResult); AutoTuneChangeCommand command = change(" sys_config ", "ignored", " conveyorStationTaskLimit ", " 12 "); List changes = Collections.singletonList(command); AutoTuneApplyResult preview = tools.applyAutoTuneChanges("preview", 10, "agent", true, null, changes); AutoTuneApplyResult applied = tools.applyAutoTuneChanges("apply", 10, "agent", false, preview.getDryRunToken(), changes); assertSame(applyResult, applied); ArgumentCaptor captor = ArgumentCaptor.forClass(AutoTuneApplyRequest.class); verify(autoTuneApplyService, times(2)).apply(captor.capture()); assertEquals(Boolean.TRUE, captor.getAllValues().get(0).getDryRun()); assertEquals(Boolean.FALSE, captor.getAllValues().get(1).getDryRun()); } @Test void applyToolRejectsMismatchedDryRunToken() { AutoTuneApplyResult dryRunResult = new AutoTuneApplyResult(); dryRunResult.setDryRun(true); dryRunResult.setSuccess(true); when(autoTuneApplyService.apply(any(AutoTuneApplyRequest.class))).thenReturn(dryRunResult); AutoTuneApplyResult preview = tools.applyAutoTuneChanges("preview", 10, "agent", true, null, Collections.singletonList(change("sys_config", null, "conveyorStationTaskLimit", "12"))); IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> tools.applyAutoTuneChanges("apply", 10, "agent", false, preview.getDryRunToken(), Collections.singletonList(change("sys_config", null, "conveyorStationTaskLimit", "13")))); assertTrue(exception.getMessage().contains("does not match")); verify(autoTuneApplyService, times(1)).apply(any(AutoTuneApplyRequest.class)); } @Test void rollbackToolDelegatesToApplyServiceRollback() { AutoTuneApplyResult expected = new AutoTuneApplyResult(); when(autoTuneApplyService.rollbackLastSuccessfulJob("bad result")).thenReturn(expected); AutoTuneApplyResult result = tools.revertLastAutoTuneJob("bad result"); assertSame(expected, result); verify(autoTuneApplyService).rollbackLastSuccessfulJob("bad result"); } @Test void agentExecutesSnapshotDryRunAndRealApplyToolSequence() { AutoTuneAgentServiceImpl service = new AutoTuneAgentServiceImpl( llmChatService, mcpToolManager, aiPromptTemplateService); AiPromptTemplate promptTemplate = new AiPromptTemplate(); promptTemplate.setContent("system prompt"); when(aiPromptTemplateService.resolvePublished("wcs_auto_tune_dispatch")).thenReturn(promptTemplate); when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools()); when(mcpToolManager.callTool(any(), any(JSONObject.class))).thenReturn(Collections.singletonMap("ok", true)); when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any())) .thenReturn( response("read snapshot", toolCall("call_1", "wcs_local_dispatch_get_auto_tune_snapshot", "{}"), 10, 5), response("dry-run first", toolCall("call_2", "wcs_local_dispatch_apply_auto_tune_changes", "{\"dryRun\":true,\"changes\":[{\"targetType\":\"sys_config\",\"targetKey\":\"conveyorStationTaskLimit\",\"newValue\":\"12\"}]}"), 11, 6), response("apply after dry-run", toolCall("call_3", "wcs_local_dispatch_apply_auto_tune_changes", "{\"dryRun\":false,\"changes\":[{\"targetType\":\"sys_config\",\"targetKey\":\"conveyorStationTaskLimit\",\"newValue\":\"12\"}]}"), 12, 7), response("已完成自动调参", null, 13, 8) ); AutoTuneAgentService.AutoTuneAgentResult result = service.runAutoTune("scheduler"); assertTrue(result.getSuccess()); assertEquals("scheduler", result.getTriggerType()); assertEquals(3, result.getToolCallCount()); assertEquals(4, result.getLlmCallCount()); assertEquals(46L, result.getPromptTokens()); assertEquals(26L, result.getCompletionTokens()); assertEquals(72L, result.getTotalTokens()); assertTrue(result.getSummary().contains("已完成自动调参")); ArgumentCaptor toolNameCaptor = ArgumentCaptor.forClass(String.class); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(JSONObject.class); verify(mcpToolManager, times(3)).callTool(toolNameCaptor.capture(), argumentCaptor.capture()); assertEquals("wcs_local_dispatch_get_auto_tune_snapshot", toolNameCaptor.getAllValues().get(0)); assertEquals("wcs_local_dispatch_apply_auto_tune_changes", toolNameCaptor.getAllValues().get(1)); assertEquals(Boolean.TRUE, argumentCaptor.getAllValues().get(1).getBoolean("dryRun")); assertEquals(Boolean.FALSE, argumentCaptor.getAllValues().get(2).getBoolean("dryRun")); } @Test void agentFailsAndDoesNotExecuteDisallowedToolCall() { AutoTuneAgentServiceImpl service = agentService(); when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools()); when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any())) .thenReturn(response("bad tool", toolCall("call_1", "wcs_local_device_get_crn_status", "{}"), 10, 5)); AutoTuneAgentService.AutoTuneAgentResult result = service.runAutoTune("scheduler"); assertFalse(result.getSuccess()); assertEquals(0, result.getToolCallCount()); assertTrue(result.getSummary().contains("Disallowed auto-tune MCP tool")); verify(mcpToolManager, never()).callTool(any(), any(JSONObject.class)); } @Test void agentFailsWhenAllowedToolThrows() { AutoTuneAgentServiceImpl service = agentService(); when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools()); when(mcpToolManager.callTool(any(), any(JSONObject.class))).thenThrow(new RuntimeException("boom")); when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any())) .thenReturn(response("snapshot", toolCall("call_1", "wcs_local_dispatch_get_auto_tune_snapshot", "{}"), 10, 5)); AutoTuneAgentService.AutoTuneAgentResult result = service.runAutoTune("scheduler"); assertFalse(result.getSuccess()); assertEquals(0, result.getToolCallCount()); assertTrue(result.getSummary().contains("Auto-tune MCP tool failed")); verify(mcpToolManager).callTool(any(), any(JSONObject.class)); } @Test void agentFailsWhenLlmReturnsNoToolCalls() { AutoTuneAgentServiceImpl service = agentService(); when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools()); when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any())) .thenReturn(response("no changes needed", null, 10, 5)); AutoTuneAgentService.AutoTuneAgentResult result = service.runAutoTune("scheduler"); assertFalse(result.getSuccess()); assertEquals(0, result.getToolCallCount()); assertTrue(result.getSummary().contains("未调用任何允许的 MCP 工具")); } @Test void agentFailsWhenMaxRoundsReached() { AutoTuneAgentServiceImpl service = agentService(); when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools()); when(mcpToolManager.callTool(any(), any(JSONObject.class))).thenReturn(Collections.singletonMap("ok", true)); when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any())) .thenReturn(response("keep going", toolCall("call_1", "wcs_local_dispatch_get_auto_tune_snapshot", "{}"), 10, 5)); AutoTuneAgentService.AutoTuneAgentResult result = service.runAutoTune("scheduler"); assertFalse(result.getSuccess()); assertEquals(10, result.getToolCallCount()); assertTrue(result.getMaxRoundsReached()); assertTrue(result.getSummary().contains("达到最大工具调用轮次")); } private AutoTuneAgentServiceImpl agentService() { AiPromptTemplate promptTemplate = new AiPromptTemplate(); promptTemplate.setContent("system prompt"); when(aiPromptTemplateService.resolvePublished("wcs_auto_tune_dispatch")).thenReturn(promptTemplate); return new AutoTuneAgentServiceImpl(llmChatService, mcpToolManager, aiPromptTemplateService); } private AutoTuneChangeCommand change(String targetType, String targetId, String targetKey, String newValue) { AutoTuneChangeCommand command = new AutoTuneChangeCommand(); command.setTargetType(targetType); command.setTargetId(targetId); command.setTargetKey(targetKey); command.setNewValue(newValue); return command; } private List allowedOpenAiTools() { List tools = new ArrayList<>(); tools.add(openAiTool("wcs_local_dispatch_get_auto_tune_snapshot")); tools.add(openAiTool("wcs_local_dispatch_get_recent_auto_tune_jobs")); tools.add(openAiTool("wcs_local_dispatch_apply_auto_tune_changes")); tools.add(openAiTool("wcs_local_dispatch_revert_last_auto_tune_job")); tools.add(openAiTool("wcs_local_device_get_crn_status")); return tools; } private Map openAiTool(String name) { LinkedHashMap function = new LinkedHashMap<>(); function.put("name", name); function.put("parameters", Collections.emptyMap()); LinkedHashMap tool = new LinkedHashMap<>(); tool.put("type", "function"); tool.put("function", function); return tool; } private ChatCompletionResponse response(String content, ChatCompletionRequest.ToolCall toolCall, int promptTokens, int completionTokens) { ChatCompletionRequest.Message message = new ChatCompletionRequest.Message(); message.setRole("assistant"); message.setContent(content); if (toolCall != null) { message.setTool_calls(Collections.singletonList(toolCall)); } ChatCompletionResponse.Choice choice = new ChatCompletionResponse.Choice(); choice.setIndex(0); choice.setMessage(message); ChatCompletionResponse.Usage usage = new ChatCompletionResponse.Usage(); usage.setPromptTokens(promptTokens); usage.setCompletionTokens(completionTokens); usage.setTotalTokens(promptTokens + completionTokens); ChatCompletionResponse response = new ChatCompletionResponse(); List choices = new ArrayList<>(); choices.add(choice); response.setChoices(choices); response.setUsage(usage); return response; } private ChatCompletionRequest.ToolCall toolCall(String id, String name, String arguments) { ChatCompletionRequest.Function function = new ChatCompletionRequest.Function(); function.setName(name); function.setArguments(arguments); ChatCompletionRequest.ToolCall toolCall = new ChatCompletionRequest.ToolCall(); toolCall.setId(id); toolCall.setType("function"); toolCall.setFunction(function); return toolCall; } }