package com.zy.ai.service.impl; import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONObject; import com.zy.ai.domain.autotune.AutoTuneTriggerType; import com.zy.ai.entity.AiPromptTemplate; import com.zy.ai.entity.ChatCompletionRequest; import com.zy.ai.entity.ChatCompletionResponse; import com.zy.ai.enums.AiPromptScene; import com.zy.ai.mcp.service.SpringAiMcpToolManager; import com.zy.ai.service.AiPromptTemplateService; import com.zy.ai.service.AutoTuneAgentService; import com.zy.ai.service.LlmChatService; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Set; @Slf4j @Service @RequiredArgsConstructor public class AutoTuneAgentServiceImpl implements AutoTuneAgentService { private static final int MAX_TOOL_ROUNDS = 10; private static final double TEMPERATURE = 0.2D; private static final int MAX_TOKENS = 2048; private static final String TOOL_GET_SNAPSHOT = "wcs_local_dispatch_get_auto_tune_snapshot"; private static final String TOOL_GET_RECENT_JOBS = "wcs_local_dispatch_get_recent_auto_tune_jobs"; private static final String TOOL_APPLY_CHANGES = "wcs_local_dispatch_apply_auto_tune_changes"; private static final String TOOL_REVERT_LAST_JOB = "wcs_local_dispatch_revert_last_auto_tune_job"; private static final Set ALLOWED_TOOL_NAMES = Set.of( TOOL_GET_SNAPSHOT, TOOL_GET_RECENT_JOBS, TOOL_APPLY_CHANGES, TOOL_REVERT_LAST_JOB ); private final LlmChatService llmChatService; private final SpringAiMcpToolManager mcpToolManager; private final AiPromptTemplateService aiPromptTemplateService; @Override public AutoTuneAgentResult runAutoTune(String triggerType) { String normalizedTriggerType = normalizeTriggerType(triggerType); UsageCounter usageCounter = new UsageCounter(); RunState runState = new RunState(); boolean maxRoundsReached = false; StringBuilder summaryBuffer = new StringBuilder(); try { List tools = filterAllowedTools(mcpToolManager.buildOpenAiTools()); if (tools == null || tools.isEmpty()) { throw new IllegalStateException("No auto-tune MCP tools registered"); } AiPromptTemplate promptTemplate = aiPromptTemplateService.resolvePublished(AiPromptScene.AUTO_TUNE_DISPATCH.getCode()); List messages = buildMessages(promptTemplate, normalizedTriggerType); for (int round = 0; round < MAX_TOOL_ROUNDS; round++) { ChatCompletionResponse response = llmChatService.chatCompletion(messages, TEMPERATURE, MAX_TOKENS, tools); ChatCompletionRequest.Message assistantMessage = extractAssistantMessage(response); usageCounter.add(response.getUsage()); messages.add(assistantMessage); appendSummary(summaryBuffer, assistantMessage.getContent()); List toolCalls = assistantMessage.getTool_calls(); if (toolCalls == null || toolCalls.isEmpty()) { return buildResult(runState.isSuccessful(), normalizedTriggerType, summaryBuffer, runState, usageCounter, false); } for (ChatCompletionRequest.ToolCall toolCall : toolCalls) { Object toolOutput = callMountedTool(toolCall, runState, normalizedTriggerType); messages.add(buildToolMessage(toolCall, toolOutput)); } } maxRoundsReached = true; return buildResult(false, normalizedTriggerType, summaryBuffer, runState, usageCounter, maxRoundsReached); } catch (Exception exception) { log.error("Auto tune agent stopped with error", exception); appendSummary(summaryBuffer, "自动调参 Agent 执行异常: " + exception.getMessage()); runState.markToolError(); return buildResult(false, normalizedTriggerType, summaryBuffer, runState, usageCounter, maxRoundsReached); } } private List buildMessages(AiPromptTemplate promptTemplate, String triggerType) { List messages = new ArrayList<>(); ChatCompletionRequest.Message systemMessage = new ChatCompletionRequest.Message(); systemMessage.setRole("system"); systemMessage.setContent(promptTemplate == null ? "" : promptTemplate.getContent()); messages.add(systemMessage); ChatCompletionRequest.Message userMessage = new ChatCompletionRequest.Message(); userMessage.setRole("user"); userMessage.setContent("请执行一次后台 WCS 自动调参。triggerType=" + triggerType + "。必须先调用 wcs_local_dispatch_get_auto_tune_snapshot 获取事实;如需提交变更," + "必须先 dry-run,再根据 dry-run 结果决定是否实际应用;实际应用时必须带上 dry-run 返回的 dryRunToken。" + "不要输出自由格式 JSON 供外层解析。"); messages.add(userMessage); return messages; } private ChatCompletionRequest.Message extractAssistantMessage(ChatCompletionResponse response) { if (response == null || response.getChoices() == null || response.getChoices().isEmpty()) { throw new IllegalStateException("LLM returned empty response"); } ChatCompletionRequest.Message message = response.getChoices().get(0).getMessage(); if (message == null) { throw new IllegalStateException("LLM returned empty message"); } return message; } private Object callMountedTool(ChatCompletionRequest.ToolCall toolCall, RunState runState, String triggerType) { String toolName = resolveToolName(toolCall); if (!ALLOWED_TOOL_NAMES.contains(toolName)) { throw new IllegalArgumentException("Disallowed auto-tune MCP tool: " + toolName); } JSONObject arguments = parseArguments(toolCall); applySchedulerTriggerType(toolName, triggerType, arguments); try { Object output = mcpToolManager.callTool(toolName, arguments); runState.markToolSuccess(toolName); return output; } catch (Exception exception) { throw new IllegalStateException("Auto-tune MCP tool failed: " + toolName + ", " + exception.getMessage(), exception); } } private void applySchedulerTriggerType(String toolName, String triggerType, JSONObject arguments) { if (!TOOL_APPLY_CHANGES.equals(toolName)) { return; } if (!AutoTuneTriggerType.AUTO.getCode().equals(triggerType)) { return; } arguments.put("triggerType", AutoTuneTriggerType.AUTO.getCode()); } private ChatCompletionRequest.Message buildToolMessage(ChatCompletionRequest.ToolCall toolCall, Object toolOutput) { ChatCompletionRequest.Message toolMessage = new ChatCompletionRequest.Message(); toolMessage.setRole("tool"); toolMessage.setTool_call_id(toolCall == null ? null : toolCall.getId()); toolMessage.setContent(JSON.toJSONString(toolOutput)); return toolMessage; } private String resolveToolName(ChatCompletionRequest.ToolCall toolCall) { if (toolCall == null || toolCall.getFunction() == null || isBlank(toolCall.getFunction().getName())) { throw new IllegalArgumentException("missing tool name"); } return toolCall.getFunction().getName(); } private JSONObject parseArguments(ChatCompletionRequest.ToolCall toolCall) { String rawArguments = toolCall == null || toolCall.getFunction() == null ? null : toolCall.getFunction().getArguments(); if (isBlank(rawArguments)) { return new JSONObject(); } try { return JSON.parseObject(rawArguments); } catch (Exception exception) { JSONObject arguments = new JSONObject(); arguments.put("_raw", rawArguments); return arguments; } } private AutoTuneAgentResult buildResult(boolean success, String triggerType, StringBuilder summaryBuffer, RunState runState, UsageCounter usageCounter, boolean maxRoundsReached) { AutoTuneAgentResult result = new AutoTuneAgentResult(); result.setSuccess(success); result.setTriggerType(triggerType); result.setToolCallCount(runState.getToolCallCount()); result.setLlmCallCount(usageCounter.getLlmCallCount()); result.setPromptTokens(usageCounter.getPromptTokens()); result.setCompletionTokens(usageCounter.getCompletionTokens()); result.setTotalTokens(usageCounter.getTotalTokens()); result.setMaxRoundsReached(maxRoundsReached); String summary = summaryBuffer == null ? "" : summaryBuffer.toString().trim(); if (runState.getToolCallCount() <= 0) { summary = "自动调参 Agent 未调用任何允许的 MCP 工具,未执行调参。" + (summary.isEmpty() ? "" : "\n" + summary); } else if (!runState.isSnapshotCalled()) { summary = summary + "\n自动调参 Agent 未调用快照工具,结果不完整。"; } if (runState.hasToolError()) { summary = summary + "\n自动调参 Agent 存在工具调用错误,已标记为失败。"; } if (maxRoundsReached) { summary = summary + "\n自动调参 Agent 达到最大工具调用轮次,已停止。"; } result.setSummary(summary); return result; } private List filterAllowedTools(List tools) { List allowedTools = new ArrayList<>(); if (tools == null || tools.isEmpty()) { return allowedTools; } for (Object tool : tools) { String toolName = resolveOpenAiToolName(tool); if (ALLOWED_TOOL_NAMES.contains(toolName)) { allowedTools.add(tool); } } return allowedTools; } private String resolveOpenAiToolName(Object tool) { if (!(tool instanceof Map toolMap)) { return null; } Object function = toolMap.get("function"); if (!(function instanceof Map functionMap)) { return null; } Object name = functionMap.get("name"); return name == null ? null : String.valueOf(name); } private void appendSummary(StringBuilder summaryBuffer, String content) { if (summaryBuffer == null || isBlank(content)) { return; } if (summaryBuffer.length() > 0) { summaryBuffer.append('\n'); } summaryBuffer.append(content.trim()); } private String normalizeTriggerType(String triggerType) { return isBlank(triggerType) ? "agent" : triggerType.trim(); } private boolean isBlank(String value) { return value == null || value.trim().isEmpty(); } private static class UsageCounter { private long promptTokens; private long completionTokens; private long totalTokens; private int llmCallCount; void add(ChatCompletionResponse.Usage usage) { llmCallCount++; if (usage == null) { return; } promptTokens += usage.getPromptTokens() == null ? 0L : usage.getPromptTokens(); completionTokens += usage.getCompletionTokens() == null ? 0L : usage.getCompletionTokens(); totalTokens += usage.getTotalTokens() == null ? 0L : usage.getTotalTokens(); } long getPromptTokens() { return promptTokens; } long getCompletionTokens() { return completionTokens; } long getTotalTokens() { return totalTokens; } int getLlmCallCount() { return llmCallCount; } } private static class RunState { private int toolCallCount; private boolean snapshotCalled; private boolean toolError; void markToolSuccess(String toolName) { toolCallCount++; if (TOOL_GET_SNAPSHOT.equals(toolName)) { snapshotCalled = true; } } void markToolError() { toolError = true; } boolean isSuccessful() { return toolCallCount > 0 && snapshotCalled && !toolError; } int getToolCallCount() { return toolCallCount; } boolean isSnapshotCalled() { return snapshotCalled; } boolean hasToolError() { return toolError; } } }