Junjie
2026-04-27 abe9409ecbe5ac752dd6f14100733d1b3739ac09
test: cover auto tune agent safety paths
2个文件已修改
83 ■■■■■ 已修改文件
src/main/java/com/zy/ai/mcp/tool/AutoTuneMcpTools.java 23 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
src/test/java/com/zy/ai/service/AutoTuneCoordinatorServiceImplTest.java 60 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
src/main/java/com/zy/ai/mcp/tool/AutoTuneMcpTools.java
@@ -26,6 +26,7 @@
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.function.LongSupplier;
@Component
@RequiredArgsConstructor
@@ -40,6 +41,7 @@
    private final AiAutoTuneJobService aiAutoTuneJobService;
    private final AiAutoTuneChangeService aiAutoTuneChangeService;
    private final ConcurrentMap<String, DryRunPreview> dryRunPreviews = new ConcurrentHashMap<>();
    private LongSupplier currentTimeMillisSupplier = System::currentTimeMillis;
    @Tool(name = "dispatch_get_auto_tune_snapshot", description = "获取WCS自动调参所需的调度快照、站点运行态、拓扑容量和当前可写参数")
    public AutoTuneSnapshot getAutoTuneSnapshot() {
@@ -164,7 +166,7 @@
        if (preview == null) {
            throw new IllegalArgumentException("dryRunToken is missing, expired, or already used.");
        }
        if (preview.isExpired()) {
        if (preview.isExpired(currentTimeMillis())) {
            throw new IllegalArgumentException("dryRunToken is expired. Run dryRun=true again.");
        }
        if (!preview.getFingerprint().equals(fingerprint)) {
@@ -175,13 +177,14 @@
    private String createDryRunToken(String fingerprint) {
        cleanExpiredDryRunPreviews();
        String token = UUID.randomUUID().toString();
        dryRunPreviews.put(token, new DryRunPreview(fingerprint, System.currentTimeMillis() + DRY_RUN_TOKEN_TTL_MILLIS));
        dryRunPreviews.put(token, new DryRunPreview(fingerprint, currentTimeMillis() + DRY_RUN_TOKEN_TTL_MILLIS));
        return token;
    }
    private void cleanExpiredDryRunPreviews() {
        long currentTimeMillis = currentTimeMillis();
        for (Map.Entry<String, DryRunPreview> entry : dryRunPreviews.entrySet()) {
            if (entry.getValue() == null || entry.getValue().isExpired()) {
            if (entry.getValue() == null || entry.getValue().isExpired(currentTimeMillis)) {
                dryRunPreviews.remove(entry.getKey());
            }
        }
@@ -228,6 +231,16 @@
        return value == null || value.trim().isEmpty();
    }
    private long currentTimeMillis() {
        return currentTimeMillisSupplier.getAsLong();
    }
    void setCurrentTimeMillisSupplier(LongSupplier currentTimeMillisSupplier) {
        this.currentTimeMillisSupplier = currentTimeMillisSupplier == null
                ? System::currentTimeMillis
                : currentTimeMillisSupplier;
    }
    private static class DryRunPreview {
        private final String fingerprint;
        private final long expireAtMillis;
@@ -241,8 +254,8 @@
            return fingerprint;
        }
        boolean isExpired() {
            return System.currentTimeMillis() > expireAtMillis;
        boolean isExpired(long currentTimeMillis) {
            return currentTimeMillis > expireAtMillis;
        }
    }
}
src/test/java/com/zy/ai/service/AutoTuneCoordinatorServiceImplTest.java
@@ -20,12 +20,15 @@
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.test.util.ReflectionTestUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.LongSupplier;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
@@ -205,6 +208,26 @@
    }
    @Test
    void applyToolRejectsExpiredDryRunToken() {
        AtomicLong currentTimeMillis = new AtomicLong(1_000L);
        ReflectionTestUtils.invokeMethod(tools, "setCurrentTimeMillisSupplier", (LongSupplier) currentTimeMillis::get);
        AutoTuneApplyResult dryRunResult = new AutoTuneApplyResult();
        dryRunResult.setDryRun(true);
        dryRunResult.setSuccess(true);
        when(autoTuneApplyService.apply(any(AutoTuneApplyRequest.class))).thenReturn(dryRunResult);
        List<AutoTuneChangeCommand> changes = Collections.singletonList(
                change("sys_config", null, "conveyorStationTaskLimit", "12"));
        AutoTuneApplyResult preview = tools.applyAutoTuneChanges("preview", 10, "agent", true, null, changes);
        currentTimeMillis.addAndGet(10L * 60L * 1000L + 1L);
        IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
                () -> tools.applyAutoTuneChanges("apply", 10, "agent", false, preview.getDryRunToken(), changes));
        assertTrue(exception.getMessage().contains("expired"));
        verify(autoTuneApplyService, times(1)).apply(any(AutoTuneApplyRequest.class));
    }
    @Test
    void rollbackToolDelegatesToApplyServiceRollback() {
        AutoTuneApplyResult expected = new AutoTuneApplyResult();
        when(autoTuneApplyService.rollbackLastSuccessfulJob("bad result")).thenReturn(expected);
@@ -225,14 +248,26 @@
        promptTemplate.setContent("system prompt");
        when(aiPromptTemplateService.resolvePublished("wcs_auto_tune_dispatch")).thenReturn(promptTemplate);
        when(mcpToolManager.buildOpenAiTools()).thenReturn(allowedOpenAiTools());
        when(mcpToolManager.callTool(any(), any(JSONObject.class))).thenReturn(Collections.singletonMap("ok", true));
        when(mcpToolManager.callTool(any(), any(JSONObject.class))).thenAnswer(invocation -> {
            String toolName = invocation.getArgument(0);
            JSONObject arguments = invocation.getArgument(1);
            if ("wcs_local_dispatch_apply_auto_tune_changes".equals(toolName)
                    && Boolean.TRUE.equals(arguments.getBoolean("dryRun"))) {
                LinkedHashMap<String, Object> dryRunOutput = new LinkedHashMap<>();
                dryRunOutput.put("success", true);
                dryRunOutput.put("dryRun", true);
                dryRunOutput.put("dryRunToken", "token-123");
                return dryRunOutput;
            }
            return Collections.singletonMap("ok", true);
        });
        when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any()))
                .thenReturn(
                        response("read snapshot", toolCall("call_1", "wcs_local_dispatch_get_auto_tune_snapshot", "{}"), 10, 5),
                        response("dry-run first", toolCall("call_2", "wcs_local_dispatch_apply_auto_tune_changes",
                                "{\"dryRun\":true,\"changes\":[{\"targetType\":\"sys_config\",\"targetKey\":\"conveyorStationTaskLimit\",\"newValue\":\"12\"}]}"), 11, 6),
                        response("apply after dry-run", toolCall("call_3", "wcs_local_dispatch_apply_auto_tune_changes",
                                "{\"dryRun\":false,\"changes\":[{\"targetType\":\"sys_config\",\"targetKey\":\"conveyorStationTaskLimit\",\"newValue\":\"12\"}]}"), 12, 7),
                                "{\"dryRun\":false,\"dryRunToken\":\"token-123\",\"changes\":[{\"targetType\":\"sys_config\",\"targetKey\":\"conveyorStationTaskLimit\",\"newValue\":\"12\"}]}"), 12, 7),
                        response("已完成自动调参", null, 13, 8)
                );
