package com.zy.ai.service;
|
|
import com.alibaba.fastjson.JSONObject;
|
import com.baomidou.mybatisplus.core.conditions.Wrapper;
|
import com.zy.ai.domain.autotune.AutoTuneApplyRequest;
|
import com.zy.ai.domain.autotune.AutoTuneApplyResult;
|
import com.zy.ai.domain.autotune.AutoTuneChangeCommand;
|
import com.zy.ai.domain.autotune.AutoTuneSnapshot;
|
import com.zy.ai.domain.autotune.AutoTuneTriggerType;
|
import com.zy.ai.entity.AiAutoTuneChange;
|
import com.zy.ai.entity.AiAutoTuneJob;
|
import com.zy.ai.entity.AiPromptTemplate;
|
import com.zy.ai.entity.ChatCompletionRequest;
|
import com.zy.ai.entity.ChatCompletionResponse;
|
import com.zy.ai.mcp.service.SpringAiMcpToolManager;
|
import com.zy.ai.mcp.tool.AutoTuneMcpTools;
|
import com.zy.ai.service.impl.AutoTuneAgentServiceImpl;
|
import com.zy.ai.service.impl.AutoTuneCoordinatorServiceImpl;
|
import com.zy.asrs.service.WrkMastService;
|
import com.zy.common.utils.RedisUtil;
|
import com.zy.core.enums.RedisKeyType;
|
import com.zy.system.service.ConfigService;
|
import com.zy.system.service.OperateLogService;
|
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.extension.ExtendWith;
|
import org.mockito.ArgumentCaptor;
|
import org.mockito.Mock;
|
import org.mockito.junit.jupiter.MockitoExtension;
|
import org.springframework.test.util.ReflectionTestUtils;
|
|
import java.util.ArrayList;
|
import java.util.Collections;
|
import java.util.Date;
|
import java.util.LinkedHashMap;
|
import java.util.List;
|
import java.util.Map;
|
import java.util.concurrent.atomic.AtomicLong;
|
import java.util.function.LongSupplier;
|
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertFalse;
|
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
import static org.junit.jupiter.api.Assertions.assertSame;
|
import static org.junit.jupiter.api.Assertions.assertThrows;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
import static org.mockito.ArgumentMatchers.any;
|
import static org.mockito.ArgumentMatchers.anyDouble;
|
import static org.mockito.ArgumentMatchers.anyInt;
|
import static org.mockito.ArgumentMatchers.anyLong;
|
import static org.mockito.ArgumentMatchers.anyString;
|
import static org.mockito.ArgumentMatchers.eq;
|
import static org.mockito.Mockito.doThrow;
|
import static org.mockito.Mockito.never;
|
import static org.mockito.Mockito.times;
|
import static org.mockito.Mockito.verify;
|
import static org.mockito.Mockito.when;
|
|
@ExtendWith(MockitoExtension.class)
|
class AutoTuneCoordinatorServiceImplTest {
|
|
private AutoTuneMcpTools tools;
|
|
@Mock
|
private AutoTuneSnapshotService autoTuneSnapshotService;
|
@Mock
|
private AutoTuneApplyService autoTuneApplyService;
|
@Mock
|
private AiAutoTuneJobService aiAutoTuneJobService;
|
@Mock
|
private AiAutoTuneChangeService aiAutoTuneChangeService;
|
@Mock
|
private LlmChatService llmChatService;
|
@Mock
|
private SpringAiMcpToolManager mcpToolManager;
|
@Mock
|
private AiPromptTemplateService aiPromptTemplateService;
|
@Mock
|
private ConfigService configService;
|
@Mock
|
private WrkMastService wrkMastService;
|
@Mock
|
private AutoTuneAgentService autoTuneAgentService;
|
@Mock
|
private RedisUtil redisUtil;
|
@Mock
|
private OperateLogService operateLogService;
|
|
@BeforeEach
|
void setUp() {
|
tools = new AutoTuneMcpTools(
|
autoTuneSnapshotService,
|
autoTuneApplyService,
|
aiAutoTuneJobService,
|
aiAutoTuneChangeService);
|
}
|
|
@Test
|
void snapshotToolDelegatesToSnapshotService() {
|
AutoTuneSnapshot snapshot = new AutoTuneSnapshot();
|
when(autoTuneSnapshotService.buildSnapshot()).thenReturn(snapshot);
|
|
AutoTuneSnapshot result = tools.getAutoTuneSnapshot();
|
|
assertSame(snapshot, result);
|
verify(autoTuneSnapshotService).buildSnapshot();
|
}
|
|
@Test
|
void recentJobsReturnsBoundedCompactSummariesWithChanges() {
|
AiAutoTuneJob job = new AiAutoTuneJob();
|
job.setId(7L);
|
job.setTriggerType("agent");
|
job.setStatus("success");
|
job.setSummary("applied");
|
job.setSuccessCount(1);
|
job.setRejectCount(0);
|
AiAutoTuneChange change = new AiAutoTuneChange();
|
change.setJobId(7L);
|
change.setTargetType("sys_config");
|
change.setTargetKey("conveyorStationTaskLimit");
|
change.setRequestedValue("12");
|
change.setResultStatus("success");
|
|
when(aiAutoTuneJobService.list(any(Wrapper.class))).thenReturn(Collections.singletonList(job));
|
when(aiAutoTuneChangeService.list(any(Wrapper.class))).thenReturn(Collections.singletonList(change));
|
|
List<Map<String, Object>> result = tools.getRecentAutoTuneJobs(99);
|
|
assertEquals(1, result.size());
|
assertEquals(7L, result.get(0).get("id"));
|
assertFalse(result.get(0).containsKey("reasoningDigest"));
|
List<?> changes = (List<?>) result.get(0).get("changes");
|
assertEquals(1, changes.size());
|
|
ArgumentCaptor<Wrapper<AiAutoTuneJob>> wrapperCaptor = ArgumentCaptor.forClass(Wrapper.class);
|
verify(aiAutoTuneJobService).list(wrapperCaptor.capture());
|
assertTrue(wrapperCaptor.getValue().getSqlSegment().contains("limit 20"));
|
}
|
|
@Test
|
void applyToolDelegatesToApplyServiceWithDryRunAndChanges() {
|
AutoTuneApplyResult expected = new AutoTuneApplyResult();
|
expected.setDryRun(true);
|
expected.setSuccess(true);
|
when(autoTuneApplyService.apply(any(AutoTuneApplyRequest.class))).thenReturn(expected);
|
|
AutoTuneChangeCommand command = new AutoTuneChangeCommand();
|
command.setTargetType("sys_config");
|
command.setTargetKey("conveyorStationTaskLimit");
|
command.setNewValue("12");
|
List<AutoTuneChangeCommand> changes = Collections.singletonList(command);
|
|
AutoTuneApplyResult result = tools.applyAutoTuneChanges("reduce congestion", 10, "agent", true, null, changes);
|
|
assertSame(expected, result);
|
assertNotNull(result.getDryRunToken());
|
ArgumentCaptor<AutoTuneApplyRequest> captor = ArgumentCaptor.forClass(AutoTuneApplyRequest.class);
|
verify(autoTuneApplyService).apply(captor.capture());
|
assertEquals("reduce congestion", captor.getValue().getReason());
|
assertEquals(10, captor.getValue().getAnalysisIntervalMinutes());
|
assertEquals("agent", captor.getValue().getTriggerType());
|
assertEquals(Boolean.TRUE, captor.getValue().getDryRun());
|
assertSame(changes, captor.getValue().getChanges());
|
}
|
|
@Test
|
void applyToolRejectsMissingDryRunBeforeServiceCall() {
|
AutoTuneChangeCommand command = change("sys_config", null, "conveyorStationTaskLimit", "12");
|
|
IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
|
() -> tools.applyAutoTuneChanges("missing dryRun", 10, "agent", null, null,
|
Collections.singletonList(command)));
|
|
assertTrue(exception.getMessage().contains("dryRun is required"));
|
verify(autoTuneApplyService, never()).apply(any(AutoTuneApplyRequest.class));
|
}
|
|
@Test
|
void applyToolRejectsDirectRealApplyWithoutDryRunToken() {
|
AutoTuneChangeCommand command = change("sys_config", null, "conveyorStationTaskLimit", "12");
|
|
IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
|
() -> tools.applyAutoTuneChanges("direct apply", 10, "agent", false, null,
|
Collections.singletonList(command)));
|
|
assertTrue(exception.getMessage().contains("dryRunToken is required"));
|
verify(autoTuneApplyService, never()).apply(any(AutoTuneApplyRequest.class));
|
}
|
|
@Test
|
void applyToolAllowsRealApplyOnlyWithMatchingDryRunToken() {
|
AutoTuneApplyResult dryRunResult = new AutoTuneApplyResult();
|
dryRunResult.setDryRun(true);
|
dryRunResult.setSuccess(true);
|
AutoTuneApplyResult applyResult = new AutoTuneApplyResult();
|
applyResult.setDryRun(false);
|
applyResult.setSuccess(true);
|
when(autoTuneApplyService.apply(any(AutoTuneApplyRequest.class))).thenReturn(dryRunResult, applyResult);
|
AutoTuneChangeCommand command = change(" sys_config ", "ignored", " conveyorStationTaskLimit ", " 12 ");
|
List<AutoTuneChangeCommand> changes = Collections.singletonList(command);
|
|
AutoTuneApplyResult preview = tools.applyAutoTuneChanges("preview", 10, "agent", true, null, changes);
|
AutoTuneApplyResult applied = tools.applyAutoTuneChanges("apply", 10, "agent", false,
|
preview.getDryRunToken(), changes);
|
|
assertSame(applyResult, applied);
|
ArgumentCaptor<AutoTuneApplyRequest> captor = ArgumentCaptor.forClass(AutoTuneApplyRequest.class);
|
verify(autoTuneApplyService, times(2)).apply(captor.capture());
|
assertEquals(Boolean.TRUE, captor.getAllValues().get(0).getDryRun());
|
assertEquals(Boolean.FALSE, captor.getAllValues().get(1).getDryRun());
|
}
|
|
@Test
|
void applyToolRejectsMismatchedDryRunToken() {
|
AutoTuneApplyResult dryRunResult = new AutoTuneApplyResult();
|
dryRunResult.setDryRun(true);
|
dryRunResult.setSuccess(true);
|
when(autoTuneApplyService.apply(any(AutoTuneApplyRequest.class))).thenReturn(dryRunResult);
|
|
AutoTuneApplyResult preview = tools.applyAutoTuneChanges("preview", 10, "agent", true, null,
|
Collections.singletonList(change("sys_config", null, "conveyorStationTaskLimit", "12")));
|
|
IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
|
() -> tools.applyAutoTuneChanges("apply", 10, "agent", false, preview.getDryRunToken(),
|
Collections.singletonList(change("sys_config", null, "conveyorStationTaskLimit", "13"))));
|
|
assertTrue(exception.getMessage().contains("does not match"));
|
verify(autoTuneApplyService, times(1)).apply(any(AutoTuneApplyRequest.class));
|
}
|
|
@Test
|
void applyToolRejectsExpiredDryRunToken() {
|
AtomicLong currentTimeMillis = new AtomicLong(1_000L);
|
ReflectionTestUtils.invokeMethod(tools, "setCurrentTimeMillisSupplier", (LongSupplier) currentTimeMillis::get);
|
AutoTuneApplyResult dryRunResult = new AutoTuneApplyResult();
|
dryRunResult.setDryRun(true);
|
dryRunResult.setSuccess(true);
|
when(autoTuneApplyService.apply(any(AutoTuneApplyRequest.class))).thenReturn(dryRunResult);
|
List<AutoTuneChangeCommand> changes = Collections.singletonList(
|
change("sys_config", null, "conveyorStationTaskLimit", "12"));
|
|
AutoTuneApplyResult preview = tools.applyAutoTuneChanges("preview", 10, "agent", true, null, changes);
|
currentTimeMillis.addAndGet(10L * 60L * 1000L + 1L);
|
IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
|
() -> tools.applyAutoTuneChanges("apply", 10, "agent", false, preview.getDryRunToken(), changes));
|
|
assertTrue(exception.getMessage().contains("expired"));
|
verify(autoTuneApplyService, times(1)).apply(any(AutoTuneApplyRequest.class));
|
}
|
|
@Test
|
void rollbackToolDelegatesToApplyServiceRollback() {
|
AutoTuneApplyResult expected = new AutoTuneApplyResult();
|
when(autoTuneApplyService.rollbackLastSuccessfulJob("bad result")).thenReturn(expected);
|
|
AutoTuneApplyResult result = tools.revertLastAutoTuneJob("bad result");
|
|
assertSame(expected, result);
|
verify(autoTuneApplyService).rollbackLastSuccessfulJob("bad result");
|
}
|
|
@Test
|
void coordinatorSkipsWhenDisabled() {
|
when(configService.getConfigValue("aiAutoTuneEnabled", "N")).thenReturn("N");
|
|
AutoTuneCoordinatorService.AutoTuneCoordinatorResult result = coordinatorService().runAutoTuneIfEligible();
|
|
assertTrue(result.getSkipped());
|
assertEquals("disabled", result.getReason());
|
verify(wrkMastService, never()).count(any(Wrapper.class));
|
verify(autoTuneAgentService, never()).runAutoTune(anyString());
|
}
|
|
@Test
|
void coordinatorSkipsWhenNoActiveTasks() {
|
when(configService.getConfigValue("aiAutoTuneEnabled", "N")).thenReturn("Y");
|
when(configService.getConfigValue("aiAutoTuneIntervalMinutes", "10")).thenReturn("10");
|
when(wrkMastService.count(any(Wrapper.class))).thenReturn(0L);
|
|
AutoTuneCoordinatorService.AutoTuneCoordinatorResult result = coordinatorService().runAutoTuneIfEligible();
|
|
assertTrue(result.getSkipped());
|
assertEquals("no_active_tasks", result.getReason());
|
verify(autoTuneAgentService, never()).runAutoTune(anyString());
|
}
|
|
@Test
|
void coordinatorSkipsWhenIntervalNotReached() {
|
AiAutoTuneJob recentJob = new AiAutoTuneJob();
|
recentJob.setId(11L);
|
recentJob.setFinishTime(new Date());
|
when(configService.getConfigValue("aiAutoTuneEnabled", "N")).thenReturn("true");
|
when(configService.getConfigValue("aiAutoTuneIntervalMinutes", "10")).thenReturn("10");
|
when(wrkMastService.count(any(Wrapper.class))).thenReturn(1L);
|
when(redisUtil.get(RedisKeyType.AI_AUTO_TUNE_LAST_TRIGGER_GUARD.key)).thenReturn(null);
|
when(aiAutoTuneJobService.list(any(Wrapper.class))).thenReturn(Collections.singletonList(recentJob));
|
|
AutoTuneCoordinatorService.AutoTuneCoordinatorResult result = coordinatorService().runAutoTuneIfEligible();
|
|
assertTrue(result.getSkipped());
|
assertEquals("interval_not_reached", result.getReason());
|
verify(autoTuneAgentService, never()).runAutoTune(anyString());
|
}
|
|
@Test
|
void coordinatorTriggersAgentWhenEligible() {
|
AutoTuneAgentService.AutoTuneAgentResult agentResult = successfulAgentResult();
|
when(configService.getConfigValue("aiAutoTuneEnabled", "N")).thenReturn("1");
|
when(configService.getConfigValue("aiAutoTuneIntervalMinutes", "10")).thenReturn("10");
|
when(wrkMastService.count(any(Wrapper.class))).thenReturn(1L);
|
when(redisUtil.get(RedisKeyType.AI_AUTO_TUNE_LAST_TRIGGER_GUARD.key)).thenReturn(null);
|
when(aiAutoTuneJobService.list(any(Wrapper.class))).thenReturn(Collections.emptyList());
|
when(redisUtil.trySetStringIfAbsent(anyString(), anyString(), anyLong())).thenReturn(true);
|
when(autoTuneAgentService.runAutoTune(AutoTuneTriggerType.AUTO.getCode())).thenReturn(agentResult);
|
|
AutoTuneCoordinatorService.AutoTuneCoordinatorResult result = coordinatorService().runAutoTuneIfEligible();
|
|
assertFalse(result.getSkipped());
|
assertTrue(result.getTriggered());
|
assertSame(agentResult, result.getAgentResult());
|
verify(autoTuneAgentService).runAutoTune(AutoTuneTriggerType.AUTO.getCode());
|
verify(operateLogService).save(any());
|
verify(redisUtil).set(anyString(), any(), anyLong());
|
verify(redisUtil).compareAndDelete(anyString(), anyString());
|
}
|
|
@Test
|
void coordinatorKeepsAgentResultWhenOperateLogFails() {
|
AutoTuneAgentService.AutoTuneAgentResult agentResult = successfulAgentResult();
|
when(configService.getConfigValue("aiAutoTuneEnabled", "N")).thenReturn("Y");
|
when(configService.getConfigValue("aiAutoTuneIntervalMinutes", "10")).thenReturn("10");
|
when(wrkMastService.count(any(Wrapper.class))).thenReturn(1L);
|
when(redisUtil.get(RedisKeyType.AI_AUTO_TUNE_LAST_TRIGGER_GUARD.key)).thenReturn(null);
|
when(aiAutoTuneJobService.list(any(Wrapper.class))).thenReturn(Collections.emptyList());
|
when(redisUtil.trySetStringIfAbsent(anyString(), anyString(), anyLong())).thenReturn(true);
|
when(autoTuneAgentService.runAutoTune(AutoTuneTriggerType.AUTO.getCode())).thenReturn(agentResult);
|
doThrow(new RuntimeException("log failed")).when(operateLogService).save(any());
|
|
AutoTuneCoordinatorService.AutoTuneCoordinatorResult result = coordinatorService().runAutoTuneIfEligible();
|
|
assertFalse(result.getSkipped());
|
assertTrue(result.getTriggered());
|
assertSame(agentResult, result.getAgentResult());
|
verify(redisUtil).compareAndDelete(anyString(), anyString());
|
}
|
|
@Test
|
void coordinatorRunsAgentAndReleasesLockWhenGuardWriteFails() {
|
AutoTuneAgentService.AutoTuneAgentResult agentResult = successfulAgentResult();
|
when(configService.getConfigValue("aiAutoTuneEnabled", "N")).thenReturn("Y");
|
when(configService.getConfigValue("aiAutoTuneIntervalMinutes", "10")).thenReturn("10");
|
when(wrkMastService.count(any(Wrapper.class))).thenReturn(1L);
|
when(redisUtil.get(RedisKeyType.AI_AUTO_TUNE_LAST_TRIGGER_GUARD.key)).thenReturn(null);
|
when(aiAutoTuneJobService.list(any(Wrapper.class))).thenReturn(Collections.emptyList());
|
when(redisUtil.trySetStringIfAbsent(anyString(), anyString(), anyLong())).thenReturn(true);
|
doThrow(new RuntimeException("guard failed"))
|
.when(redisUtil)
|
.set(eq(RedisKeyType.AI_AUTO_TUNE_LAST_TRIGGER_GUARD.key), any(), eq(600L));
|
when(autoTuneAgentService.runAutoTune(AutoTuneTriggerType.AUTO.getCode())).thenReturn(agentResult);
|
|
AutoTuneCoordinatorService.AutoTuneCoordinatorResult result = coordinatorService().runAutoTuneIfEligible();
|
|
assertFalse(result.getSkipped());
|
assertTrue(result.getTriggered());
|
assertSame(agentResult, result.getAgentResult());
|
verify(autoTuneAgentService).runAutoTune(AutoTuneTriggerType.AUTO.getCode());
|
verify(redisUtil).compareAndDelete(anyString(), anyString());
|
}
|
|
@Test
|
void coordinatorSetsGuardWhenAgentReturnsFailure() {
|
AutoTuneAgentService.AutoTuneAgentResult agentResult = failedAgentResult();
|
when(configService.getConfigValue("aiAutoTuneEnabled", "N")).thenReturn("Y");
|
when(configService.getConfigValue("aiAutoTuneIntervalMinutes", "10")).thenReturn("10");
|
when(wrkMastService.count(any(Wrapper.class))).thenReturn(1L);
|
when(redisUtil.get(RedisKeyType.AI_AUTO_TUNE_LAST_TRIGGER_GUARD.key)).thenReturn(null);
|
when(aiAutoTuneJobService.list(any(Wrapper.class))).thenReturn(Collections.emptyList());
|
when(redisUtil.trySetStringIfAbsent(anyString(), anyString(), anyLong())).thenReturn(true);
|
when(autoTuneAgentService.runAutoTune(AutoTuneTriggerType.AUTO.getCode())).thenReturn(agentResult);
|
|
AutoTuneCoordinatorService.AutoTuneCoordinatorResult result = coordinatorService().runAutoTuneIfEligible();
|
|
assertFalse(result.getSkipped());
|
assertSame(agentResult, result.getAgentResult());
|
verify(redisUtil).set(eq(RedisKeyType.AI_AUTO_TUNE_LAST_TRIGGER_GUARD.key), any(), eq(600L));
|
}
|
|
@Test
|
void coordinatorSetsGuardWhenAgentThrows() {
|
when(configService.getConfigValue("aiAutoTuneEnabled", "N")).thenReturn("Y");
|
when(configService.getConfigValue("aiAutoTuneIntervalMinutes", "10")).thenReturn("10");
|
when(wrkMastService.count(any(Wrapper.class))).thenReturn(1L);
|
when(redisUtil.get(RedisKeyType.AI_AUTO_TUNE_LAST_TRIGGER_GUARD.key)).thenReturn(null);
|
when(aiAutoTuneJobService.list(any(Wrapper.class))).thenReturn(Collections.emptyList());
|
when(redisUtil.trySetStringIfAbsent(anyString(), anyString(), anyLong())).thenReturn(true);
|
when(autoTuneAgentService.runAutoTune(AutoTuneTriggerType.AUTO.getCode())).thenThrow(new RuntimeException("agent failed"));
|
|
AutoTuneCoordinatorService.AutoTuneCoordinatorResult result = coordinatorService().runAutoTuneIfEligible();
|
|
assertFalse(result.getSkipped());
|
assertFalse(result.getAgentResult().getSuccess());
|
verify(redisUtil).set(eq(RedisKeyType.AI_AUTO_TUNE_LAST_TRIGGER_GUARD.key), any(), eq(600L));
|
verify(redisUtil).compareAndDelete(anyString(), anyString());
|
}
|
|
@Test
|
void coordinatorSkipsWhenRunningLockIsNotAcquired() {
|
when(configService.getConfigValue("aiAutoTuneEnabled", "N")).thenReturn("Y");
|
when(configService.getConfigValue("aiAutoTuneIntervalMinutes", "10")).thenReturn("10");
|
when(wrkMastService.count(any(Wrapper.class))).thenReturn(1L);
|
when(redisUtil.get(RedisKeyType.AI_AUTO_TUNE_LAST_TRIGGER_GUARD.key)).thenReturn(null);
|
when(aiAutoTuneJobService.list(any(Wrapper.class))).thenReturn(Collections.emptyList());
|
when(redisUtil.trySetStringIfAbsent(anyString(), anyString(), anyLong())).thenReturn(false);
|
|
AutoTuneCoordinatorService.AutoTuneCoordinatorResult result = coordinatorService().runAutoTuneIfEligible();
|
|
assertTrue(result.getSkipped());
|
assertEquals("running_lock_not_acquired", result.getReason());
|
verify(autoTuneAgentService, never()).runAutoTune(anyString());
|
verify(redisUtil, never()).compareAndDelete(anyString(), anyString());
|
}
|
|
@Test
|
void agentExecutesSnapshotDryRunAndRealApplyToolSequence() {
|
AutoTuneAgentServiceImpl service = new AutoTuneAgentServiceImpl(
|
llmChatService,
|
mcpToolManager,
|
aiPromptTemplateService);
|
AiPromptTemplate promptTemplate = new AiPromptTemplate();
|
promptTemplate.setContent("system prompt");
|
when(aiPromptTemplateService.resolvePublished("wcs_auto_tune_dispatch")).thenReturn(promptTemplate);
|
when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools());
|
when(mcpToolManager.callTool(any(), any(JSONObject.class))).thenAnswer(invocation -> {
|
String toolName = invocation.getArgument(0);
|
JSONObject arguments = invocation.getArgument(1);
|
if ("wcs_local_dispatch_apply_auto_tune_changes".equals(toolName)
|
&& Boolean.TRUE.equals(arguments.getBoolean("dryRun"))) {
|
LinkedHashMap<String, Object> dryRunOutput = new LinkedHashMap<>();
|
dryRunOutput.put("success", true);
|
dryRunOutput.put("dryRun", true);
|
dryRunOutput.put("dryRunToken", "token-123");
|
return dryRunOutput;
|
}
|
return Collections.singletonMap("ok", true);
|
});
|
when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any()))
|
.thenReturn(
|
response("read snapshot", toolCall("call_1", "wcs_local_dispatch_get_auto_tune_snapshot", "{}"), 10, 5),
|
response("dry-run first", toolCall("call_2", "wcs_local_dispatch_apply_auto_tune_changes",
|
"{\"dryRun\":true,\"changes\":[{\"targetType\":\"sys_config\",\"targetKey\":\"conveyorStationTaskLimit\",\"newValue\":\"12\"}]}"), 11, 6),
|
response("apply after dry-run", toolCall("call_3", "wcs_local_dispatch_apply_auto_tune_changes",
|
"{\"dryRun\":false,\"dryRunToken\":\"token-123\",\"changes\":[{\"targetType\":\"sys_config\",\"targetKey\":\"conveyorStationTaskLimit\",\"newValue\":\"12\"}]}"), 12, 7),
|
response("已完成自动调参", null, 13, 8)
|
);
|
|
AutoTuneAgentService.AutoTuneAgentResult result = service.runAutoTune("scheduler");
|
|
assertTrue(result.getSuccess());
|
assertEquals("scheduler", result.getTriggerType());
|
assertEquals(3, result.getToolCallCount());
|
assertEquals(4, result.getLlmCallCount());
|
assertEquals(46L, result.getPromptTokens());
|
assertEquals(26L, result.getCompletionTokens());
|
assertEquals(72L, result.getTotalTokens());
|
assertTrue(result.getSummary().contains("已完成自动调参"));
|
|
ArgumentCaptor<String> toolNameCaptor = ArgumentCaptor.forClass(String.class);
|
ArgumentCaptor<JSONObject> argumentCaptor = ArgumentCaptor.forClass(JSONObject.class);
|
verify(mcpToolManager, times(3)).callTool(toolNameCaptor.capture(), argumentCaptor.capture());
|
assertEquals("wcs_local_dispatch_get_auto_tune_snapshot", toolNameCaptor.getAllValues().get(0));
|
assertEquals("wcs_local_dispatch_apply_auto_tune_changes", toolNameCaptor.getAllValues().get(1));
|
assertEquals(Boolean.TRUE, argumentCaptor.getAllValues().get(1).getBoolean("dryRun"));
|
assertEquals(Boolean.FALSE, argumentCaptor.getAllValues().get(2).getBoolean("dryRun"));
|
assertEquals("token-123", argumentCaptor.getAllValues().get(2).getString("dryRunToken"));
|
|
ArgumentCaptor<List<Object>> toolsCaptor = ArgumentCaptor.forClass(List.class);
|
verify(llmChatService, times(4)).chatCompletion(any(), anyDouble(), anyInt(), toolsCaptor.capture());
|
List<String> visibleToolNames = toolNames(toolsCaptor.getAllValues().get(0));
|
assertEquals(4, visibleToolNames.size());
|
assertTrue(visibleToolNames.contains("wcs_local_dispatch_get_auto_tune_snapshot"));
|
assertTrue(visibleToolNames.contains("wcs_local_dispatch_get_recent_auto_tune_jobs"));
|
assertTrue(visibleToolNames.contains("wcs_local_dispatch_apply_auto_tune_changes"));
|
assertTrue(visibleToolNames.contains("wcs_local_dispatch_revert_last_auto_tune_job"));
|
assertFalse(visibleToolNames.contains("wcs_local_device_get_crn_status"));
|
}
|
|
@Test
|
void agentForcesAutoTriggerTypeOnApplyTools() {
|
AutoTuneAgentServiceImpl service = agentService();
|
when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools());
|
when(mcpToolManager.callTool(any(), any(JSONObject.class))).thenReturn(Collections.singletonMap("ok", true));
|
when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any()))
|
.thenReturn(
|
response("snapshot", toolCall("call_1", "wcs_local_dispatch_get_auto_tune_snapshot",
|
"{}"), 10, 5),
|
response("dry run", toolCall("call_2", "wcs_local_dispatch_apply_auto_tune_changes",
|
"{\"dryRun\":true,\"changes\":[]}"), 10, 5),
|
response("done", null, 10, 5)
|
);
|
|
AutoTuneAgentService.AutoTuneAgentResult result = service.runAutoTune("auto");
|
|
assertTrue(result.getSuccess());
|
ArgumentCaptor<JSONObject> argumentCaptor = ArgumentCaptor.forClass(JSONObject.class);
|
verify(mcpToolManager).callTool(eq("wcs_local_dispatch_apply_auto_tune_changes"), argumentCaptor.capture());
|
assertEquals("auto", argumentCaptor.getValue().getString("triggerType"));
|
}
|
|
@Test
|
void agentFailsAndDoesNotExecuteDisallowedToolCall() {
|
AutoTuneAgentServiceImpl service = agentService();
|
when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools());
|
when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any()))
|
.thenReturn(response("bad tool", toolCall("call_1", "wcs_local_device_get_crn_status", "{}"), 10, 5));
|
|
AutoTuneAgentService.AutoTuneAgentResult result = service.runAutoTune("scheduler");
|
|
assertFalse(result.getSuccess());
|
assertEquals(0, result.getToolCallCount());
|
assertTrue(result.getSummary().contains("Disallowed auto-tune MCP tool"));
|
verify(mcpToolManager, never()).callTool(any(), any(JSONObject.class));
|
}
|
|
@Test
|
void agentFailsWhenAllowedToolThrows() {
|
AutoTuneAgentServiceImpl service = agentService();
|
when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools());
|
when(mcpToolManager.callTool(any(), any(JSONObject.class))).thenThrow(new RuntimeException("boom"));
|
when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any()))
|
.thenReturn(response("snapshot", toolCall("call_1", "wcs_local_dispatch_get_auto_tune_snapshot", "{}"), 10, 5));
|
|
AutoTuneAgentService.AutoTuneAgentResult result = service.runAutoTune("scheduler");
|
|
assertFalse(result.getSuccess());
|
assertEquals(0, result.getToolCallCount());
|
assertTrue(result.getSummary().contains("Auto-tune MCP tool failed"));
|
verify(mcpToolManager).callTool(any(), any(JSONObject.class));
|
}
|
|
@Test
|
void agentFailsWhenLlmReturnsNoToolCalls() {
|
AutoTuneAgentServiceImpl service = agentService();
|
when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools());
|
when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any()))
|
.thenReturn(response("no changes needed", null, 10, 5));
|
|
AutoTuneAgentService.AutoTuneAgentResult result = service.runAutoTune("scheduler");
|
|
assertFalse(result.getSuccess());
|
assertEquals(0, result.getToolCallCount());
|
assertTrue(result.getSummary().contains("未调用任何允许的 MCP 工具"));
|
}
|
|
@Test
|
void agentFailsWhenMaxRoundsReached() {
|
AutoTuneAgentServiceImpl service = agentService();
|
when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools());
|
when(mcpToolManager.callTool(any(), any(JSONObject.class))).thenReturn(Collections.singletonMap("ok", true));
|
when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any()))
|
.thenReturn(response("keep going", toolCall("call_1", "wcs_local_dispatch_get_auto_tune_snapshot", "{}"), 10, 5));
|
|
AutoTuneAgentService.AutoTuneAgentResult result = service.runAutoTune("scheduler");
|
|
assertFalse(result.getSuccess());
|
assertEquals(10, result.getToolCallCount());
|
assertTrue(result.getMaxRoundsReached());
|
assertTrue(result.getSummary().contains("达到最大工具调用轮次"));
|
}
|
|
private AutoTuneAgentServiceImpl agentService() {
|
AiPromptTemplate promptTemplate = new AiPromptTemplate();
|
promptTemplate.setContent("system prompt");
|
when(aiPromptTemplateService.resolvePublished("wcs_auto_tune_dispatch")).thenReturn(promptTemplate);
|
return new AutoTuneAgentServiceImpl(llmChatService, mcpToolManager, aiPromptTemplateService);
|
}
|
|
private AutoTuneCoordinatorServiceImpl coordinatorService() {
|
return new AutoTuneCoordinatorServiceImpl(
|
configService,
|
wrkMastService,
|
aiAutoTuneJobService,
|
autoTuneAgentService,
|
redisUtil,
|
operateLogService);
|
}
|
|
private AutoTuneAgentService.AutoTuneAgentResult successfulAgentResult() {
|
AutoTuneAgentService.AutoTuneAgentResult result = new AutoTuneAgentService.AutoTuneAgentResult();
|
result.setSuccess(true);
|
result.setTriggerType(AutoTuneTriggerType.AUTO.getCode());
|
result.setSummary("no changes needed");
|
result.setToolCallCount(1);
|
result.setLlmCallCount(1);
|
result.setPromptTokens(10L);
|
result.setCompletionTokens(5L);
|
result.setTotalTokens(15L);
|
result.setMaxRoundsReached(false);
|
return result;
|
}
|
|
private AutoTuneAgentService.AutoTuneAgentResult failedAgentResult() {
|
AutoTuneAgentService.AutoTuneAgentResult result = successfulAgentResult();
|
result.setSuccess(false);
|
result.setSummary("failed");
|
return result;
|
}
|
|
private AutoTuneChangeCommand change(String targetType, String targetId, String targetKey, String newValue) {
|
AutoTuneChangeCommand command = new AutoTuneChangeCommand();
|
command.setTargetType(targetType);
|
command.setTargetId(targetId);
|
command.setTargetKey(targetKey);
|
command.setNewValue(newValue);
|
return command;
|
}
|
|
private List<Object> allowedOpenAiTools() {
|
List<Object> tools = new ArrayList<>();
|
tools.add(openAiTool("wcs_local_dispatch_get_auto_tune_snapshot"));
|
tools.add(openAiTool("wcs_local_dispatch_get_recent_auto_tune_jobs"));
|
tools.add(openAiTool("wcs_local_dispatch_apply_auto_tune_changes"));
|
tools.add(openAiTool("wcs_local_dispatch_revert_last_auto_tune_job"));
|
tools.add(openAiTool("wcs_local_device_get_crn_status"));
|
return tools;
|
}
|
|
private Map<String, Object> openAiTool(String name) {
|
LinkedHashMap<String, Object> function = new LinkedHashMap<>();
|
function.put("name", name);
|
function.put("parameters", Collections.emptyMap());
|
|
LinkedHashMap<String, Object> tool = new LinkedHashMap<>();
|
tool.put("type", "function");
|
tool.put("function", function);
|
return tool;
|
}
|
|
private List<String> toolNames(List<Object> tools) {
|
List<String> names = new ArrayList<>();
|
for (Object tool : tools) {
|
Map<?, ?> toolMap = (Map<?, ?>) tool;
|
Map<?, ?> function = (Map<?, ?>) toolMap.get("function");
|
names.add(String.valueOf(function.get("name")));
|
}
|
return names;
|
}
|
|
private ChatCompletionResponse response(String content,
|
ChatCompletionRequest.ToolCall toolCall,
|
int promptTokens,
|
int completionTokens) {
|
ChatCompletionRequest.Message message = new ChatCompletionRequest.Message();
|
message.setRole("assistant");
|
message.setContent(content);
|
if (toolCall != null) {
|
message.setTool_calls(Collections.singletonList(toolCall));
|
}
|
|
ChatCompletionResponse.Choice choice = new ChatCompletionResponse.Choice();
|
choice.setIndex(0);
|
choice.setMessage(message);
|
|
ChatCompletionResponse.Usage usage = new ChatCompletionResponse.Usage();
|
usage.setPromptTokens(promptTokens);
|
usage.setCompletionTokens(completionTokens);
|
usage.setTotalTokens(promptTokens + completionTokens);
|
|
ChatCompletionResponse response = new ChatCompletionResponse();
|
List<ChatCompletionResponse.Choice> choices = new ArrayList<>();
|
choices.add(choice);
|
response.setChoices(choices);
|
response.setUsage(usage);
|
return response;
|
}
|
|
private ChatCompletionRequest.ToolCall toolCall(String id, String name, String arguments) {
|
ChatCompletionRequest.Function function = new ChatCompletionRequest.Function();
|
function.setName(name);
|
function.setArguments(arguments);
|
|
ChatCompletionRequest.ToolCall toolCall = new ChatCompletionRequest.ToolCall();
|
toolCall.setId(id);
|
toolCall.setType("function");
|
toolCall.setFunction(function);
|
return toolCall;
|
}
|
}
|