| src/main/java/com/zy/ai/mcp/config/SpringAiMcpConfig.java | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| src/main/java/com/zy/ai/mcp/tool/AutoTuneMcpTools.java | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| src/main/java/com/zy/ai/service/AutoTuneAgentService.java | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| src/main/java/com/zy/ai/service/impl/AutoTuneAgentServiceImpl.java | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| src/test/java/com/zy/ai/service/AutoTuneCoordinatorServiceImplTest.java | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 |
src/main/java/com/zy/ai/mcp/config/SpringAiMcpConfig.java
@@ -1,5 +1,6 @@ package com.zy.ai.mcp.config; import com.zy.ai.mcp.tool.AutoTuneMcpTools; import com.zy.ai.mcp.tool.WcsMcpTools; import org.springframework.ai.support.ToolCallbacks; import org.springframework.ai.tool.StaticToolCallbackProvider; @@ -11,7 +12,8 @@ public class SpringAiMcpConfig { @Bean("wcsMcpToolCallbackProvider") public ToolCallbackProvider wcsMcpToolCallbackProvider(WcsMcpTools wcsMcpTools) { return new StaticToolCallbackProvider(ToolCallbacks.from(wcsMcpTools)); public ToolCallbackProvider wcsMcpToolCallbackProvider(WcsMcpTools wcsMcpTools, AutoTuneMcpTools autoTuneMcpTools) { return new StaticToolCallbackProvider(ToolCallbacks.from(wcsMcpTools, autoTuneMcpTools)); } } src/main/java/com/zy/ai/mcp/tool/AutoTuneMcpTools.java
New file @@ -0,0 +1,136 @@ package com.zy.ai.mcp.tool; import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import com.zy.ai.domain.autotune.AutoTuneApplyRequest; import com.zy.ai.domain.autotune.AutoTuneApplyResult; import com.zy.ai.domain.autotune.AutoTuneChangeCommand; import com.zy.ai.domain.autotune.AutoTuneSnapshot; import com.zy.ai.entity.AiAutoTuneChange; import com.zy.ai.entity.AiAutoTuneJob; import com.zy.ai.service.AiAutoTuneChangeService; import com.zy.ai.service.AiAutoTuneJobService; import com.zy.ai.service.AutoTuneApplyService; import com.zy.ai.service.AutoTuneSnapshotService; import lombok.RequiredArgsConstructor; import org.springframework.ai.tool.annotation.Tool; import org.springframework.ai.tool.annotation.ToolParam; import org.springframework.stereotype.Component; import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @Component @RequiredArgsConstructor public class AutoTuneMcpTools { private static final int DEFAULT_RECENT_JOB_LIMIT = 5; private static final int MAX_RECENT_JOB_LIMIT = 20; private final AutoTuneSnapshotService autoTuneSnapshotService; private final AutoTuneApplyService autoTuneApplyService; private final AiAutoTuneJobService aiAutoTuneJobService; private final AiAutoTuneChangeService aiAutoTuneChangeService; @Tool(name = "dispatch_get_auto_tune_snapshot", description = "获取WCS自动调参所需的调度快照、站点运行态、拓扑容量和当前可写参数") public AutoTuneSnapshot getAutoTuneSnapshot() { return autoTuneSnapshotService.buildSnapshot(); } @Tool(name = "dispatch_get_recent_auto_tune_jobs", description = "获取近期自动调参任务摘要及其变更结果,默认5条,最大20条") public List<Map<String, Object>> getRecentAutoTuneJobs( @ToolParam(description = "返回任务数量上限,默认5,最大20", required = false) Integer limit) { int safeLimit = boundLimit(limit); List<AiAutoTuneJob> jobs = aiAutoTuneJobService.list(new QueryWrapper<AiAutoTuneJob>() .orderByDesc("start_time") .orderByDesc("id") .last("limit " + safeLimit)); if (jobs == null || jobs.isEmpty()) { return new ArrayList<>(); } List<Map<String, Object>> result = new ArrayList<>(); for (AiAutoTuneJob job : jobs) { result.add(toJobSummary(job)); } return result; } @Tool(name = "dispatch_apply_auto_tune_changes", description = "提交自动调参变更。实际应用前必须先使用 dryRun=true 验证") public AutoTuneApplyResult applyAutoTuneChanges( @ToolParam(description = "本次调参原因或分析摘要", required = false) String reason, @ToolParam(description = "建议自动调参分析间隔分钟", required = false) Integer analysisIntervalMinutes, @ToolParam(description = "触发类型,例如 scheduler/manual/agent", required = false) String triggerType, @ToolParam(description = "是否仅试算,实际应用前必须先传 true", required = false) Boolean dryRun, @ToolParam(description = "调参变更列表") List<AutoTuneChangeCommand> changes) { AutoTuneApplyRequest request = new AutoTuneApplyRequest(); request.setReason(reason); request.setAnalysisIntervalMinutes(analysisIntervalMinutes); request.setTriggerType(triggerType); request.setDryRun(dryRun); request.setChanges(changes); return autoTuneApplyService.apply(request); } @Tool(name = "dispatch_revert_last_auto_tune_job", description = "回滚最近一次成功的自动调参任务") public AutoTuneApplyResult revertLastAutoTuneJob( @ToolParam(description = "回滚原因,必须说明来自MCP事实的异常证据", required = false) String reason) { return autoTuneApplyService.rollbackLastSuccessfulJob(reason); } private Map<String, Object> toJobSummary(AiAutoTuneJob job) { LinkedHashMap<String, Object> item = new LinkedHashMap<>(); item.put("id", job.getId()); item.put("triggerType", job.getTriggerType()); item.put("status", job.getStatus()); item.put("startTime", job.getStartTime()); item.put("finishTime", job.getFinishTime()); item.put("summary", job.getSummary()); item.put("successCount", job.getSuccessCount()); item.put("rejectCount", job.getRejectCount()); item.put("errorMessage", job.getErrorMessage()); item.put("changes", listChangeSummaries(job.getId())); return item; } private List<Map<String, Object>> listChangeSummaries(Long jobId) { if (jobId == null) { return new ArrayList<>(); } List<AiAutoTuneChange> changes = aiAutoTuneChangeService.list(new QueryWrapper<AiAutoTuneChange>() .eq("job_id", jobId) .orderByAsc("id")); if (changes == null || changes.isEmpty()) { return new ArrayList<>(); } List<Map<String, Object>> result = new ArrayList<>(); for (AiAutoTuneChange change : changes) { result.add(toChangeSummary(change)); } return result; } private Map<String, Object> toChangeSummary(AiAutoTuneChange change) { LinkedHashMap<String, Object> item = new LinkedHashMap<>(); item.put("targetType", change.getTargetType()); item.put("targetId", change.getTargetId()); item.put("targetKey", change.getTargetKey()); item.put("oldValue", change.getOldValue()); item.put("requestedValue", change.getRequestedValue()); item.put("appliedValue", change.getAppliedValue()); item.put("resultStatus", change.getResultStatus()); item.put("rejectReason", change.getRejectReason()); item.put("cooldownExpireTime", change.getCooldownExpireTime()); item.put("createTime", change.getCreateTime()); return item; } private int boundLimit(Integer limit) { if (limit == null || limit <= 0) { return DEFAULT_RECENT_JOB_LIMIT; } return Math.min(limit, MAX_RECENT_JOB_LIMIT); } } src/main/java/com/zy/ai/service/AutoTuneAgentService.java
New file @@ -0,0 +1,33 @@ package com.zy.ai.service; import lombok.Data; import java.io.Serializable; public interface AutoTuneAgentService { AutoTuneAgentResult runAutoTune(String triggerType); @Data class AutoTuneAgentResult implements Serializable { private static final long serialVersionUID = 1L; private Boolean success; private String triggerType; private String summary; private Integer toolCallCount; private Integer llmCallCount; private Long promptTokens; private Long completionTokens; private Long totalTokens; private Boolean maxRoundsReached; } } src/main/java/com/zy/ai/service/impl/AutoTuneAgentServiceImpl.java
New file @@ -0,0 +1,227 @@ package com.zy.ai.service.impl; import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONObject; import com.zy.ai.entity.AiPromptTemplate; import com.zy.ai.entity.ChatCompletionRequest; import com.zy.ai.entity.ChatCompletionResponse; import com.zy.ai.enums.AiPromptScene; import com.zy.ai.mcp.service.SpringAiMcpToolManager; import com.zy.ai.service.AiPromptTemplateService; import com.zy.ai.service.AutoTuneAgentService; import com.zy.ai.service.LlmChatService; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; @Slf4j @Service @RequiredArgsConstructor public class AutoTuneAgentServiceImpl implements AutoTuneAgentService { private static final int MAX_TOOL_ROUNDS = 10; private static final double TEMPERATURE = 0.2D; private static final int MAX_TOKENS = 2048; private final LlmChatService llmChatService; private final SpringAiMcpToolManager mcpToolManager; private final AiPromptTemplateService aiPromptTemplateService; @Override public AutoTuneAgentResult runAutoTune(String triggerType) { String normalizedTriggerType = normalizeTriggerType(triggerType); UsageCounter usageCounter = new UsageCounter(); int toolCallCount = 0; boolean maxRoundsReached = false; StringBuilder summaryBuffer = new StringBuilder(); try { List<Object> tools = mcpToolManager.buildOpenAiTools(); if (tools == null || tools.isEmpty()) { throw new IllegalStateException("No MCP tools registered"); } AiPromptTemplate promptTemplate = aiPromptTemplateService.resolvePublished(AiPromptScene.AUTO_TUNE_DISPATCH.getCode()); List<ChatCompletionRequest.Message> messages = buildMessages(promptTemplate, normalizedTriggerType); for (int round = 0; round < MAX_TOOL_ROUNDS; round++) { ChatCompletionResponse response = llmChatService.chatCompletion(messages, TEMPERATURE, MAX_TOKENS, tools); ChatCompletionRequest.Message assistantMessage = extractAssistantMessage(response); usageCounter.add(response.getUsage()); messages.add(assistantMessage); appendSummary(summaryBuffer, assistantMessage.getContent()); List<ChatCompletionRequest.ToolCall> toolCalls = assistantMessage.getTool_calls(); if (toolCalls == null || toolCalls.isEmpty()) { return buildResult(true, normalizedTriggerType, summaryBuffer, toolCallCount, usageCounter, false); } for (ChatCompletionRequest.ToolCall toolCall : toolCalls) { Object toolOutput = callMountedTool(toolCall); toolCallCount++; messages.add(buildToolMessage(toolCall, toolOutput)); } } maxRoundsReached = true; return buildResult(true, normalizedTriggerType, summaryBuffer, toolCallCount, usageCounter, maxRoundsReached); } catch (Exception exception) { log.error("Auto tune agent stopped with error", exception); appendSummary(summaryBuffer, "自动调参 Agent 执行异常: " + exception.getMessage()); return buildResult(false, normalizedTriggerType, summaryBuffer, toolCallCount, usageCounter, maxRoundsReached); } } private List<ChatCompletionRequest.Message> buildMessages(AiPromptTemplate promptTemplate, String triggerType) { List<ChatCompletionRequest.Message> messages = new ArrayList<>(); ChatCompletionRequest.Message systemMessage = new ChatCompletionRequest.Message(); systemMessage.setRole("system"); systemMessage.setContent(promptTemplate == null ? "" : promptTemplate.getContent()); messages.add(systemMessage); ChatCompletionRequest.Message userMessage = new ChatCompletionRequest.Message(); userMessage.setRole("user"); userMessage.setContent("请执行一次后台 WCS 自动调参。triggerType=" + triggerType + "。必须先调用 wcs_local_dispatch_get_auto_tune_snapshot 获取事实;如需提交变更," + "必须先 dry-run,再根据 dry-run 结果决定是否实际应用。不要输出自由格式 JSON 供外层解析。"); messages.add(userMessage); return messages; } private ChatCompletionRequest.Message extractAssistantMessage(ChatCompletionResponse response) { if (response == null || response.getChoices() == null || response.getChoices().isEmpty()) { throw new IllegalStateException("LLM returned empty response"); } ChatCompletionRequest.Message message = response.getChoices().get(0).getMessage(); if (message == null) { throw new IllegalStateException("LLM returned empty message"); } return message; } private Object callMountedTool(ChatCompletionRequest.ToolCall toolCall) { String toolName = resolveToolName(toolCall); JSONObject arguments = parseArguments(toolCall); try { return mcpToolManager.callTool(toolName, arguments); } catch (Exception exception) { LinkedHashMap<String, Object> error = new LinkedHashMap<>(); error.put("tool", toolName); error.put("error", exception.getMessage()); return error; } } private ChatCompletionRequest.Message buildToolMessage(ChatCompletionRequest.ToolCall toolCall, Object toolOutput) { ChatCompletionRequest.Message toolMessage = new ChatCompletionRequest.Message(); toolMessage.setRole("tool"); toolMessage.setTool_call_id(toolCall == null ? null : toolCall.getId()); toolMessage.setContent(JSON.toJSONString(toolOutput)); return toolMessage; } private String resolveToolName(ChatCompletionRequest.ToolCall toolCall) { if (toolCall == null || toolCall.getFunction() == null || isBlank(toolCall.getFunction().getName())) { throw new IllegalArgumentException("missing tool name"); } return toolCall.getFunction().getName(); } private JSONObject parseArguments(ChatCompletionRequest.ToolCall toolCall) { String rawArguments = toolCall == null || toolCall.getFunction() == null ? null : toolCall.getFunction().getArguments(); if (isBlank(rawArguments)) { return new JSONObject(); } try { return JSON.parseObject(rawArguments); } catch (Exception exception) { JSONObject arguments = new JSONObject(); arguments.put("_raw", rawArguments); return arguments; } } private AutoTuneAgentResult buildResult(boolean success, String triggerType, StringBuilder summaryBuffer, int toolCallCount, UsageCounter usageCounter, boolean maxRoundsReached) { AutoTuneAgentResult result = new AutoTuneAgentResult(); result.setSuccess(success); result.setTriggerType(triggerType); result.setToolCallCount(toolCallCount); result.setLlmCallCount(usageCounter.getLlmCallCount()); result.setPromptTokens(usageCounter.getPromptTokens()); result.setCompletionTokens(usageCounter.getCompletionTokens()); result.setTotalTokens(usageCounter.getTotalTokens()); result.setMaxRoundsReached(maxRoundsReached); String summary = summaryBuffer == null ? "" : summaryBuffer.toString().trim(); if (toolCallCount <= 0 && success) { summary = "自动调参 Agent 未调用任何 MCP 工具,未执行调参。" + (summary.isEmpty() ? "" : "\n" + summary); } if (maxRoundsReached) { summary = summary + "\n自动调参 Agent 达到最大工具调用轮次,已停止。"; } result.setSummary(summary); return result; } private void appendSummary(StringBuilder summaryBuffer, String content) { if (summaryBuffer == null || isBlank(content)) { return; } if (summaryBuffer.length() > 0) { summaryBuffer.append('\n'); } summaryBuffer.append(content.trim()); } private String normalizeTriggerType(String triggerType) { return isBlank(triggerType) ? "agent" : triggerType.trim(); } private boolean isBlank(String value) { return value == null || value.trim().isEmpty(); } private static class UsageCounter { private long promptTokens; private long completionTokens; private long totalTokens; private int llmCallCount; void add(ChatCompletionResponse.Usage usage) { llmCallCount++; if (usage == null) { return; } promptTokens += usage.getPromptTokens() == null ? 0L : usage.getPromptTokens(); completionTokens += usage.getCompletionTokens() == null ? 0L : usage.getCompletionTokens(); totalTokens += usage.getTotalTokens() == null ? 0L : usage.getTotalTokens(); } long getPromptTokens() { return promptTokens; } long getCompletionTokens() { return completionTokens; } long getTotalTokens() { return totalTokens; } int getLlmCallCount() { return llmCallCount; } } } src/test/java/com/zy/ai/service/AutoTuneCoordinatorServiceImplTest.java
New file @@ -0,0 +1,226 @@ package com.zy.ai.service; import com.alibaba.fastjson.JSONObject; import com.baomidou.mybatisplus.core.conditions.Wrapper; import com.zy.ai.domain.autotune.AutoTuneApplyRequest; import com.zy.ai.domain.autotune.AutoTuneApplyResult; import com.zy.ai.domain.autotune.AutoTuneChangeCommand; import com.zy.ai.domain.autotune.AutoTuneSnapshot; import com.zy.ai.entity.AiAutoTuneChange; import com.zy.ai.entity.AiAutoTuneJob; import com.zy.ai.entity.AiPromptTemplate; import com.zy.ai.entity.ChatCompletionRequest; import com.zy.ai.entity.ChatCompletionResponse; import com.zy.ai.mcp.service.SpringAiMcpToolManager; import com.zy.ai.mcp.tool.AutoTuneMcpTools; import com.zy.ai.service.impl.AutoTuneAgentServiceImpl; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyDouble; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) class AutoTuneCoordinatorServiceImplTest { private AutoTuneMcpTools tools; @Mock private AutoTuneSnapshotService autoTuneSnapshotService; @Mock private AutoTuneApplyService autoTuneApplyService; @Mock private AiAutoTuneJobService aiAutoTuneJobService; @Mock private AiAutoTuneChangeService aiAutoTuneChangeService; @Mock private LlmChatService llmChatService; @Mock private SpringAiMcpToolManager mcpToolManager; @Mock private AiPromptTemplateService aiPromptTemplateService; @BeforeEach void setUp() { tools = new AutoTuneMcpTools( autoTuneSnapshotService, autoTuneApplyService, aiAutoTuneJobService, aiAutoTuneChangeService); } @Test void snapshotToolDelegatesToSnapshotService() { AutoTuneSnapshot snapshot = new AutoTuneSnapshot(); when(autoTuneSnapshotService.buildSnapshot()).thenReturn(snapshot); AutoTuneSnapshot result = tools.getAutoTuneSnapshot(); assertSame(snapshot, result); verify(autoTuneSnapshotService).buildSnapshot(); } @Test void recentJobsReturnsBoundedCompactSummariesWithChanges() { AiAutoTuneJob job = new AiAutoTuneJob(); job.setId(7L); job.setTriggerType("agent"); job.setStatus("success"); job.setSummary("applied"); job.setSuccessCount(1); job.setRejectCount(0); AiAutoTuneChange change = new AiAutoTuneChange(); change.setJobId(7L); change.setTargetType("sys_config"); change.setTargetKey("conveyorStationTaskLimit"); change.setRequestedValue("12"); change.setResultStatus("success"); when(aiAutoTuneJobService.list(any(Wrapper.class))).thenReturn(Collections.singletonList(job)); when(aiAutoTuneChangeService.list(any(Wrapper.class))).thenReturn(Collections.singletonList(change)); List<Map<String, Object>> result = tools.getRecentAutoTuneJobs(99); assertEquals(1, result.size()); assertEquals(7L, result.get(0).get("id")); assertFalse(result.get(0).containsKey("reasoningDigest")); List<?> changes = (List<?>) result.get(0).get("changes"); assertEquals(1, changes.size()); ArgumentCaptor<Wrapper<AiAutoTuneJob>> wrapperCaptor = ArgumentCaptor.forClass(Wrapper.class); verify(aiAutoTuneJobService).list(wrapperCaptor.capture()); assertTrue(wrapperCaptor.getValue().getSqlSegment().contains("limit 20")); } @Test void applyToolDelegatesToApplyServiceWithDryRunAndChanges() { AutoTuneApplyResult expected = new AutoTuneApplyResult(); expected.setDryRun(true); when(autoTuneApplyService.apply(any(AutoTuneApplyRequest.class))).thenReturn(expected); AutoTuneChangeCommand command = new AutoTuneChangeCommand(); command.setTargetType("sys_config"); command.setTargetKey("conveyorStationTaskLimit"); command.setNewValue("12"); List<AutoTuneChangeCommand> changes = Collections.singletonList(command); AutoTuneApplyResult result = tools.applyAutoTuneChanges("reduce congestion", 10, "agent", true, changes); assertSame(expected, result); ArgumentCaptor<AutoTuneApplyRequest> captor = ArgumentCaptor.forClass(AutoTuneApplyRequest.class); verify(autoTuneApplyService).apply(captor.capture()); assertEquals("reduce congestion", captor.getValue().getReason()); assertEquals(10, captor.getValue().getAnalysisIntervalMinutes()); assertEquals("agent", captor.getValue().getTriggerType()); assertEquals(Boolean.TRUE, captor.getValue().getDryRun()); assertSame(changes, captor.getValue().getChanges()); } @Test void rollbackToolDelegatesToApplyServiceRollback() { AutoTuneApplyResult expected = new AutoTuneApplyResult(); when(autoTuneApplyService.rollbackLastSuccessfulJob("bad result")).thenReturn(expected); AutoTuneApplyResult result = tools.revertLastAutoTuneJob("bad result"); assertSame(expected, result); verify(autoTuneApplyService).rollbackLastSuccessfulJob("bad result"); } @Test void agentExecutesSnapshotDryRunAndRealApplyToolSequence() { AutoTuneAgentServiceImpl service = new AutoTuneAgentServiceImpl( llmChatService, mcpToolManager, aiPromptTemplateService); AiPromptTemplate promptTemplate = new AiPromptTemplate(); promptTemplate.setContent("system prompt"); when(aiPromptTemplateService.resolvePublished("wcs_auto_tune_dispatch")).thenReturn(promptTemplate); when(mcpToolManager.buildOpenAiTools()).thenReturn(Collections.singletonList(Collections.singletonMap("type", "function"))); when(mcpToolManager.callTool(any(), any(JSONObject.class))).thenReturn(Collections.singletonMap("ok", true)); when(llmChatService.chatCompletion(any(), anyDouble(), anyInt(), any())) .thenReturn( response("read snapshot", toolCall("call_1", "wcs_local_dispatch_get_auto_tune_snapshot", "{}"), 10, 5), response("dry-run first", toolCall("call_2", "wcs_local_dispatch_apply_auto_tune_changes", "{\"dryRun\":true,\"changes\":[{\"targetType\":\"sys_config\",\"targetKey\":\"conveyorStationTaskLimit\",\"newValue\":\"12\"}]}"), 11, 6), response("apply after dry-run", toolCall("call_3", "wcs_local_dispatch_apply_auto_tune_changes", "{\"dryRun\":false,\"changes\":[{\"targetType\":\"sys_config\",\"targetKey\":\"conveyorStationTaskLimit\",\"newValue\":\"12\"}]}"), 12, 7), response("已完成自动调参", null, 13, 8) ); AutoTuneAgentService.AutoTuneAgentResult result = service.runAutoTune("scheduler"); assertTrue(result.getSuccess()); assertEquals("scheduler", result.getTriggerType()); assertEquals(3, result.getToolCallCount()); assertEquals(4, result.getLlmCallCount()); assertEquals(46L, result.getPromptTokens()); assertEquals(26L, result.getCompletionTokens()); assertEquals(72L, result.getTotalTokens()); assertTrue(result.getSummary().contains("已完成自动调参")); ArgumentCaptor<String> toolNameCaptor = ArgumentCaptor.forClass(String.class); ArgumentCaptor<JSONObject> argumentCaptor = ArgumentCaptor.forClass(JSONObject.class); verify(mcpToolManager, org.mockito.Mockito.times(3)).callTool(toolNameCaptor.capture(), argumentCaptor.capture()); assertEquals("wcs_local_dispatch_get_auto_tune_snapshot", toolNameCaptor.getAllValues().get(0)); assertEquals("wcs_local_dispatch_apply_auto_tune_changes", toolNameCaptor.getAllValues().get(1)); assertEquals(Boolean.TRUE, argumentCaptor.getAllValues().get(1).getBoolean("dryRun")); assertEquals(Boolean.FALSE, argumentCaptor.getAllValues().get(2).getBoolean("dryRun")); } private ChatCompletionResponse response(String content, ChatCompletionRequest.ToolCall toolCall, int promptTokens, int completionTokens) { ChatCompletionRequest.Message message = new ChatCompletionRequest.Message(); message.setRole("assistant"); message.setContent(content); if (toolCall != null) { message.setTool_calls(Collections.singletonList(toolCall)); } ChatCompletionResponse.Choice choice = new ChatCompletionResponse.Choice(); choice.setIndex(0); choice.setMessage(message); ChatCompletionResponse.Usage usage = new ChatCompletionResponse.Usage(); usage.setPromptTokens(promptTokens); usage.setCompletionTokens(completionTokens); usage.setTotalTokens(promptTokens + completionTokens); ChatCompletionResponse response = new ChatCompletionResponse(); List<ChatCompletionResponse.Choice> choices = new ArrayList<>(); choices.add(choice); response.setChoices(choices); response.setUsage(usage); return response; } private ChatCompletionRequest.ToolCall toolCall(String id, String name, String arguments) { ChatCompletionRequest.Function function = new ChatCompletionRequest.Function(); function.setName(name); function.setArguments(arguments); ChatCompletionRequest.ToolCall toolCall = new ChatCompletionRequest.ToolCall(); toolCall.setId(id); toolCall.setType("function"); toolCall.setFunction(function); return toolCall; } }