package com.zy.ai.mcp.service;
|
|
import com.alibaba.fastjson.JSON;
|
import com.alibaba.fastjson.JSONObject;
|
import com.zy.ai.entity.AiMcpMount;
|
import com.zy.ai.enums.AiMcpTransportType;
|
import com.zy.ai.service.AiMcpMountService;
|
import io.modelcontextprotocol.client.McpClient;
|
import io.modelcontextprotocol.client.McpSyncClient;
|
import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
|
import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport;
|
import io.modelcontextprotocol.spec.McpSchema;
|
import lombok.RequiredArgsConstructor;
|
import lombok.extern.slf4j.Slf4j;
|
import org.springframework.ai.mcp.SyncMcpToolCallbackProvider;
|
import org.springframework.ai.tool.ToolCallback;
|
import org.springframework.ai.tool.definition.DefaultToolDefinition;
|
import org.springframework.ai.tool.definition.ToolDefinition;
|
import org.springframework.beans.factory.annotation.Value;
|
import org.springframework.stereotype.Service;
|
|
import jakarta.annotation.PreDestroy;
|
import java.net.URI;
|
import java.time.Duration;
|
import java.util.ArrayList;
|
import java.util.Collections;
|
import java.util.Comparator;
|
import java.util.LinkedHashMap;
|
import java.util.LinkedHashSet;
|
import java.util.List;
|
import java.util.Map;
|
|
@Slf4j
|
@Service
|
@RequiredArgsConstructor
|
public class SpringAiMcpToolManager {
|
|
private static final McpSchema.Implementation CLIENT_INFO =
|
new McpSchema.Implementation("wcs-ai-assistant", "1.0.0");
|
|
private final Object clientMonitor = new Object();
|
private final AiMcpMountService aiMcpMountService;
|
|
@Value("${spring.ai.mcp.server.request-timeout:20s}")
|
private Duration defaultRequestTimeout;
|
|
private volatile ClientRegistry clientRegistry;
|
|
public List<Map<String, Object>> listTools() {
|
return ensureClientRegistry().toolList;
|
}
|
|
public List<Object> buildOpenAiTools() {
|
List<Object> tools = new ArrayList<Object>();
|
for (Map<String, Object> item : listTools()) {
|
Object name = item.get("name");
|
if (name == null) {
|
continue;
|
}
|
|
Map<String, Object> function = new LinkedHashMap<String, Object>();
|
function.put("name", String.valueOf(name));
|
Object description = item.get("description");
|
if (description != null) {
|
function.put("description", String.valueOf(description));
|
}
|
Object inputSchema = item.get("inputSchema");
|
function.put("parameters", inputSchema == null ? new LinkedHashMap<String, Object>() : inputSchema);
|
|
Map<String, Object> tool = new LinkedHashMap<String, Object>();
|
tool.put("type", "function");
|
tool.put("function", function);
|
tools.add(tool);
|
}
|
return tools;
|
}
|
|
public Object callTool(String toolName, JSONObject arguments) {
|
if (isBlank(toolName)) {
|
throw new IllegalArgumentException("missing tool name");
|
}
|
|
MountedTool mountedTool = ensureClientRegistry().toolMap.get(toolName);
|
if (mountedTool == null) {
|
throw new IllegalArgumentException("tool not found: " + toolName);
|
}
|
|
try {
|
String rawResult = mountedTool.callback.call(arguments == null ? "{}" : arguments.toJSONString());
|
return parseToolResult(rawResult);
|
} catch (Exception e) {
|
evictCache();
|
throw e;
|
}
|
}
|
|
public Map<String, Object> testMount(AiMcpMount mount) {
|
if (mount == null) {
|
throw new IllegalArgumentException("参数不能为空");
|
}
|
MountSession session = null;
|
long start = System.currentTimeMillis();
|
try {
|
session = openSession(mount);
|
LinkedHashMap<String, Object> result = new LinkedHashMap<String, Object>();
|
result.put("ok", true);
|
result.put("message", "连接成功,已发现 " + session.callbacks.length + " 个工具");
|
result.put("latencyMs", System.currentTimeMillis() - start);
|
result.put("url", session.url);
|
result.put("transportType", mount.getTransportType());
|
result.put("toolCount", session.callbacks.length);
|
result.put("toolNames", collectToolNames(session.callbacks, mount, reservedToolNames(mount.getMountCode())));
|
return result;
|
} catch (Exception e) {
|
TransportTarget target = resolveTransportTarget(mount);
|
LinkedHashMap<String, Object> result = new LinkedHashMap<String, Object>();
|
result.put("ok", false);
|
result.put("message", safeMessage(e));
|
result.put("latencyMs", System.currentTimeMillis() - start);
|
result.put("url", target.url);
|
result.put("transportType", mount.getTransportType());
|
result.put("toolCount", 0);
|
result.put("toolNames", new ArrayList<String>());
|
return result;
|
} finally {
|
if (session != null) {
|
session.close();
|
}
|
}
|
}
|
|
public void evictCache() {
|
resetClientRegistry();
|
}
|
|
private ClientRegistry ensureClientRegistry() {
|
ClientRegistry current = clientRegistry;
|
if (current != null) {
|
return current;
|
}
|
|
synchronized (clientMonitor) {
|
current = clientRegistry;
|
if (current != null) {
|
return current;
|
}
|
|
current = buildClientRegistry();
|
clientRegistry = current;
|
return current;
|
}
|
}
|
|
private ClientRegistry buildClientRegistry() {
|
List<AiMcpMount> mounts = loadEnabledMounts();
|
List<MountSession> sessions = new ArrayList<MountSession>();
|
LinkedHashMap<String, MountedTool> toolMap = new LinkedHashMap<String, MountedTool>();
|
List<Map<String, Object>> toolList = new ArrayList<Map<String, Object>>();
|
|
for (AiMcpMount mount : mounts) {
|
if (mount == null) {
|
continue;
|
}
|
MountSession session = null;
|
try {
|
session = openSession(mount);
|
sessions.add(session);
|
for (ToolCallback callback : session.callbacks) {
|
MountedTool mountedTool = buildMountedTool(mount, callback, toolMap.keySet());
|
toolMap.put(mountedTool.toolName, mountedTool);
|
toolList.add(toToolDescriptor(mountedTool));
|
}
|
log.info("MCP mount loaded, mountCode={}, transport={}, url={}, tools={}",
|
mount.getMountCode(), mount.getTransportType(), session.url, session.callbacks.length);
|
} catch (Exception e) {
|
log.warn("Failed to load MCP mount, mountCode={}, transport={}, url={}",
|
mount.getMountCode(), mount.getTransportType(), resolveUrl(mount), e);
|
if (session != null) {
|
session.close();
|
}
|
}
|
}
|
|
toolList.sort(new Comparator<Map<String, Object>>() {
|
@Override
|
public int compare(Map<String, Object> left, Map<String, Object> right) {
|
return String.valueOf(left.get("name")).compareTo(String.valueOf(right.get("name")));
|
}
|
});
|
return new ClientRegistry(sessions, toolMap, toolList);
|
}
|
|
private List<AiMcpMount> loadEnabledMounts() {
|
try {
|
List<AiMcpMount> mounts = aiMcpMountService.listEnabledOrdered();
|
if (mounts == null || mounts.isEmpty()) {
|
aiMcpMountService.initDefaultsIfMissing();
|
mounts = aiMcpMountService.listEnabledOrdered();
|
}
|
return mounts == null ? Collections.<AiMcpMount>emptyList() : mounts;
|
} catch (Exception e) {
|
log.warn("Failed to query MCP mount configuration", e);
|
return Collections.emptyList();
|
}
|
}
|
|
private MountedTool buildMountedTool(AiMcpMount mount, ToolCallback callback, java.util.Set<String> usedNames) {
|
ToolDefinition definition = callback == null ? null : callback.getToolDefinition();
|
if (definition == null || isBlank(definition.name())) {
|
throw new IllegalArgumentException("invalid tool definition");
|
}
|
|
String originalName = definition.name();
|
String preferredName = mount.getMountCode() + "_" + originalName;
|
String finalName = ensureUniqueToolName(preferredName, mount, originalName, usedNames);
|
ToolCallback effectiveCallback = originalName.equals(finalName)
|
? callback
|
: new MountedToolCallback(finalName, callback);
|
|
return new MountedTool(mount, originalName, finalName, effectiveCallback);
|
}
|
|
private String ensureUniqueToolName(String preferredName,
|
AiMcpMount mount,
|
String originalName,
|
java.util.Set<String> usedNames) {
|
if (!usedNames.contains(preferredName)) {
|
return preferredName;
|
}
|
|
String fallbackBase = mount.getMountCode() + "_" + originalName;
|
if (!usedNames.contains(fallbackBase)) {
|
log.warn("Duplicate MCP tool name detected, fallback rename applied, mountCode={}, originalName={}, finalName={}",
|
mount.getMountCode(), originalName, fallbackBase);
|
return fallbackBase;
|
}
|
|
int index = 2;
|
String candidate = fallbackBase + "_" + index;
|
while (usedNames.contains(candidate)) {
|
index++;
|
candidate = fallbackBase + "_" + index;
|
}
|
log.warn("Duplicate MCP tool name detected, numbered rename applied, mountCode={}, originalName={}, finalName={}",
|
mount.getMountCode(), originalName, candidate);
|
return candidate;
|
}
|
|
private Map<String, Object> toToolDescriptor(MountedTool mountedTool) {
|
ToolDefinition definition = mountedTool.callback.getToolDefinition();
|
Map<String, Object> item = new LinkedHashMap<String, Object>();
|
item.put("name", definition.name());
|
item.put("originalName", mountedTool.originalName);
|
item.put("mountCode", mountedTool.mount.getMountCode());
|
item.put("mountName", mountedTool.mount.getName());
|
item.put("transportType", mountedTool.mount.getTransportType());
|
item.put("description", definition.description());
|
item.put("inputSchema", parseSchema(definition.inputSchema()));
|
return item;
|
}
|
|
private List<String> collectToolNames(ToolCallback[] callbacks, AiMcpMount mount, java.util.Set<String> reservedNames) {
|
List<String> names = new ArrayList<String>();
|
java.util.Set<String> used = new LinkedHashSet<String>();
|
if (reservedNames != null && !reservedNames.isEmpty()) {
|
used.addAll(reservedNames);
|
}
|
if (callbacks == null) {
|
return names;
|
}
|
for (ToolCallback callback : callbacks) {
|
ToolDefinition definition = callback == null ? null : callback.getToolDefinition();
|
if (definition == null || isBlank(definition.name())) {
|
continue;
|
}
|
String preferred = mount.getMountCode() + "_" + definition.name();
|
String finalName = ensureUniqueToolName(preferred, mount, definition.name(), used);
|
used.add(finalName);
|
names.add(finalName);
|
}
|
return names;
|
}
|
|
private java.util.Set<String> reservedToolNames(String mountCode) {
|
LinkedHashSet<String> reserved = new LinkedHashSet<String>();
|
ClientRegistry current = clientRegistry;
|
if (current == null || current.toolMap == null || current.toolMap.isEmpty()) {
|
return reserved;
|
}
|
for (MountedTool mountedTool : current.toolMap.values()) {
|
if (mountedTool == null || mountedTool.mount == null) {
|
continue;
|
}
|
if (mountCode != null && mountCode.equals(mountedTool.mount.getMountCode())) {
|
continue;
|
}
|
reserved.add(mountedTool.toolName);
|
}
|
return reserved;
|
}
|
|
private MountSession openSession(AiMcpMount mount) {
|
Duration timeout = resolveTimeout(mount);
|
TransportTarget target = resolveTransportTarget(mount);
|
String baseUrl = target.baseUrl;
|
String endpoint = target.endpoint;
|
AiMcpTransportType transportType = AiMcpTransportType.ofCode(mount.getTransportType());
|
McpSyncClient syncClient;
|
|
if (transportType == AiMcpTransportType.STREAMABLE_HTTP) {
|
HttpClientStreamableHttpTransport transport = HttpClientStreamableHttpTransport.builder(baseUrl)
|
.endpoint(endpoint)
|
.connectTimeout(timeout)
|
.build();
|
syncClient = McpClient.sync(transport)
|
.clientInfo(CLIENT_INFO)
|
.requestTimeout(timeout)
|
.initializationTimeout(timeout)
|
.build();
|
} else {
|
HttpClientSseClientTransport transport = HttpClientSseClientTransport.builder(baseUrl)
|
.sseEndpoint(endpoint)
|
.connectTimeout(timeout)
|
.build();
|
syncClient = McpClient.sync(transport)
|
.clientInfo(CLIENT_INFO)
|
.requestTimeout(timeout)
|
.initializationTimeout(timeout)
|
.build();
|
}
|
|
syncClient.initialize();
|
SyncMcpToolCallbackProvider callbackProvider = new SyncMcpToolCallbackProvider(syncClient);
|
ToolCallback[] callbacks = callbackProvider.getToolCallbacks();
|
return new MountSession(mount, syncClient, target.url, callbacks == null ? new ToolCallback[0] : callbacks);
|
}
|
|
private Duration resolveTimeout(AiMcpMount mount) {
|
if (mount != null && mount.getRequestTimeoutMs() != null && mount.getRequestTimeoutMs() > 0) {
|
return Duration.ofMillis(mount.getRequestTimeoutMs());
|
}
|
return defaultRequestTimeout == null ? Duration.ofSeconds(20) : defaultRequestTimeout;
|
}
|
|
private String resolveUrl(AiMcpMount mount) {
|
return trim(mount == null ? null : mount.getUrl());
|
}
|
|
private TransportTarget resolveTransportTarget(AiMcpMount mount) {
|
String rawUrl = resolveUrl(mount);
|
if (isBlank(rawUrl)) {
|
throw new IllegalArgumentException("missing url");
|
}
|
try {
|
URI finalUri = URI.create(rawUrl);
|
String authority = finalUri.getRawAuthority();
|
String scheme = finalUri.getScheme();
|
if (isBlank(scheme) || isBlank(authority)) {
|
throw new IllegalArgumentException("invalid MCP url");
|
}
|
String baseUrl = scheme + "://" + authority;
|
String endpoint = finalUri.getRawPath();
|
if (isBlank(endpoint)) {
|
throw new IllegalArgumentException("missing MCP path");
|
}
|
if (finalUri.getRawQuery() != null && !finalUri.getRawQuery().isEmpty()) {
|
endpoint = endpoint + "?" + finalUri.getRawQuery();
|
}
|
return new TransportTarget(rawUrl, baseUrl, endpoint);
|
} catch (IllegalArgumentException e) {
|
throw e;
|
} catch (Exception e) {
|
throw new IllegalArgumentException("invalid url: " + rawUrl, e);
|
}
|
}
|
|
private Object parseToolResult(String rawResult) {
|
if (rawResult == null || rawResult.trim().isEmpty()) {
|
return rawResult;
|
}
|
try {
|
return JSON.parse(rawResult);
|
} catch (Exception ignore) {
|
return rawResult;
|
}
|
}
|
|
@SuppressWarnings("unchecked")
|
private Map<String, Object> parseSchema(String inputSchema) {
|
if (inputSchema == null || inputSchema.trim().isEmpty()) {
|
return Collections.emptyMap();
|
}
|
try {
|
Object parsed = JSON.parse(inputSchema);
|
if (parsed instanceof Map) {
|
return new LinkedHashMap<String, Object>((Map<String, Object>) parsed);
|
}
|
} catch (Exception e) {
|
log.warn("Failed to parse MCP tool schema: {}", inputSchema, e);
|
}
|
Map<String, Object> fallback = new LinkedHashMap<String, Object>();
|
fallback.put("type", "object");
|
return fallback;
|
}
|
|
private String normalizePath(String path) {
|
if (isBlank(path)) {
|
return "/";
|
}
|
String value = path.trim();
|
if (!value.startsWith("/")) {
|
value = "/" + value;
|
}
|
return value;
|
}
|
|
private String safeMessage(Throwable throwable) {
|
if (throwable == null) {
|
return "unknown error";
|
}
|
if (!isBlank(throwable.getMessage())) {
|
return throwable.getMessage();
|
}
|
return throwable.getClass().getSimpleName();
|
}
|
|
private String trim(String text) {
|
return text == null ? null : text.trim();
|
}
|
|
private boolean isBlank(String text) {
|
return text == null || text.trim().isEmpty();
|
}
|
|
private void resetClientRegistry() {
|
synchronized (clientMonitor) {
|
ClientRegistry current = clientRegistry;
|
clientRegistry = null;
|
if (current != null) {
|
current.close();
|
}
|
}
|
}
|
|
@PreDestroy
|
public void destroy() {
|
resetClientRegistry();
|
}
|
|
private static final class ClientRegistry implements AutoCloseable {
|
|
private final List<MountSession> sessions;
|
private final Map<String, MountedTool> toolMap;
|
private final List<Map<String, Object>> toolList;
|
|
private ClientRegistry(List<MountSession> sessions,
|
Map<String, MountedTool> toolMap,
|
List<Map<String, Object>> toolList) {
|
this.sessions = sessions;
|
this.toolMap = toolMap;
|
this.toolList = toolList;
|
}
|
|
@Override
|
public void close() {
|
if (sessions == null) {
|
return;
|
}
|
for (MountSession session : sessions) {
|
if (session != null) {
|
session.close();
|
}
|
}
|
}
|
}
|
|
private static final class MountSession implements AutoCloseable {
|
|
private final AiMcpMount mount;
|
private final McpSyncClient syncClient;
|
private final String url;
|
private final ToolCallback[] callbacks;
|
|
private MountSession(AiMcpMount mount,
|
McpSyncClient syncClient,
|
String url,
|
ToolCallback[] callbacks) {
|
this.mount = mount;
|
this.syncClient = syncClient;
|
this.url = url;
|
this.callbacks = callbacks;
|
}
|
|
@Override
|
public void close() {
|
try {
|
syncClient.closeGracefully();
|
} catch (Exception e) {
|
log.debug("Close MCP client gracefully failed, mountCode={}", mount == null ? null : mount.getMountCode(), e);
|
}
|
try {
|
syncClient.close();
|
} catch (Exception e) {
|
log.debug("Close MCP client failed, mountCode={}", mount == null ? null : mount.getMountCode(), e);
|
}
|
}
|
}
|
|
private static final class MountedTool {
|
|
private final AiMcpMount mount;
|
private final String originalName;
|
private final String toolName;
|
private final ToolCallback callback;
|
|
private MountedTool(AiMcpMount mount, String originalName, String toolName, ToolCallback callback) {
|
this.mount = mount;
|
this.originalName = originalName;
|
this.toolName = toolName;
|
this.callback = callback;
|
}
|
}
|
|
private static final class MountedToolCallback implements ToolCallback {
|
|
private final ToolCallback delegate;
|
private final ToolDefinition definition;
|
|
private MountedToolCallback(String name, ToolCallback delegate) {
|
this.delegate = delegate;
|
ToolDefinition source = delegate.getToolDefinition();
|
this.definition = DefaultToolDefinition.builder()
|
.name(name)
|
.description(source == null ? "" : source.description())
|
.inputSchema(source == null ? "{}" : source.inputSchema())
|
.build();
|
}
|
|
@Override
|
public ToolDefinition getToolDefinition() {
|
return definition;
|
}
|
|
@Override
|
public org.springframework.ai.tool.metadata.ToolMetadata getToolMetadata() {
|
return delegate.getToolMetadata();
|
}
|
|
@Override
|
public String call(String toolInput) {
|
return delegate.call(toolInput);
|
}
|
|
@Override
|
public String call(String toolInput, org.springframework.ai.chat.model.ToolContext toolContext) {
|
return delegate.call(toolInput, toolContext);
|
}
|
}
|
|
private static final class TransportTarget {
|
|
private final String url;
|
private final String baseUrl;
|
private final String endpoint;
|
|
private TransportTarget(String url, String baseUrl, String endpoint) {
|
this.url = url;
|
this.baseUrl = baseUrl;
|
this.endpoint = endpoint;
|
}
|
}
|
}
|