Junjie
2026-04-27 e4e91b46d0ce781e7dc87dcdf0d2909b01911d4b
src/main/java/com/zy/ai/service/impl/AutoTuneAgentServiceImpl.java
@@ -2,6 +2,7 @@
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.zy.ai.domain.autotune.AutoTuneTriggerType;
import com.zy.ai.entity.AiPromptTemplate;
import com.zy.ai.entity.ChatCompletionRequest;
import com.zy.ai.entity.ChatCompletionResponse;
@@ -15,8 +16,9 @@
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
@Slf4j
@Service
@@ -26,6 +28,16 @@
    private static final int MAX_TOOL_ROUNDS = 10;
    private static final double TEMPERATURE = 0.2D;
    private static final int MAX_TOKENS = 2048;
    private static final String TOOL_GET_SNAPSHOT = "wcs_local_dispatch_get_auto_tune_snapshot";
    private static final String TOOL_GET_RECENT_JOBS = "wcs_local_dispatch_get_recent_auto_tune_jobs";
    private static final String TOOL_APPLY_CHANGES = "wcs_local_dispatch_apply_auto_tune_changes";
    private static final String TOOL_REVERT_LAST_JOB = "wcs_local_dispatch_revert_last_auto_tune_job";
    private static final Set<String> ALLOWED_TOOL_NAMES = Set.of(
            TOOL_GET_SNAPSHOT,
            TOOL_GET_RECENT_JOBS,
            TOOL_APPLY_CHANGES,
            TOOL_REVERT_LAST_JOB
    );
    private final LlmChatService llmChatService;
    private final SpringAiMcpToolManager mcpToolManager;
@@ -35,14 +47,14 @@
    public AutoTuneAgentResult runAutoTune(String triggerType) {
        String normalizedTriggerType = normalizeTriggerType(triggerType);
        UsageCounter usageCounter = new UsageCounter();
        int toolCallCount = 0;
        RunState runState = new RunState();
        boolean maxRoundsReached = false;
        StringBuilder summaryBuffer = new StringBuilder();
        try {
            List<Object> tools = mcpToolManager.buildOpenAiTools();
            List<Object> tools = filterAllowedTools(mcpToolManager.buildOpenAiTools());
            if (tools == null || tools.isEmpty()) {
                throw new IllegalStateException("No MCP tools registered");
                throw new IllegalStateException("No auto-tune MCP tools registered");
            }
            AiPromptTemplate promptTemplate = aiPromptTemplateService.resolvePublished(AiPromptScene.AUTO_TUNE_DISPATCH.getCode());
@@ -57,21 +69,22 @@
                List<ChatCompletionRequest.ToolCall> toolCalls = assistantMessage.getTool_calls();
                if (toolCalls == null || toolCalls.isEmpty()) {
                    return buildResult(true, normalizedTriggerType, summaryBuffer, toolCallCount, usageCounter, false);
                    return buildResult(runState.isSuccessful(), normalizedTriggerType, summaryBuffer, runState,
                            usageCounter, false);
                }
                for (ChatCompletionRequest.ToolCall toolCall : toolCalls) {
                    Object toolOutput = callMountedTool(toolCall);
                    toolCallCount++;
                    Object toolOutput = callMountedTool(toolCall, runState, normalizedTriggerType);
                    messages.add(buildToolMessage(toolCall, toolOutput));
                }
            }
            maxRoundsReached = true;
            return buildResult(true, normalizedTriggerType, summaryBuffer, toolCallCount, usageCounter, maxRoundsReached);
            return buildResult(false, normalizedTriggerType, summaryBuffer, runState, usageCounter, maxRoundsReached);
        } catch (Exception exception) {
            log.error("Auto tune agent stopped with error", exception);
            appendSummary(summaryBuffer, "自动调参 Agent 执行异常: " + exception.getMessage());
            return buildResult(false, normalizedTriggerType, summaryBuffer, toolCallCount, usageCounter, maxRoundsReached);
            runState.markToolError();
            return buildResult(false, normalizedTriggerType, summaryBuffer, runState, usageCounter, maxRoundsReached);
        }
    }
@@ -87,7 +100,8 @@
        userMessage.setRole("user");
        userMessage.setContent("请执行一次后台 WCS 自动调参。triggerType=" + triggerType
                + "。必须先调用 wcs_local_dispatch_get_auto_tune_snapshot 获取事实;如需提交变更,"
                + "必须先 dry-run,再根据 dry-run 结果决定是否实际应用。不要输出自由格式 JSON 供外层解析。");
                + "必须先 dry-run,再根据 dry-run 结果决定是否实际应用;实际应用时必须带上 dry-run 返回的 dryRunToken。"
                + "不要输出自由格式 JSON 供外层解析。");
        messages.add(userMessage);
        return messages;
    }