@@ -254,6 +289,17 @@
        assertEquals("wcs_local_dispatch_apply_auto_tune_changes", toolNameCaptor.getAllValues().get(1));
        assertEquals(Boolean.TRUE, argumentCaptor.getAllValues().get(1).getBoolean("dryRun"));
        assertEquals(Boolean.FALSE, argumentCaptor.getAllValues().get(2).getBoolean("dryRun"));
        assertEquals("token-123", argumentCaptor.getAllValues().get(2).getString("dryRunToken"));
        ArgumentCaptor<List<Object>> toolsCaptor = ArgumentCaptor.forClass(List.class);
        verify(llmChatService, times(4)).chatCompletion(any(), anyDouble(), anyInt(), toolsCaptor.capture());
        List<String> visibleToolNames = toolNames(toolsCaptor.getAllValues().get(0));
        assertEquals(4, visibleToolNames.size());
        assertTrue(visibleToolNames.contains("wcs_local_dispatch_get_auto_tune_snapshot"));
        assertTrue(visibleToolNames.contains("wcs_local_dispatch_get_recent_auto_tune_jobs"));
        assertTrue(visibleToolNames.contains("wcs_local_dispatch_apply_auto_tune_changes"));
        assertTrue(visibleToolNames.contains("wcs_local_dispatch_revert_last_auto_tune_job"));
        assertFalse(visibleToolNames.contains("wcs_local_device_get_crn_status"));
    }
    @Test
@@ -354,6 +400,16 @@
        return tool;
    }
    private List<String> toolNames(List<Object> tools) {
        List<String> names = new ArrayList<>();
        for (Object tool : tools) {
            Map<?, ?> toolMap = (Map<?, ?>) tool;
            Map<?, ?> function = (Map<?, ?>) toolMap.get("function");
            names.add(String.valueOf(function.get("name")));
        }
        return names;
    }
    private ChatCompletionResponse response(String content,
                                            ChatCompletionRequest.ToolCall toolCall,
                                            int promptTokens,