package com.zy.ai.service.impl; import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONObject; import com.zy.ai.entity.AiPromptTemplate; import com.zy.ai.entity.ChatCompletionRequest; import com.zy.ai.entity.ChatCompletionResponse; import com.zy.ai.enums.AiPromptScene; import com.zy.ai.mcp.service.SpringAiMcpToolManager; import com.zy.ai.service.AiPromptTemplateService; import com.zy.ai.service.AutoTuneAgentService; import com.zy.ai.service.LlmChatService; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; @Slf4j @Service @RequiredArgsConstructor public class AutoTuneAgentServiceImpl implements AutoTuneAgentService { private static final int MAX_TOOL_ROUNDS = 10; private static final double TEMPERATURE = 0.2D; private static final int MAX_TOKENS = 2048; private final LlmChatService llmChatService; private final SpringAiMcpToolManager mcpToolManager; private final AiPromptTemplateService aiPromptTemplateService; @Override public AutoTuneAgentResult runAutoTune(String triggerType) { String normalizedTriggerType = normalizeTriggerType(triggerType); UsageCounter usageCounter = new UsageCounter(); int toolCallCount = 0; boolean maxRoundsReached = false; StringBuilder summaryBuffer = new StringBuilder(); try { List tools = mcpToolManager.buildOpenAiTools(); if (tools == null || tools.isEmpty()) { throw new IllegalStateException("No MCP tools registered"); } AiPromptTemplate promptTemplate = aiPromptTemplateService.resolvePublished(AiPromptScene.AUTO_TUNE_DISPATCH.getCode()); List messages = buildMessages(promptTemplate, normalizedTriggerType); for (int round = 0; round < MAX_TOOL_ROUNDS; round++) { ChatCompletionResponse response = llmChatService.chatCompletion(messages, TEMPERATURE, MAX_TOKENS, tools); ChatCompletionRequest.Message assistantMessage = extractAssistantMessage(response); usageCounter.add(response.getUsage()); messages.add(assistantMessage); appendSummary(summaryBuffer, assistantMessage.getContent()); List toolCalls = assistantMessage.getTool_calls(); if (toolCalls == null || toolCalls.isEmpty()) { return buildResult(true, normalizedTriggerType, summaryBuffer, toolCallCount, usageCounter, false); } for (ChatCompletionRequest.ToolCall toolCall : toolCalls) { Object toolOutput = callMountedTool(toolCall); toolCallCount++; messages.add(buildToolMessage(toolCall, toolOutput)); } } maxRoundsReached = true; return buildResult(true, normalizedTriggerType, summaryBuffer, toolCallCount, usageCounter, maxRoundsReached); } catch (Exception exception) { log.error("Auto tune agent stopped with error", exception); appendSummary(summaryBuffer, "自动调参 Agent 执行异常: " + exception.getMessage()); return buildResult(false, normalizedTriggerType, summaryBuffer, toolCallCount, usageCounter, maxRoundsReached); } } private List buildMessages(AiPromptTemplate promptTemplate, String triggerType) { List messages = new ArrayList<>(); ChatCompletionRequest.Message systemMessage = new ChatCompletionRequest.Message(); systemMessage.setRole("system"); systemMessage.setContent(promptTemplate == null ? "" : promptTemplate.getContent()); messages.add(systemMessage); ChatCompletionRequest.Message userMessage = new ChatCompletionRequest.Message(); userMessage.setRole("user"); userMessage.setContent("请执行一次后台 WCS 自动调参。triggerType=" + triggerType + "。必须先调用 wcs_local_dispatch_get_auto_tune_snapshot 获取事实;如需提交变更," + "必须先 dry-run,再根据 dry-run 结果决定是否实际应用。不要输出自由格式 JSON 供外层解析。"); messages.add(userMessage); return messages; } private ChatCompletionRequest.Message extractAssistantMessage(ChatCompletionResponse response) { if (response == null || response.getChoices() == null || response.getChoices().isEmpty()) { throw new IllegalStateException("LLM returned empty response"); } ChatCompletionRequest.Message message = response.getChoices().get(0).getMessage(); if (message == null) { throw new IllegalStateException("LLM returned empty message"); } return message; } private Object callMountedTool(ChatCompletionRequest.ToolCall toolCall) { String toolName = resolveToolName(toolCall); JSONObject arguments = parseArguments(toolCall); try { return mcpToolManager.callTool(toolName, arguments); } catch (Exception exception) { LinkedHashMap error = new LinkedHashMap<>(); error.put("tool", toolName); error.put("error", exception.getMessage()); return error; } } private ChatCompletionRequest.Message buildToolMessage(ChatCompletionRequest.ToolCall toolCall, Object toolOutput) { ChatCompletionRequest.Message toolMessage = new ChatCompletionRequest.Message(); toolMessage.setRole("tool"); toolMessage.setTool_call_id(toolCall == null ? null : toolCall.getId()); toolMessage.setContent(JSON.toJSONString(toolOutput)); return toolMessage; } private String resolveToolName(ChatCompletionRequest.ToolCall toolCall) { if (toolCall == null || toolCall.getFunction() == null || isBlank(toolCall.getFunction().getName())) { throw new IllegalArgumentException("missing tool name"); } return toolCall.getFunction().getName(); } private JSONObject parseArguments(ChatCompletionRequest.ToolCall toolCall) { String rawArguments = toolCall == null || toolCall.getFunction() == null ? null : toolCall.getFunction().getArguments(); if (isBlank(rawArguments)) { return new JSONObject(); } try { return JSON.parseObject(rawArguments); } catch (Exception exception) { JSONObject arguments = new JSONObject(); arguments.put("_raw", rawArguments); return arguments; } } private AutoTuneAgentResult buildResult(boolean success, String triggerType, StringBuilder summaryBuffer, int toolCallCount, UsageCounter usageCounter, boolean maxRoundsReached) { AutoTuneAgentResult result = new AutoTuneAgentResult(); result.setSuccess(success); result.setTriggerType(triggerType); result.setToolCallCount(toolCallCount); result.setLlmCallCount(usageCounter.getLlmCallCount()); result.setPromptTokens(usageCounter.getPromptTokens()); result.setCompletionTokens(usageCounter.getCompletionTokens()); result.setTotalTokens(usageCounter.getTotalTokens()); result.setMaxRoundsReached(maxRoundsReached); String summary = summaryBuffer == null ? "" : summaryBuffer.toString().trim(); if (toolCallCount <= 0 && success) { summary = "自动调参 Agent 未调用任何 MCP 工具,未执行调参。" + (summary.isEmpty() ? "" : "\n" + summary); } if (maxRoundsReached) { summary = summary + "\n自动调参 Agent 达到最大工具调用轮次,已停止。"; } result.setSummary(summary); return result; } private void appendSummary(StringBuilder summaryBuffer, String content) { if (summaryBuffer == null || isBlank(content)) { return; } if (summaryBuffer.length() > 0) { summaryBuffer.append('\n'); } summaryBuffer.append(content.trim()); } private String normalizeTriggerType(String triggerType) { return isBlank(triggerType) ? "agent" : triggerType.trim(); } private boolean isBlank(String value) { return value == null || value.trim().isEmpty(); } private static class UsageCounter { private long promptTokens; private long completionTokens; private long totalTokens; private int llmCallCount; void add(ChatCompletionResponse.Usage usage) { llmCallCount++; if (usage == null) { return; } promptTokens += usage.getPromptTokens() == null ? 0L : usage.getPromptTokens(); completionTokens += usage.getCompletionTokens() == null ? 0L : usage.getCompletionTokens(); totalTokens += usage.getTotalTokens() == null ? 0L : usage.getTotalTokens(); } long getPromptTokens() { return promptTokens; } long getCompletionTokens() { return completionTokens; } long getTotalTokens() { return totalTokens; } int getLlmCallCount() { return llmCallCount; } } }