@@ -103,17 +117,31 @@
        return message;
    }
    private Object callMountedTool(ChatCompletionRequest.ToolCall toolCall) {
    private Object callMountedTool(ChatCompletionRequest.ToolCall toolCall, RunState runState, String triggerType) {
        String toolName = resolveToolName(toolCall);
        JSONObject arguments = parseArguments(toolCall);
        try {
            return mcpToolManager.callTool(toolName, arguments);
        } catch (Exception exception) {
            LinkedHashMap<String, Object> error = new LinkedHashMap<>();
            error.put("tool", toolName);
            error.put("error", exception.getMessage());
            return error;
        if (!ALLOWED_TOOL_NAMES.contains(toolName)) {
            throw new IllegalArgumentException("Disallowed auto-tune MCP tool: " + toolName);
        }
        JSONObject arguments = parseArguments(toolCall);
        applySchedulerTriggerType(toolName, triggerType, arguments);
        try {
            Object output = mcpToolManager.callTool(toolName, arguments);
            runState.markToolSuccess(toolName);
            return output;
        } catch (Exception exception) {
            throw new IllegalStateException("Auto-tune MCP tool failed: " + toolName + ", " + exception.getMessage(),
                    exception);
        }
    }
    private void applySchedulerTriggerType(String toolName, String triggerType, JSONObject arguments) {
        if (!TOOL_APPLY_CHANGES.equals(toolName)) {
            return;
        }
        if (!AutoTuneTriggerType.AUTO.getCode().equals(triggerType)) {
            return;
        }
        arguments.put("triggerType", AutoTuneTriggerType.AUTO.getCode());
    }
    private ChatCompletionRequest.Message buildToolMessage(ChatCompletionRequest.ToolCall toolCall, Object toolOutput) {
@@ -150,13 +178,13 @@
    private AutoTuneAgentResult buildResult(boolean success,
                                            String triggerType,
                                            StringBuilder summaryBuffer,
                                            int toolCallCount,
                                            RunState runState,
                                            UsageCounter usageCounter,
                                            boolean maxRoundsReached) {
        AutoTuneAgentResult result = new AutoTuneAgentResult();
        result.setSuccess(success);
        result.setTriggerType(triggerType);
        result.setToolCallCount(toolCallCount);
        result.setToolCallCount(runState.getToolCallCount());
        result.setLlmCallCount(usageCounter.getLlmCallCount());
        result.setPromptTokens(usageCounter.getPromptTokens());
        result.setCompletionTokens(usageCounter.getCompletionTokens());
@@ -164,14 +192,45 @@
        result.setMaxRoundsReached(maxRoundsReached);
        String summary = summaryBuffer == null ? "" : summaryBuffer.toString().trim();
        if (toolCallCount <= 0 && success) {
            summary = "自动调参 Agent 未调用任何 MCP 工具,未执行调参。" + (summary.isEmpty() ? "" : "\n" + summary);
        if (runState.getToolCallCount() <= 0) {
            summary = "自动调参 Agent 未调用任何允许的 MCP 工具,未执行调参。" + (summary.isEmpty() ? "" : "\n" + summary);
        } else if (!runState.isSnapshotCalled()) {
            summary = summary + "\n自动调参 Agent 未调用快照工具,结果不完整。";
        }
        if (runState.hasToolError()) {
            summary = summary + "\n自动调参 Agent 存在工具调用错误,已标记为失败。";
        }
        if (maxRoundsReached) {
            summary = summary + "\n自动调参 Agent 达到最大工具调用轮次,已停止。";
        }
        result.setSummary(summary);
        return result;
    }
    private List<Object> filterAllowedTools(List<Object> tools) {
        List<Object> allowedTools = new ArrayList<>();
        if (tools == null || tools.isEmpty()) {
            return allowedTools;
        }
        for (Object tool : tools) {
            String toolName = resolveOpenAiToolName(tool);
            if (ALLOWED_TOOL_NAMES.contains(toolName)) {
                allowedTools.add(tool);
            }
        }
        return allowedTools;
    }
    private String resolveOpenAiToolName(Object tool) {
        if (!(tool instanceof Map<?, ?> toolMap)) {
            return null;
        }
        Object function = toolMap.get("function");
        if (!(function instanceof Map<?, ?> functionMap)) {
            return null;
        }
        Object name = functionMap.get("name");
        return name == null ? null : String.valueOf(name);
    }
    private void appendSummary(StringBuilder summaryBuffer, String content) {
@@ -224,4 +283,37 @@
            return llmCallCount;
        }
    }
    private static class RunState {
        private int toolCallCount;
        private boolean snapshotCalled;
        private boolean toolError;
        void markToolSuccess(String toolName) {
            toolCallCount++;
            if (TOOL_GET_SNAPSHOT.equals(toolName)) {
                snapshotCalled = true;
            }
        }
        void markToolError() {
            toolError = true;
        }
        boolean isSuccessful() {
            return toolCallCount > 0 && snapshotCalled && !toolError;
        }
        int getToolCallCount() {
            return toolCallCount;
        }
        boolean isSnapshotCalled() {
            return snapshotCalled;
        }
        boolean hasToolError() {
            return toolError;
        }
    }
}