| | |
| | | import org.springframework.stereotype.Service; |
| | | |
| | | import java.util.ArrayList; |
| | | import java.util.LinkedHashMap; |
| | | import java.util.List; |
| | | import java.util.Map; |
| | | import java.util.Set; |
| | | |
| | | @Slf4j |
| | | @Service |
| | |
| | | private static final int MAX_TOOL_ROUNDS = 10; |
| | | private static final double TEMPERATURE = 0.2D; |
| | | private static final int MAX_TOKENS = 2048; |
| | | private static final String TOOL_GET_SNAPSHOT = "wcs_local_dispatch_get_auto_tune_snapshot"; |
| | | private static final String TOOL_GET_RECENT_JOBS = "wcs_local_dispatch_get_recent_auto_tune_jobs"; |
| | | private static final String TOOL_APPLY_CHANGES = "wcs_local_dispatch_apply_auto_tune_changes"; |
| | | private static final String TOOL_REVERT_LAST_JOB = "wcs_local_dispatch_revert_last_auto_tune_job"; |
| | | private static final Set<String> ALLOWED_TOOL_NAMES = Set.of( |
| | | TOOL_GET_SNAPSHOT, |
| | | TOOL_GET_RECENT_JOBS, |
| | | TOOL_APPLY_CHANGES, |
| | | TOOL_REVERT_LAST_JOB |
| | | ); |
| | | |
| | | private final LlmChatService llmChatService; |
| | | private final SpringAiMcpToolManager mcpToolManager; |
| | |
| | | public AutoTuneAgentResult runAutoTune(String triggerType) { |
| | | String normalizedTriggerType = normalizeTriggerType(triggerType); |
| | | UsageCounter usageCounter = new UsageCounter(); |
| | | int toolCallCount = 0; |
| | | RunState runState = new RunState(); |
| | | boolean maxRoundsReached = false; |
| | | StringBuilder summaryBuffer = new StringBuilder(); |
| | | |
| | | try { |
| | | List<Object> tools = mcpToolManager.buildOpenAiTools(); |
| | | List<Object> tools = filterAllowedTools(mcpToolManager.buildOpenAiTools()); |
| | | if (tools == null || tools.isEmpty()) { |
| | | throw new IllegalStateException("No MCP tools registered"); |
| | | throw new IllegalStateException("No auto-tune MCP tools registered"); |
| | | } |
| | | |
| | | AiPromptTemplate promptTemplate = aiPromptTemplateService.resolvePublished(AiPromptScene.AUTO_TUNE_DISPATCH.getCode()); |
| | |
| | | |
| | | List<ChatCompletionRequest.ToolCall> toolCalls = assistantMessage.getTool_calls(); |
| | | if (toolCalls == null || toolCalls.isEmpty()) { |
| | | return buildResult(true, normalizedTriggerType, summaryBuffer, toolCallCount, usageCounter, false); |
| | | return buildResult(runState.isSuccessful(), normalizedTriggerType, summaryBuffer, runState, |
| | | usageCounter, false); |
| | | } |
| | | |
| | | for (ChatCompletionRequest.ToolCall toolCall : toolCalls) { |
| | | Object toolOutput = callMountedTool(toolCall); |
| | | toolCallCount++; |
| | | Object toolOutput = callMountedTool(toolCall, runState); |
| | | messages.add(buildToolMessage(toolCall, toolOutput)); |
| | | } |
| | | } |
| | | maxRoundsReached = true; |
| | | return buildResult(true, normalizedTriggerType, summaryBuffer, toolCallCount, usageCounter, maxRoundsReached); |
| | | return buildResult(false, normalizedTriggerType, summaryBuffer, runState, usageCounter, maxRoundsReached); |
| | | } catch (Exception exception) { |
| | | log.error("Auto tune agent stopped with error", exception); |
| | | appendSummary(summaryBuffer, "自动调参 Agent 执行异常: " + exception.getMessage()); |
| | | return buildResult(false, normalizedTriggerType, summaryBuffer, toolCallCount, usageCounter, maxRoundsReached); |
| | | runState.markToolError(); |
| | | return buildResult(false, normalizedTriggerType, summaryBuffer, runState, usageCounter, maxRoundsReached); |
| | | } |
| | | } |
| | | |
| | |
| | | userMessage.setRole("user"); |
| | | userMessage.setContent("请执行一次后台 WCS 自动调参。triggerType=" + triggerType |
| | | + "。必须先调用 wcs_local_dispatch_get_auto_tune_snapshot 获取事实;如需提交变更," |
| | | + "必须先 dry-run,再根据 dry-run 结果决定是否实际应用。不要输出自由格式 JSON 供外层解析。"); |
| | | + "必须先 dry-run,再根据 dry-run 结果决定是否实际应用;实际应用时必须带上 dry-run 返回的 dryRunToken。" |
| | | + "不要输出自由格式 JSON 供外层解析。"); |
| | | messages.add(userMessage); |
| | | return messages; |
| | | } |
| | |
| | | return message; |
| | | } |
| | | |
| | | private Object callMountedTool(ChatCompletionRequest.ToolCall toolCall) { |
| | | private Object callMountedTool(ChatCompletionRequest.ToolCall toolCall, RunState runState) { |
| | | String toolName = resolveToolName(toolCall); |
| | | if (!ALLOWED_TOOL_NAMES.contains(toolName)) { |
| | | throw new IllegalArgumentException("Disallowed auto-tune MCP tool: " + toolName); |
| | | } |
| | | JSONObject arguments = parseArguments(toolCall); |
| | | try { |
| | | return mcpToolManager.callTool(toolName, arguments); |
| | | Object output = mcpToolManager.callTool(toolName, arguments); |
| | | runState.markToolSuccess(toolName); |
| | | return output; |
| | | } catch (Exception exception) { |
| | | LinkedHashMap<String, Object> error = new LinkedHashMap<>(); |
| | | error.put("tool", toolName); |
| | | error.put("error", exception.getMessage()); |
| | | return error; |
| | | throw new IllegalStateException("Auto-tune MCP tool failed: " + toolName + ", " + exception.getMessage(), |
| | | exception); |
| | | } |
| | | } |
| | | |
| | |
| | | private AutoTuneAgentResult buildResult(boolean success, |
| | | String triggerType, |
| | | StringBuilder summaryBuffer, |
| | | int toolCallCount, |
| | | RunState runState, |
| | | UsageCounter usageCounter, |
| | | boolean maxRoundsReached) { |
| | | AutoTuneAgentResult result = new AutoTuneAgentResult(); |
| | | result.setSuccess(success); |
| | | result.setTriggerType(triggerType); |
| | | result.setToolCallCount(toolCallCount); |
| | | result.setToolCallCount(runState.getToolCallCount()); |
| | | result.setLlmCallCount(usageCounter.getLlmCallCount()); |
| | | result.setPromptTokens(usageCounter.getPromptTokens()); |
| | | result.setCompletionTokens(usageCounter.getCompletionTokens()); |
| | |
| | | result.setMaxRoundsReached(maxRoundsReached); |
| | | |
| | | String summary = summaryBuffer == null ? "" : summaryBuffer.toString().trim(); |
| | | if (toolCallCount <= 0 && success) { |
| | | summary = "自动调参 Agent 未调用任何 MCP 工具,未执行调参。" + (summary.isEmpty() ? "" : "\n" + summary); |
| | | if (runState.getToolCallCount() <= 0) { |
| | | summary = "自动调参 Agent 未调用任何允许的 MCP 工具,未执行调参。" + (summary.isEmpty() ? "" : "\n" + summary); |
| | | } else if (!runState.isSnapshotCalled()) { |
| | | summary = summary + "\n自动调参 Agent 未调用快照工具,结果不完整。"; |
| | | } |
| | | if (runState.hasToolError()) { |
| | | summary = summary + "\n自动调参 Agent 存在工具调用错误,已标记为失败。"; |
| | | } |
| | | if (maxRoundsReached) { |
| | | summary = summary + "\n自动调参 Agent 达到最大工具调用轮次,已停止。"; |
| | | } |
| | | result.setSummary(summary); |
| | | return result; |
| | | } |
| | | |
| | | private List<Object> filterAllowedTools(List<Object> tools) { |
| | | List<Object> allowedTools = new ArrayList<>(); |
| | | if (tools == null || tools.isEmpty()) { |
| | | return allowedTools; |
| | | } |
| | | for (Object tool : tools) { |
| | | String toolName = resolveOpenAiToolName(tool); |
| | | if (ALLOWED_TOOL_NAMES.contains(toolName)) { |
| | | allowedTools.add(tool); |
| | | } |
| | | } |
| | | return allowedTools; |
| | | } |
| | | |
| | | private String resolveOpenAiToolName(Object tool) { |
| | | if (!(tool instanceof Map<?, ?> toolMap)) { |
| | | return null; |
| | | } |
| | | Object function = toolMap.get("function"); |
| | | if (!(function instanceof Map<?, ?> functionMap)) { |
| | | return null; |
| | | } |
| | | Object name = functionMap.get("name"); |
| | | return name == null ? null : String.valueOf(name); |
| | | } |
| | | |
| | | private void appendSummary(StringBuilder summaryBuffer, String content) { |
| | |
| | | return llmCallCount; |
| | | } |
| | | } |
| | | |
| | | private static class RunState { |
| | | private int toolCallCount; |
| | | private boolean snapshotCalled; |
| | | private boolean toolError; |
| | | |
| | | void markToolSuccess(String toolName) { |
| | | toolCallCount++; |
| | | if (TOOL_GET_SNAPSHOT.equals(toolName)) { |
| | | snapshotCalled = true; |
| | | } |
| | | } |
| | | |
| | | void markToolError() { |
| | | toolError = true; |
| | | } |
| | | |
| | | boolean isSuccessful() { |
| | | return toolCallCount > 0 && snapshotCalled && !toolError; |
| | | } |
| | | |
| | | int getToolCallCount() { |
| | | return toolCallCount; |
| | | } |
| | | |
| | | boolean isSnapshotCalled() { |
| | | return snapshotCalled; |
| | | } |
| | | |
| | | boolean hasToolError() { |
| | | return toolError; |
| | | } |
| | | } |
| | | } |