From caf3bdd9bbb629c8bc6f1a19b3ccdf441bf7650c Mon Sep 17 00:00:00 2001
From: Junjie <fallin.jie@qq.com>
Date: 星期日, 15 三月 2026 17:46:47 +0800
Subject: [PATCH] #
---
src/main/java/com/zy/ai/mcp/service/SpringAiMcpToolManager.java | 574 ++++++++++++++++++++++++++++++++++++++++++--------------
1 files changed, 425 insertions(+), 149 deletions(-)
diff --git a/src/main/java/com/zy/ai/mcp/service/SpringAiMcpToolManager.java b/src/main/java/com/zy/ai/mcp/service/SpringAiMcpToolManager.java
index a40a717..cb91533 100644
--- a/src/main/java/com/zy/ai/mcp/service/SpringAiMcpToolManager.java
+++ b/src/main/java/com/zy/ai/mcp/service/SpringAiMcpToolManager.java
@@ -2,73 +2,52 @@
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
+import com.zy.ai.entity.AiMcpMount;
+import com.zy.ai.enums.AiMcpTransportType;
+import com.zy.ai.service.AiMcpMountService;
import io.modelcontextprotocol.client.McpClient;
import io.modelcontextprotocol.client.McpSyncClient;
import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
+import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport;
import io.modelcontextprotocol.spec.McpSchema;
+import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.mcp.SyncMcpToolCallbackProvider;
import org.springframework.ai.tool.ToolCallback;
-import org.springframework.ai.tool.ToolCallbackProvider;
+import org.springframework.ai.tool.definition.DefaultToolDefinition;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import jakarta.annotation.PreDestroy;
+import java.net.URI;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedHashMap;
+import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
@Slf4j
@Service
+@RequiredArgsConstructor
public class SpringAiMcpToolManager {
private static final McpSchema.Implementation CLIENT_INFO =
new McpSchema.Implementation("wcs-ai-assistant", "1.0.0");
private final Object clientMonitor = new Object();
-
- @Value("${server.port:9090}")
- private Integer serverPort;
-
- @Value("${server.servlet.context-path:}")
- private String contextPath;
-
- @Value("${spring.ai.mcp.server.sse-endpoint:/ai/mcp/sse}")
- private String sseEndpoint;
+ private final AiMcpMountService aiMcpMountService;
@Value("${spring.ai.mcp.server.request-timeout:20s}")
- private Duration requestTimeout;
+ private Duration defaultRequestTimeout;
- @Value("${app.ai.mcp.client.base-url:}")
- private String configuredBaseUrl;
-
- private volatile ClientSession clientSession;
+ private volatile ClientRegistry clientRegistry;
public List<Map<String, Object>> listTools() {
- List<Map<String, Object>> tools = new ArrayList<Map<String, Object>>();
- for (ToolCallback callback : getToolCallbacks()) {
- if (callback == null || callback.getToolDefinition() == null) {
- continue;
- }
- ToolDefinition definition = callback.getToolDefinition();
- Map<String, Object> item = new LinkedHashMap<String, Object>();
- item.put("name", definition.name());
- item.put("description", definition.description());
- item.put("inputSchema", parseSchema(definition.inputSchema()));
- tools.add(item);
- }
- tools.sort(new Comparator<Map<String, Object>>() {
- @Override
- public int compare(Map<String, Object> left, Map<String, Object> right) {
- return String.valueOf(left.get("name")).compareTo(String.valueOf(right.get("name")));
- }
- });
- return tools;
+ return ensureClientRegistry().toolList;
}
public List<Object> buildOpenAiTools() {
@@ -97,40 +76,301 @@
}
public Object callTool(String toolName, JSONObject arguments) {
- if (toolName == null || toolName.trim().isEmpty()) {
+ if (isBlank(toolName)) {
throw new IllegalArgumentException("missing tool name");
}
- ToolCallback callback = findCallback(toolName);
- if (callback == null) {
+ MountedTool mountedTool = ensureClientRegistry().toolMap.get(toolName);
+ if (mountedTool == null) {
throw new IllegalArgumentException("tool not found: " + toolName);
}
- String rawResult = callback.call(arguments == null ? "{}" : arguments.toJSONString());
- return parseToolResult(rawResult);
+ try {
+ String rawResult = mountedTool.callback.call(arguments == null ? "{}" : arguments.toJSONString());
+ return parseToolResult(rawResult);
+ } catch (Exception e) {
+ evictCache();
+ throw e;
+ }
}
- private ToolCallback findCallback(String toolName) {
- for (ToolCallback callback : getToolCallbacks()) {
- if (callback == null || callback.getToolDefinition() == null) {
- continue;
- }
- if (toolName.equals(callback.getToolDefinition().name())) {
- return callback;
+ public Map<String, Object> testMount(AiMcpMount mount) {
+ if (mount == null) {
+ throw new IllegalArgumentException("鍙傛暟涓嶈兘涓虹┖");
+ }
+ MountSession session = null;
+ long start = System.currentTimeMillis();
+ try {
+ session = openSession(mount);
+ LinkedHashMap<String, Object> result = new LinkedHashMap<String, Object>();
+ result.put("ok", true);
+ result.put("message", "杩炴帴鎴愬姛锛屽凡鍙戠幇 " + session.callbacks.length + " 涓伐鍏�");
+ result.put("latencyMs", System.currentTimeMillis() - start);
+ result.put("url", session.url);
+ result.put("transportType", mount.getTransportType());
+ result.put("toolCount", session.callbacks.length);
+ result.put("toolNames", collectToolNames(session.callbacks, mount, reservedToolNames(mount.getMountCode())));
+ return result;
+ } catch (Exception e) {
+ TransportTarget target = resolveTransportTarget(mount);
+ LinkedHashMap<String, Object> result = new LinkedHashMap<String, Object>();
+ result.put("ok", false);
+ result.put("message", safeMessage(e));
+ result.put("latencyMs", System.currentTimeMillis() - start);
+ result.put("url", target.url);
+ result.put("transportType", mount.getTransportType());
+ result.put("toolCount", 0);
+ result.put("toolNames", new ArrayList<String>());
+ return result;
+ } finally {
+ if (session != null) {
+ session.close();
}
}
- return null;
}
- private ToolCallback[] getToolCallbacks() {
+ public void evictCache() {
+ resetClientRegistry();
+ }
+
+ private ClientRegistry ensureClientRegistry() {
+ ClientRegistry current = clientRegistry;
+ if (current != null) {
+ return current;
+ }
+
+ synchronized (clientMonitor) {
+ current = clientRegistry;
+ if (current != null) {
+ return current;
+ }
+
+ current = buildClientRegistry();
+ clientRegistry = current;
+ return current;
+ }
+ }
+
+ private ClientRegistry buildClientRegistry() {
+ List<AiMcpMount> mounts = loadEnabledMounts();
+ List<MountSession> sessions = new ArrayList<MountSession>();
+ LinkedHashMap<String, MountedTool> toolMap = new LinkedHashMap<String, MountedTool>();
+ List<Map<String, Object>> toolList = new ArrayList<Map<String, Object>>();
+
+ for (AiMcpMount mount : mounts) {
+ if (mount == null) {
+ continue;
+ }
+ MountSession session = null;
+ try {
+ session = openSession(mount);
+ sessions.add(session);
+ for (ToolCallback callback : session.callbacks) {
+ MountedTool mountedTool = buildMountedTool(mount, callback, toolMap.keySet());
+ toolMap.put(mountedTool.toolName, mountedTool);
+ toolList.add(toToolDescriptor(mountedTool));
+ }
+ log.info("MCP mount loaded, mountCode={}, transport={}, url={}, tools={}",
+ mount.getMountCode(), mount.getTransportType(), session.url, session.callbacks.length);
+ } catch (Exception e) {
+ log.warn("Failed to load MCP mount, mountCode={}, transport={}, url={}",
+ mount.getMountCode(), mount.getTransportType(), resolveUrl(mount), e);
+ if (session != null) {
+ session.close();
+ }
+ }
+ }
+
+ toolList.sort(new Comparator<Map<String, Object>>() {
+ @Override
+ public int compare(Map<String, Object> left, Map<String, Object> right) {
+ return String.valueOf(left.get("name")).compareTo(String.valueOf(right.get("name")));
+ }
+ });
+ return new ClientRegistry(sessions, toolMap, toolList);
+ }
+
+ private List<AiMcpMount> loadEnabledMounts() {
try {
- ToolCallback[] callbacks = ensureToolCallbackProvider().getToolCallbacks();
- return callbacks == null ? new ToolCallback[0] : callbacks;
+ List<AiMcpMount> mounts = aiMcpMountService.listEnabledOrdered();
+ if (mounts == null || mounts.isEmpty()) {
+ aiMcpMountService.initDefaultsIfMissing();
+ mounts = aiMcpMountService.listEnabledOrdered();
+ }
+ return mounts == null ? Collections.<AiMcpMount>emptyList() : mounts;
} catch (Exception e) {
- log.warn("Failed to load MCP tools through SSE client, baseUrl={}, sseEndpoint={}",
- resolveBaseUrl(), resolveClientSseEndpoint(), e);
- resetClientSession();
- return new ToolCallback[0];
+ log.warn("Failed to query MCP mount configuration", e);
+ return Collections.emptyList();
+ }
+ }
+
+ private MountedTool buildMountedTool(AiMcpMount mount, ToolCallback callback, java.util.Set<String> usedNames) {
+ ToolDefinition definition = callback == null ? null : callback.getToolDefinition();
+ if (definition == null || isBlank(definition.name())) {
+ throw new IllegalArgumentException("invalid tool definition");
+ }
+
+ String originalName = definition.name();
+ String preferredName = mount.getMountCode() + "_" + originalName;
+ String finalName = ensureUniqueToolName(preferredName, mount, originalName, usedNames);
+ ToolCallback effectiveCallback = originalName.equals(finalName)
+ ? callback
+ : new MountedToolCallback(finalName, callback);
+
+ return new MountedTool(mount, originalName, finalName, effectiveCallback);
+ }
+
+ private String ensureUniqueToolName(String preferredName,
+ AiMcpMount mount,
+ String originalName,
+ java.util.Set<String> usedNames) {
+ if (!usedNames.contains(preferredName)) {
+ return preferredName;
+ }
+
+ String fallbackBase = mount.getMountCode() + "_" + originalName;
+ if (!usedNames.contains(fallbackBase)) {
+ log.warn("Duplicate MCP tool name detected, fallback rename applied, mountCode={}, originalName={}, finalName={}",
+ mount.getMountCode(), originalName, fallbackBase);
+ return fallbackBase;
+ }
+
+ int index = 2;
+ String candidate = fallbackBase + "_" + index;
+ while (usedNames.contains(candidate)) {
+ index++;
+ candidate = fallbackBase + "_" + index;
+ }
+ log.warn("Duplicate MCP tool name detected, numbered rename applied, mountCode={}, originalName={}, finalName={}",
+ mount.getMountCode(), originalName, candidate);
+ return candidate;
+ }
+
+ private Map<String, Object> toToolDescriptor(MountedTool mountedTool) {
+ ToolDefinition definition = mountedTool.callback.getToolDefinition();
+ Map<String, Object> item = new LinkedHashMap<String, Object>();
+ item.put("name", definition.name());
+ item.put("originalName", mountedTool.originalName);
+ item.put("mountCode", mountedTool.mount.getMountCode());
+ item.put("mountName", mountedTool.mount.getName());
+ item.put("transportType", mountedTool.mount.getTransportType());
+ item.put("description", definition.description());
+ item.put("inputSchema", parseSchema(definition.inputSchema()));
+ return item;
+ }
+
+ private List<String> collectToolNames(ToolCallback[] callbacks, AiMcpMount mount, java.util.Set<String> reservedNames) {
+ List<String> names = new ArrayList<String>();
+ java.util.Set<String> used = new LinkedHashSet<String>();
+ if (reservedNames != null && !reservedNames.isEmpty()) {
+ used.addAll(reservedNames);
+ }
+ if (callbacks == null) {
+ return names;
+ }
+ for (ToolCallback callback : callbacks) {
+ ToolDefinition definition = callback == null ? null : callback.getToolDefinition();
+ if (definition == null || isBlank(definition.name())) {
+ continue;
+ }
+ String preferred = mount.getMountCode() + "_" + definition.name();
+ String finalName = ensureUniqueToolName(preferred, mount, definition.name(), used);
+ used.add(finalName);
+ names.add(finalName);
+ }
+ return names;
+ }
+
+ private java.util.Set<String> reservedToolNames(String mountCode) {
+ LinkedHashSet<String> reserved = new LinkedHashSet<String>();
+ ClientRegistry current = clientRegistry;
+ if (current == null || current.toolMap == null || current.toolMap.isEmpty()) {
+ return reserved;
+ }
+ for (MountedTool mountedTool : current.toolMap.values()) {
+ if (mountedTool == null || mountedTool.mount == null) {
+ continue;
+ }
+ if (mountCode != null && mountCode.equals(mountedTool.mount.getMountCode())) {
+ continue;
+ }
+ reserved.add(mountedTool.toolName);
+ }
+ return reserved;
+ }
+
+ private MountSession openSession(AiMcpMount mount) {
+ Duration timeout = resolveTimeout(mount);
+ TransportTarget target = resolveTransportTarget(mount);
+ String baseUrl = target.baseUrl;
+ String endpoint = target.endpoint;
+ AiMcpTransportType transportType = AiMcpTransportType.ofCode(mount.getTransportType());
+ McpSyncClient syncClient;
+
+ if (transportType == AiMcpTransportType.STREAMABLE_HTTP) {
+ HttpClientStreamableHttpTransport transport = HttpClientStreamableHttpTransport.builder(baseUrl)
+ .endpoint(endpoint)
+ .connectTimeout(timeout)
+ .build();
+ syncClient = McpClient.sync(transport)
+ .clientInfo(CLIENT_INFO)
+ .requestTimeout(timeout)
+ .initializationTimeout(timeout)
+ .build();
+ } else {
+ HttpClientSseClientTransport transport = HttpClientSseClientTransport.builder(baseUrl)
+ .sseEndpoint(endpoint)
+ .connectTimeout(timeout)
+ .build();
+ syncClient = McpClient.sync(transport)
+ .clientInfo(CLIENT_INFO)
+ .requestTimeout(timeout)
+ .initializationTimeout(timeout)
+ .build();
+ }
+
+ syncClient.initialize();
+ SyncMcpToolCallbackProvider callbackProvider = new SyncMcpToolCallbackProvider(syncClient);
+ ToolCallback[] callbacks = callbackProvider.getToolCallbacks();
+ return new MountSession(mount, syncClient, target.url, callbacks == null ? new ToolCallback[0] : callbacks);
+ }
+
+ private Duration resolveTimeout(AiMcpMount mount) {
+ if (mount != null && mount.getRequestTimeoutMs() != null && mount.getRequestTimeoutMs() > 0) {
+ return Duration.ofMillis(mount.getRequestTimeoutMs());
+ }
+ return defaultRequestTimeout == null ? Duration.ofSeconds(20) : defaultRequestTimeout;
+ }
+
+ private String resolveUrl(AiMcpMount mount) {
+ return trim(mount == null ? null : mount.getUrl());
+ }
+
+ private TransportTarget resolveTransportTarget(AiMcpMount mount) {
+ String rawUrl = resolveUrl(mount);
+ if (isBlank(rawUrl)) {
+ throw new IllegalArgumentException("missing url");
+ }
+ try {
+ URI finalUri = URI.create(rawUrl);
+ String authority = finalUri.getRawAuthority();
+ String scheme = finalUri.getScheme();
+ if (isBlank(scheme) || isBlank(authority)) {
+ throw new IllegalArgumentException("invalid MCP url");
+ }
+ String baseUrl = scheme + "://" + authority;
+ String endpoint = finalUri.getRawPath();
+ if (isBlank(endpoint)) {
+ throw new IllegalArgumentException("missing MCP path");
+ }
+ if (finalUri.getRawQuery() != null && !finalUri.getRawQuery().isEmpty()) {
+ endpoint = endpoint + "?" + finalUri.getRawQuery();
+ }
+ return new TransportTarget(rawUrl, baseUrl, endpoint);
+ } catch (IllegalArgumentException e) {
+ throw e;
+ } catch (Exception e) {
+ throw new IllegalArgumentException("invalid url: " + rawUrl, e);
}
}
@@ -163,89 +403,8 @@
return fallback;
}
- private ToolCallbackProvider ensureToolCallbackProvider() {
- return ensureClientSession().toolCallbackProvider;
- }
-
- private ClientSession ensureClientSession() {
- ClientSession current = clientSession;
- if (current != null) {
- return current;
- }
-
- synchronized (clientMonitor) {
- current = clientSession;
- if (current != null) {
- return current;
- }
-
- String baseUrl = resolveBaseUrl();
- String clientSseEndpoint = resolveClientSseEndpoint();
- HttpClientSseClientTransport transport = HttpClientSseClientTransport.builder(baseUrl)
- .sseEndpoint(clientSseEndpoint)
- .connectTimeout(requestTimeout)
- .build();
-
- McpSyncClient syncClient = McpClient.sync(transport)
- .clientInfo(CLIENT_INFO)
- .requestTimeout(requestTimeout)
- .initializationTimeout(requestTimeout)
- .build();
- syncClient.initialize();
-
- SyncMcpToolCallbackProvider callbackProvider = new SyncMcpToolCallbackProvider(syncClient);
- current = new ClientSession(syncClient, callbackProvider, baseUrl);
- clientSession = current;
- log.info("Spring AI MCP SSE client initialized, baseUrl={}, sseEndpoint={}, tools={}",
- baseUrl, clientSseEndpoint, current.toolCallbackProvider.getToolCallbacks().length);
- return current;
- }
- }
-
- private void resetClientSession() {
- synchronized (clientMonitor) {
- ClientSession current = clientSession;
- clientSession = null;
- if (current != null) {
- current.close();
- }
- }
- }
-
- private String resolveBaseUrl() {
- if (configuredBaseUrl != null && !configuredBaseUrl.trim().isEmpty()) {
- return trimTrailingSlash(configuredBaseUrl.trim());
- }
- StringBuilder url = new StringBuilder("http://127.0.0.1:");
- url.append(serverPort == null ? 9090 : serverPort);
- return trimTrailingSlash(url.toString());
- }
-
- private String resolveClientSseEndpoint() {
- String endpoint = normalizePath(sseEndpoint);
- if (configuredBaseUrl != null && !configuredBaseUrl.trim().isEmpty()) {
- return endpoint;
- }
- String context = normalizeContextPath(contextPath);
- if (context.isEmpty()) {
- return endpoint;
- }
- return context + endpoint;
- }
-
- private String normalizeContextPath(String path) {
- if (path == null || path.trim().isEmpty() || "/".equals(path.trim())) {
- return "";
- }
- String value = path.trim();
- if (!value.startsWith("/")) {
- value = "/" + value;
- }
- return trimTrailingSlash(value);
- }
-
private String normalizePath(String path) {
- if (path == null || path.trim().isEmpty()) {
+ if (isBlank(path)) {
return "/";
}
String value = path.trim();
@@ -255,28 +414,81 @@
return value;
}
- private String trimTrailingSlash(String value) {
- if (value == null || value.isEmpty()) {
- return "";
+ private String safeMessage(Throwable throwable) {
+ if (throwable == null) {
+ return "unknown error";
}
- return value.endsWith("/") && value.length() > 1 ? value.substring(0, value.length() - 1) : value;
+ if (!isBlank(throwable.getMessage())) {
+ return throwable.getMessage();
+ }
+ return throwable.getClass().getSimpleName();
+ }
+
+ private String trim(String text) {
+ return text == null ? null : text.trim();
+ }
+
+ private boolean isBlank(String text) {
+ return text == null || text.trim().isEmpty();
+ }
+
+ private void resetClientRegistry() {
+ synchronized (clientMonitor) {
+ ClientRegistry current = clientRegistry;
+ clientRegistry = null;
+ if (current != null) {
+ current.close();
+ }
+ }
}
@PreDestroy
public void destroy() {
- resetClientSession();
+ resetClientRegistry();
}
- private static final class ClientSession implements AutoCloseable {
+ private static final class ClientRegistry implements AutoCloseable {
+ private final List<MountSession> sessions;
+ private final Map<String, MountedTool> toolMap;
+ private final List<Map<String, Object>> toolList;
+
+ private ClientRegistry(List<MountSession> sessions,
+ Map<String, MountedTool> toolMap,
+ List<Map<String, Object>> toolList) {
+ this.sessions = sessions;
+ this.toolMap = toolMap;
+ this.toolList = toolList;
+ }
+
+ @Override
+ public void close() {
+ if (sessions == null) {
+ return;
+ }
+ for (MountSession session : sessions) {
+ if (session != null) {
+ session.close();
+ }
+ }
+ }
+ }
+
+ private static final class MountSession implements AutoCloseable {
+
+ private final AiMcpMount mount;
private final McpSyncClient syncClient;
- private final ToolCallbackProvider toolCallbackProvider;
- private final String baseUrl;
+ private final String url;
+ private final ToolCallback[] callbacks;
- private ClientSession(McpSyncClient syncClient, ToolCallbackProvider toolCallbackProvider, String baseUrl) {
+ private MountSession(AiMcpMount mount,
+ McpSyncClient syncClient,
+ String url,
+ ToolCallback[] callbacks) {
+ this.mount = mount;
this.syncClient = syncClient;
- this.toolCallbackProvider = toolCallbackProvider;
- this.baseUrl = baseUrl;
+ this.url = url;
+ this.callbacks = callbacks;
}
@Override
@@ -284,13 +496,77 @@
try {
syncClient.closeGracefully();
} catch (Exception e) {
- log.debug("Close MCP SSE client gracefully failed, baseUrl={}", baseUrl, e);
+ log.debug("Close MCP client gracefully failed, mountCode={}", mount == null ? null : mount.getMountCode(), e);
}
try {
syncClient.close();
} catch (Exception e) {
- log.debug("Close MCP SSE client failed, baseUrl={}", baseUrl, e);
+ log.debug("Close MCP client failed, mountCode={}", mount == null ? null : mount.getMountCode(), e);
}
}
}
+
+ private static final class MountedTool {
+
+ private final AiMcpMount mount;
+ private final String originalName;
+ private final String toolName;
+ private final ToolCallback callback;
+
+ private MountedTool(AiMcpMount mount, String originalName, String toolName, ToolCallback callback) {
+ this.mount = mount;
+ this.originalName = originalName;
+ this.toolName = toolName;
+ this.callback = callback;
+ }
+ }
+
+ private static final class MountedToolCallback implements ToolCallback {
+
+ private final ToolCallback delegate;
+ private final ToolDefinition definition;
+
+ private MountedToolCallback(String name, ToolCallback delegate) {
+ this.delegate = delegate;
+ ToolDefinition source = delegate.getToolDefinition();
+ this.definition = DefaultToolDefinition.builder()
+ .name(name)
+ .description(source == null ? "" : source.description())
+ .inputSchema(source == null ? "{}" : source.inputSchema())
+ .build();
+ }
+
+ @Override
+ public ToolDefinition getToolDefinition() {
+ return definition;
+ }
+
+ @Override
+ public org.springframework.ai.tool.metadata.ToolMetadata getToolMetadata() {
+ return delegate.getToolMetadata();
+ }
+
+ @Override
+ public String call(String toolInput) {
+ return delegate.call(toolInput);
+ }
+
+ @Override
+ public String call(String toolInput, org.springframework.ai.chat.model.ToolContext toolContext) {
+ return delegate.call(toolInput, toolContext);
+ }
+ }
+
+ private static final class TransportTarget {
+
+ private final String url;
+ private final String baseUrl;
+ private final String endpoint;
+
+ private TransportTarget(String url, String baseUrl, String endpoint) {
+ this.url = url;
+ this.baseUrl = baseUrl;
+ this.endpoint = endpoint;
+ }
+ }
}
--
Gitblit v1.9.1