package com.zy.ai.service.impl;
|
|
import com.alibaba.fastjson.JSON;
|
import com.alibaba.fastjson.JSONObject;
|
import com.zy.ai.entity.AiPromptTemplate;
|
import com.zy.ai.entity.ChatCompletionRequest;
|
import com.zy.ai.entity.ChatCompletionResponse;
|
import com.zy.ai.enums.AiPromptScene;
|
import com.zy.ai.mcp.service.SpringAiMcpToolManager;
|
import com.zy.ai.service.AiPromptTemplateService;
|
import com.zy.ai.service.AutoTuneAgentService;
|
import com.zy.ai.service.LlmChatService;
|
import lombok.RequiredArgsConstructor;
|
import lombok.extern.slf4j.Slf4j;
|
import org.springframework.stereotype.Service;
|
|
import java.util.ArrayList;
|
import java.util.LinkedHashMap;
|
import java.util.List;
|
|
@Slf4j
|
@Service
|
@RequiredArgsConstructor
|
public class AutoTuneAgentServiceImpl implements AutoTuneAgentService {
|
|
private static final int MAX_TOOL_ROUNDS = 10;
|
private static final double TEMPERATURE = 0.2D;
|
private static final int MAX_TOKENS = 2048;
|
|
private final LlmChatService llmChatService;
|
private final SpringAiMcpToolManager mcpToolManager;
|
private final AiPromptTemplateService aiPromptTemplateService;
|
|
@Override
|
public AutoTuneAgentResult runAutoTune(String triggerType) {
|
String normalizedTriggerType = normalizeTriggerType(triggerType);
|
UsageCounter usageCounter = new UsageCounter();
|
int toolCallCount = 0;
|
boolean maxRoundsReached = false;
|
StringBuilder summaryBuffer = new StringBuilder();
|
|
try {
|
List<Object> tools = mcpToolManager.buildOpenAiTools();
|
if (tools == null || tools.isEmpty()) {
|
throw new IllegalStateException("No MCP tools registered");
|
}
|
|
AiPromptTemplate promptTemplate = aiPromptTemplateService.resolvePublished(AiPromptScene.AUTO_TUNE_DISPATCH.getCode());
|
List<ChatCompletionRequest.Message> messages = buildMessages(promptTemplate, normalizedTriggerType);
|
|
for (int round = 0; round < MAX_TOOL_ROUNDS; round++) {
|
ChatCompletionResponse response = llmChatService.chatCompletion(messages, TEMPERATURE, MAX_TOKENS, tools);
|
ChatCompletionRequest.Message assistantMessage = extractAssistantMessage(response);
|
usageCounter.add(response.getUsage());
|
messages.add(assistantMessage);
|
appendSummary(summaryBuffer, assistantMessage.getContent());
|
|
List<ChatCompletionRequest.ToolCall> toolCalls = assistantMessage.getTool_calls();
|
if (toolCalls == null || toolCalls.isEmpty()) {
|
return buildResult(true, normalizedTriggerType, summaryBuffer, toolCallCount, usageCounter, false);
|
}
|
|
for (ChatCompletionRequest.ToolCall toolCall : toolCalls) {
|
Object toolOutput = callMountedTool(toolCall);
|
toolCallCount++;
|
messages.add(buildToolMessage(toolCall, toolOutput));
|
}
|
}
|
maxRoundsReached = true;
|
return buildResult(true, normalizedTriggerType, summaryBuffer, toolCallCount, usageCounter, maxRoundsReached);
|
} catch (Exception exception) {
|
log.error("Auto tune agent stopped with error", exception);
|
appendSummary(summaryBuffer, "自动调参 Agent 执行异常: " + exception.getMessage());
|
return buildResult(false, normalizedTriggerType, summaryBuffer, toolCallCount, usageCounter, maxRoundsReached);
|
}
|
}
|
|
private List<ChatCompletionRequest.Message> buildMessages(AiPromptTemplate promptTemplate, String triggerType) {
|
List<ChatCompletionRequest.Message> messages = new ArrayList<>();
|
|
ChatCompletionRequest.Message systemMessage = new ChatCompletionRequest.Message();
|
systemMessage.setRole("system");
|
systemMessage.setContent(promptTemplate == null ? "" : promptTemplate.getContent());
|
messages.add(systemMessage);
|
|
ChatCompletionRequest.Message userMessage = new ChatCompletionRequest.Message();
|
userMessage.setRole("user");
|
userMessage.setContent("请执行一次后台 WCS 自动调参。triggerType=" + triggerType
|
+ "。必须先调用 wcs_local_dispatch_get_auto_tune_snapshot 获取事实;如需提交变更,"
|
+ "必须先 dry-run,再根据 dry-run 结果决定是否实际应用。不要输出自由格式 JSON 供外层解析。");
|
messages.add(userMessage);
|
return messages;
|
}
|
|
private ChatCompletionRequest.Message extractAssistantMessage(ChatCompletionResponse response) {
|
if (response == null || response.getChoices() == null || response.getChoices().isEmpty()) {
|
throw new IllegalStateException("LLM returned empty response");
|
}
|
ChatCompletionRequest.Message message = response.getChoices().get(0).getMessage();
|
if (message == null) {
|
throw new IllegalStateException("LLM returned empty message");
|
}
|
return message;
|
}
|
|
private Object callMountedTool(ChatCompletionRequest.ToolCall toolCall) {
|
String toolName = resolveToolName(toolCall);
|
JSONObject arguments = parseArguments(toolCall);
|
try {
|
return mcpToolManager.callTool(toolName, arguments);
|
} catch (Exception exception) {
|
LinkedHashMap<String, Object> error = new LinkedHashMap<>();
|
error.put("tool", toolName);
|
error.put("error", exception.getMessage());
|
return error;
|
}
|
}
|
|
private ChatCompletionRequest.Message buildToolMessage(ChatCompletionRequest.ToolCall toolCall, Object toolOutput) {
|
ChatCompletionRequest.Message toolMessage = new ChatCompletionRequest.Message();
|
toolMessage.setRole("tool");
|
toolMessage.setTool_call_id(toolCall == null ? null : toolCall.getId());
|
toolMessage.setContent(JSON.toJSONString(toolOutput));
|
return toolMessage;
|
}
|
|
private String resolveToolName(ChatCompletionRequest.ToolCall toolCall) {
|
if (toolCall == null || toolCall.getFunction() == null || isBlank(toolCall.getFunction().getName())) {
|
throw new IllegalArgumentException("missing tool name");
|
}
|
return toolCall.getFunction().getName();
|
}
|
|
private JSONObject parseArguments(ChatCompletionRequest.ToolCall toolCall) {
|
String rawArguments = toolCall == null || toolCall.getFunction() == null
|
? null
|
: toolCall.getFunction().getArguments();
|
if (isBlank(rawArguments)) {
|
return new JSONObject();
|
}
|
try {
|
return JSON.parseObject(rawArguments);
|
} catch (Exception exception) {
|
JSONObject arguments = new JSONObject();
|
arguments.put("_raw", rawArguments);
|
return arguments;
|
}
|
}
|
|
private AutoTuneAgentResult buildResult(boolean success,
|
String triggerType,
|
StringBuilder summaryBuffer,
|
int toolCallCount,
|
UsageCounter usageCounter,
|
boolean maxRoundsReached) {
|
AutoTuneAgentResult result = new AutoTuneAgentResult();
|
result.setSuccess(success);
|
result.setTriggerType(triggerType);
|
result.setToolCallCount(toolCallCount);
|
result.setLlmCallCount(usageCounter.getLlmCallCount());
|
result.setPromptTokens(usageCounter.getPromptTokens());
|
result.setCompletionTokens(usageCounter.getCompletionTokens());
|
result.setTotalTokens(usageCounter.getTotalTokens());
|
result.setMaxRoundsReached(maxRoundsReached);
|
|
String summary = summaryBuffer == null ? "" : summaryBuffer.toString().trim();
|
if (toolCallCount <= 0 && success) {
|
summary = "自动调参 Agent 未调用任何 MCP 工具,未执行调参。" + (summary.isEmpty() ? "" : "\n" + summary);
|
}
|
if (maxRoundsReached) {
|
summary = summary + "\n自动调参 Agent 达到最大工具调用轮次,已停止。";
|
}
|
result.setSummary(summary);
|
return result;
|
}
|
|
private void appendSummary(StringBuilder summaryBuffer, String content) {
|
if (summaryBuffer == null || isBlank(content)) {
|
return;
|
}
|
if (summaryBuffer.length() > 0) {
|
summaryBuffer.append('\n');
|
}
|
summaryBuffer.append(content.trim());
|
}
|
|
private String normalizeTriggerType(String triggerType) {
|
return isBlank(triggerType) ? "agent" : triggerType.trim();
|
}
|
|
private boolean isBlank(String value) {
|
return value == null || value.trim().isEmpty();
|
}
|
|
private static class UsageCounter {
|
private long promptTokens;
|
private long completionTokens;
|
private long totalTokens;
|
private int llmCallCount;
|
|
void add(ChatCompletionResponse.Usage usage) {
|
llmCallCount++;
|
if (usage == null) {
|
return;
|
}
|
promptTokens += usage.getPromptTokens() == null ? 0L : usage.getPromptTokens();
|
completionTokens += usage.getCompletionTokens() == null ? 0L : usage.getCompletionTokens();
|
totalTokens += usage.getTotalTokens() == null ? 0L : usage.getTotalTokens();
|
}
|
|
long getPromptTokens() {
|
return promptTokens;
|
}
|
|
long getCompletionTokens() {
|
return completionTokens;
|
}
|
|
long getTotalTokens() {
|
return totalTokens;
|
}
|
|
int getLlmCallCount() {
|
return llmCallCount;
|
}
|
}
|
}
|