package com.zy.ai.service.impl;
|
|
import com.alibaba.fastjson.JSON;
|
import com.alibaba.fastjson.JSONObject;
|
import com.zy.ai.entity.AiPromptTemplate;
|
import com.zy.ai.entity.ChatCompletionRequest;
|
import com.zy.ai.entity.ChatCompletionResponse;
|
import com.zy.ai.enums.AiPromptScene;
|
import com.zy.ai.mcp.service.SpringAiMcpToolManager;
|
import com.zy.ai.service.AiPromptTemplateService;
|
import com.zy.ai.service.AutoTuneAgentService;
|
import com.zy.ai.service.LlmChatService;
|
import lombok.RequiredArgsConstructor;
|
import lombok.extern.slf4j.Slf4j;
|
import org.springframework.stereotype.Service;
|
|
import java.util.ArrayList;
|
import java.util.List;
|
import java.util.Map;
|
import java.util.Set;
|
|
@Slf4j
|
@Service
|
@RequiredArgsConstructor
|
public class AutoTuneAgentServiceImpl implements AutoTuneAgentService {
|
|
private static final int MAX_TOOL_ROUNDS = 10;
|
private static final double TEMPERATURE = 0.2D;
|
private static final int MAX_TOKENS = 2048;
|
private static final String TOOL_GET_SNAPSHOT = "wcs_local_dispatch_get_auto_tune_snapshot";
|
private static final String TOOL_GET_RECENT_JOBS = "wcs_local_dispatch_get_recent_auto_tune_jobs";
|
private static final String TOOL_APPLY_CHANGES = "wcs_local_dispatch_apply_auto_tune_changes";
|
private static final String TOOL_REVERT_LAST_JOB = "wcs_local_dispatch_revert_last_auto_tune_job";
|
private static final Set<String> ALLOWED_TOOL_NAMES = Set.of(
|
TOOL_GET_SNAPSHOT,
|
TOOL_GET_RECENT_JOBS,
|
TOOL_APPLY_CHANGES,
|
TOOL_REVERT_LAST_JOB
|
);
|
|
private final LlmChatService llmChatService;
|
private final SpringAiMcpToolManager mcpToolManager;
|
private final AiPromptTemplateService aiPromptTemplateService;
|
|
@Override
|
public AutoTuneAgentResult runAutoTune(String triggerType) {
|
String normalizedTriggerType = normalizeTriggerType(triggerType);
|
UsageCounter usageCounter = new UsageCounter();
|
RunState runState = new RunState();
|
boolean maxRoundsReached = false;
|
StringBuilder summaryBuffer = new StringBuilder();
|
|
try {
|
List<Object> tools = filterAllowedTools(mcpToolManager.buildOpenAiTools());
|
if (tools == null || tools.isEmpty()) {
|
throw new IllegalStateException("No auto-tune MCP tools registered");
|
}
|
|
AiPromptTemplate promptTemplate = aiPromptTemplateService.resolvePublished(AiPromptScene.AUTO_TUNE_DISPATCH.getCode());
|
List<ChatCompletionRequest.Message> messages = buildMessages(promptTemplate, normalizedTriggerType);
|
|
for (int round = 0; round < MAX_TOOL_ROUNDS; round++) {
|
ChatCompletionResponse response = llmChatService.chatCompletion(messages, TEMPERATURE, MAX_TOKENS, tools);
|
ChatCompletionRequest.Message assistantMessage = extractAssistantMessage(response);
|
usageCounter.add(response.getUsage());
|
messages.add(assistantMessage);
|
appendSummary(summaryBuffer, assistantMessage.getContent());
|
|
List<ChatCompletionRequest.ToolCall> toolCalls = assistantMessage.getTool_calls();
|
if (toolCalls == null || toolCalls.isEmpty()) {
|
return buildResult(runState.isSuccessful(), normalizedTriggerType, summaryBuffer, runState,
|
usageCounter, false);
|
}
|
|
for (ChatCompletionRequest.ToolCall toolCall : toolCalls) {
|
Object toolOutput = callMountedTool(toolCall, runState);
|
messages.add(buildToolMessage(toolCall, toolOutput));
|
}
|
}
|
maxRoundsReached = true;
|
return buildResult(false, normalizedTriggerType, summaryBuffer, runState, usageCounter, maxRoundsReached);
|
} catch (Exception exception) {
|
log.error("Auto tune agent stopped with error", exception);
|
appendSummary(summaryBuffer, "自动调参 Agent 执行异常: " + exception.getMessage());
|
runState.markToolError();
|
return buildResult(false, normalizedTriggerType, summaryBuffer, runState, usageCounter, maxRoundsReached);
|
}
|
}
|
|
private List<ChatCompletionRequest.Message> buildMessages(AiPromptTemplate promptTemplate, String triggerType) {
|
List<ChatCompletionRequest.Message> messages = new ArrayList<>();
|
|
ChatCompletionRequest.Message systemMessage = new ChatCompletionRequest.Message();
|
systemMessage.setRole("system");
|
systemMessage.setContent(promptTemplate == null ? "" : promptTemplate.getContent());
|
messages.add(systemMessage);
|
|
ChatCompletionRequest.Message userMessage = new ChatCompletionRequest.Message();
|
userMessage.setRole("user");
|
userMessage.setContent("请执行一次后台 WCS 自动调参。triggerType=" + triggerType
|
+ "。必须先调用 wcs_local_dispatch_get_auto_tune_snapshot 获取事实;如需提交变更,"
|
+ "必须先 dry-run,再根据 dry-run 结果决定是否实际应用;实际应用时必须带上 dry-run 返回的 dryRunToken。"
|
+ "不要输出自由格式 JSON 供外层解析。");
|
messages.add(userMessage);
|
return messages;
|
}
|
|
private ChatCompletionRequest.Message extractAssistantMessage(ChatCompletionResponse response) {
|
if (response == null || response.getChoices() == null || response.getChoices().isEmpty()) {
|
throw new IllegalStateException("LLM returned empty response");
|
}
|
ChatCompletionRequest.Message message = response.getChoices().get(0).getMessage();
|
if (message == null) {
|
throw new IllegalStateException("LLM returned empty message");
|
}
|
return message;
|
}
|
|
private Object callMountedTool(ChatCompletionRequest.ToolCall toolCall, RunState runState) {
|
String toolName = resolveToolName(toolCall);
|
if (!ALLOWED_TOOL_NAMES.contains(toolName)) {
|
throw new IllegalArgumentException("Disallowed auto-tune MCP tool: " + toolName);
|
}
|
JSONObject arguments = parseArguments(toolCall);
|
try {
|
Object output = mcpToolManager.callTool(toolName, arguments);
|
runState.markToolSuccess(toolName);
|
return output;
|
} catch (Exception exception) {
|
throw new IllegalStateException("Auto-tune MCP tool failed: " + toolName + ", " + exception.getMessage(),
|
exception);
|
}
|
}
|
|
private ChatCompletionRequest.Message buildToolMessage(ChatCompletionRequest.ToolCall toolCall, Object toolOutput) {
|
ChatCompletionRequest.Message toolMessage = new ChatCompletionRequest.Message();
|
toolMessage.setRole("tool");
|
toolMessage.setTool_call_id(toolCall == null ? null : toolCall.getId());
|
toolMessage.setContent(JSON.toJSONString(toolOutput));
|
return toolMessage;
|
}
|
|
private String resolveToolName(ChatCompletionRequest.ToolCall toolCall) {
|
if (toolCall == null || toolCall.getFunction() == null || isBlank(toolCall.getFunction().getName())) {
|
throw new IllegalArgumentException("missing tool name");
|
}
|
return toolCall.getFunction().getName();
|
}
|
|
private JSONObject parseArguments(ChatCompletionRequest.ToolCall toolCall) {
|
String rawArguments = toolCall == null || toolCall.getFunction() == null
|
? null
|
: toolCall.getFunction().getArguments();
|
if (isBlank(rawArguments)) {
|
return new JSONObject();
|
}
|
try {
|
return JSON.parseObject(rawArguments);
|
} catch (Exception exception) {
|
JSONObject arguments = new JSONObject();
|
arguments.put("_raw", rawArguments);
|
return arguments;
|
}
|
}
|
|
private AutoTuneAgentResult buildResult(boolean success,
|
String triggerType,
|
StringBuilder summaryBuffer,
|
RunState runState,
|
UsageCounter usageCounter,
|
boolean maxRoundsReached) {
|
AutoTuneAgentResult result = new AutoTuneAgentResult();
|
result.setSuccess(success);
|
result.setTriggerType(triggerType);
|
result.setToolCallCount(runState.getToolCallCount());
|
result.setLlmCallCount(usageCounter.getLlmCallCount());
|
result.setPromptTokens(usageCounter.getPromptTokens());
|
result.setCompletionTokens(usageCounter.getCompletionTokens());
|
result.setTotalTokens(usageCounter.getTotalTokens());
|
result.setMaxRoundsReached(maxRoundsReached);
|
|
String summary = summaryBuffer == null ? "" : summaryBuffer.toString().trim();
|
if (runState.getToolCallCount() <= 0) {
|
summary = "自动调参 Agent 未调用任何允许的 MCP 工具,未执行调参。" + (summary.isEmpty() ? "" : "\n" + summary);
|
} else if (!runState.isSnapshotCalled()) {
|
summary = summary + "\n自动调参 Agent 未调用快照工具,结果不完整。";
|
}
|
if (runState.hasToolError()) {
|
summary = summary + "\n自动调参 Agent 存在工具调用错误,已标记为失败。";
|
}
|
if (maxRoundsReached) {
|
summary = summary + "\n自动调参 Agent 达到最大工具调用轮次,已停止。";
|
}
|
result.setSummary(summary);
|
return result;
|
}
|
|
private List<Object> filterAllowedTools(List<Object> tools) {
|
List<Object> allowedTools = new ArrayList<>();
|
if (tools == null || tools.isEmpty()) {
|
return allowedTools;
|
}
|
for (Object tool : tools) {
|
String toolName = resolveOpenAiToolName(tool);
|
if (ALLOWED_TOOL_NAMES.contains(toolName)) {
|
allowedTools.add(tool);
|
}
|
}
|
return allowedTools;
|
}
|
|
private String resolveOpenAiToolName(Object tool) {
|
if (!(tool instanceof Map<?, ?> toolMap)) {
|
return null;
|
}
|
Object function = toolMap.get("function");
|
if (!(function instanceof Map<?, ?> functionMap)) {
|
return null;
|
}
|
Object name = functionMap.get("name");
|
return name == null ? null : String.valueOf(name);
|
}
|
|
private void appendSummary(StringBuilder summaryBuffer, String content) {
|
if (summaryBuffer == null || isBlank(content)) {
|
return;
|
}
|
if (summaryBuffer.length() > 0) {
|
summaryBuffer.append('\n');
|
}
|
summaryBuffer.append(content.trim());
|
}
|
|
private String normalizeTriggerType(String triggerType) {
|
return isBlank(triggerType) ? "agent" : triggerType.trim();
|
}
|
|
private boolean isBlank(String value) {
|
return value == null || value.trim().isEmpty();
|
}
|
|
private static class UsageCounter {
|
private long promptTokens;
|
private long completionTokens;
|
private long totalTokens;
|
private int llmCallCount;
|
|
void add(ChatCompletionResponse.Usage usage) {
|
llmCallCount++;
|
if (usage == null) {
|
return;
|
}
|
promptTokens += usage.getPromptTokens() == null ? 0L : usage.getPromptTokens();
|
completionTokens += usage.getCompletionTokens() == null ? 0L : usage.getCompletionTokens();
|
totalTokens += usage.getTotalTokens() == null ? 0L : usage.getTotalTokens();
|
}
|
|
long getPromptTokens() {
|
return promptTokens;
|
}
|
|
long getCompletionTokens() {
|
return completionTokens;
|
}
|
|
long getTotalTokens() {
|
return totalTokens;
|
}
|
|
int getLlmCallCount() {
|
return llmCallCount;
|
}
|
}
|
|
private static class RunState {
|
private int toolCallCount;
|
private boolean snapshotCalled;
|
private boolean toolError;
|
|
void markToolSuccess(String toolName) {
|
toolCallCount++;
|
if (TOOL_GET_SNAPSHOT.equals(toolName)) {
|
snapshotCalled = true;
|
}
|
}
|
|
void markToolError() {
|
toolError = true;
|
}
|
|
boolean isSuccessful() {
|
return toolCallCount > 0 && snapshotCalled && !toolError;
|
}
|
|
int getToolCallCount() {
|
return toolCallCount;
|
}
|
|
boolean isSnapshotCalled() {
|
return snapshotCalled;
|
}
|
|
boolean hasToolError() {
|
return toolError;
|
}
|
}
|
}
|