package com.zy.ai.mcp.service; import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONObject; import com.zy.ai.entity.AiMcpMount; import com.zy.ai.enums.AiMcpTransportType; import com.zy.ai.service.AiMcpMountService; import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.McpSyncClient; import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; import io.modelcontextprotocol.spec.McpSchema; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.mcp.SyncMcpToolCallbackProvider; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.definition.DefaultToolDefinition; import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; import jakarta.annotation.PreDestroy; import java.net.URI; import java.time.Duration; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; @Slf4j @Service @RequiredArgsConstructor public class SpringAiMcpToolManager { private static final McpSchema.Implementation CLIENT_INFO = new McpSchema.Implementation("wcs-ai-assistant", "1.0.0"); private final Object clientMonitor = new Object(); private final AiMcpMountService aiMcpMountService; @Value("${spring.ai.mcp.server.request-timeout:20s}") private Duration defaultRequestTimeout; private volatile ClientRegistry clientRegistry; public List> listTools() { return ensureClientRegistry().toolList; } public List buildOpenAiTools() { List tools = new ArrayList(); for (Map item : listTools()) { Object name = item.get("name"); if (name == null) { continue; } Map function = new LinkedHashMap(); function.put("name", String.valueOf(name)); Object description = item.get("description"); if (description != null) { function.put("description", String.valueOf(description)); } Object inputSchema = item.get("inputSchema"); function.put("parameters", inputSchema == null ? new LinkedHashMap() : inputSchema); Map tool = new LinkedHashMap(); tool.put("type", "function"); tool.put("function", function); tools.add(tool); } return tools; } public Object callTool(String toolName, JSONObject arguments) { if (isBlank(toolName)) { throw new IllegalArgumentException("missing tool name"); } MountedTool mountedTool = ensureClientRegistry().toolMap.get(toolName); if (mountedTool == null) { throw new IllegalArgumentException("tool not found: " + toolName); } try { String rawResult = mountedTool.callback.call(arguments == null ? "{}" : arguments.toJSONString()); return parseToolResult(rawResult); } catch (Exception e) { evictCache(); throw e; } } public Map testMount(AiMcpMount mount) { if (mount == null) { throw new IllegalArgumentException("参数不能为空"); } MountSession session = null; long start = System.currentTimeMillis(); try { session = openSession(mount); LinkedHashMap result = new LinkedHashMap(); result.put("ok", true); result.put("message", "连接成功,已发现 " + session.callbacks.length + " 个工具"); result.put("latencyMs", System.currentTimeMillis() - start); result.put("url", session.url); result.put("transportType", mount.getTransportType()); result.put("toolCount", session.callbacks.length); result.put("toolNames", collectToolNames(session.callbacks, mount, reservedToolNames(mount.getMountCode()))); return result; } catch (Exception e) { TransportTarget target = resolveTransportTarget(mount); LinkedHashMap result = new LinkedHashMap(); result.put("ok", false); result.put("message", safeMessage(e)); result.put("latencyMs", System.currentTimeMillis() - start); result.put("url", target.url); result.put("transportType", mount.getTransportType()); result.put("toolCount", 0); result.put("toolNames", new ArrayList()); return result; } finally { if (session != null) { session.close(); } } } public void evictCache() { resetClientRegistry(); } private ClientRegistry ensureClientRegistry() { ClientRegistry current = clientRegistry; if (current != null) { return current; } synchronized (clientMonitor) { current = clientRegistry; if (current != null) { return current; } current = buildClientRegistry(); clientRegistry = current; return current; } } private ClientRegistry buildClientRegistry() { List mounts = loadEnabledMounts(); List sessions = new ArrayList(); LinkedHashMap toolMap = new LinkedHashMap(); List> toolList = new ArrayList>(); for (AiMcpMount mount : mounts) { if (mount == null) { continue; } MountSession session = null; try { session = openSession(mount); sessions.add(session); for (ToolCallback callback : session.callbacks) { MountedTool mountedTool = buildMountedTool(mount, callback, toolMap.keySet()); toolMap.put(mountedTool.toolName, mountedTool); toolList.add(toToolDescriptor(mountedTool)); } log.info("MCP mount loaded, mountCode={}, transport={}, url={}, tools={}", mount.getMountCode(), mount.getTransportType(), session.url, session.callbacks.length); } catch (Exception e) { log.warn("Failed to load MCP mount, mountCode={}, transport={}, url={}", mount.getMountCode(), mount.getTransportType(), resolveUrl(mount), e); if (session != null) { session.close(); } } } toolList.sort(new Comparator>() { @Override public int compare(Map left, Map right) { return String.valueOf(left.get("name")).compareTo(String.valueOf(right.get("name"))); } }); return new ClientRegistry(sessions, toolMap, toolList); } private List loadEnabledMounts() { try { List mounts = aiMcpMountService.listEnabledOrdered(); if (mounts == null || mounts.isEmpty()) { aiMcpMountService.initDefaultsIfMissing(); mounts = aiMcpMountService.listEnabledOrdered(); } return mounts == null ? Collections.emptyList() : mounts; } catch (Exception e) { log.warn("Failed to query MCP mount configuration", e); return Collections.emptyList(); } } private MountedTool buildMountedTool(AiMcpMount mount, ToolCallback callback, java.util.Set usedNames) { ToolDefinition definition = callback == null ? null : callback.getToolDefinition(); if (definition == null || isBlank(definition.name())) { throw new IllegalArgumentException("invalid tool definition"); } String originalName = definition.name(); String preferredName = mount.getMountCode() + "_" + originalName; String finalName = ensureUniqueToolName(preferredName, mount, originalName, usedNames); ToolCallback effectiveCallback = originalName.equals(finalName) ? callback : new MountedToolCallback(finalName, callback); return new MountedTool(mount, originalName, finalName, effectiveCallback); } private String ensureUniqueToolName(String preferredName, AiMcpMount mount, String originalName, java.util.Set usedNames) { if (!usedNames.contains(preferredName)) { return preferredName; } String fallbackBase = mount.getMountCode() + "_" + originalName; if (!usedNames.contains(fallbackBase)) { log.warn("Duplicate MCP tool name detected, fallback rename applied, mountCode={}, originalName={}, finalName={}", mount.getMountCode(), originalName, fallbackBase); return fallbackBase; } int index = 2; String candidate = fallbackBase + "_" + index; while (usedNames.contains(candidate)) { index++; candidate = fallbackBase + "_" + index; } log.warn("Duplicate MCP tool name detected, numbered rename applied, mountCode={}, originalName={}, finalName={}", mount.getMountCode(), originalName, candidate); return candidate; } private Map toToolDescriptor(MountedTool mountedTool) { ToolDefinition definition = mountedTool.callback.getToolDefinition(); Map item = new LinkedHashMap(); item.put("name", definition.name()); item.put("originalName", mountedTool.originalName); item.put("mountCode", mountedTool.mount.getMountCode()); item.put("mountName", mountedTool.mount.getName()); item.put("transportType", mountedTool.mount.getTransportType()); item.put("description", definition.description()); item.put("inputSchema", parseSchema(definition.inputSchema())); return item; } private List collectToolNames(ToolCallback[] callbacks, AiMcpMount mount, java.util.Set reservedNames) { List names = new ArrayList(); java.util.Set used = new LinkedHashSet(); if (reservedNames != null && !reservedNames.isEmpty()) { used.addAll(reservedNames); } if (callbacks == null) { return names; } for (ToolCallback callback : callbacks) { ToolDefinition definition = callback == null ? null : callback.getToolDefinition(); if (definition == null || isBlank(definition.name())) { continue; } String preferred = mount.getMountCode() + "_" + definition.name(); String finalName = ensureUniqueToolName(preferred, mount, definition.name(), used); used.add(finalName); names.add(finalName); } return names; } private java.util.Set reservedToolNames(String mountCode) { LinkedHashSet reserved = new LinkedHashSet(); ClientRegistry current = clientRegistry; if (current == null || current.toolMap == null || current.toolMap.isEmpty()) { return reserved; } for (MountedTool mountedTool : current.toolMap.values()) { if (mountedTool == null || mountedTool.mount == null) { continue; } if (mountCode != null && mountCode.equals(mountedTool.mount.getMountCode())) { continue; } reserved.add(mountedTool.toolName); } return reserved; } private MountSession openSession(AiMcpMount mount) { Duration timeout = resolveTimeout(mount); TransportTarget target = resolveTransportTarget(mount); String baseUrl = target.baseUrl; String endpoint = target.endpoint; AiMcpTransportType transportType = AiMcpTransportType.ofCode(mount.getTransportType()); McpSyncClient syncClient; if (transportType == AiMcpTransportType.STREAMABLE_HTTP) { HttpClientStreamableHttpTransport transport = HttpClientStreamableHttpTransport.builder(baseUrl) .endpoint(endpoint) .connectTimeout(timeout) .build(); syncClient = McpClient.sync(transport) .clientInfo(CLIENT_INFO) .requestTimeout(timeout) .initializationTimeout(timeout) .build(); } else { HttpClientSseClientTransport transport = HttpClientSseClientTransport.builder(baseUrl) .sseEndpoint(endpoint) .connectTimeout(timeout) .build(); syncClient = McpClient.sync(transport) .clientInfo(CLIENT_INFO) .requestTimeout(timeout) .initializationTimeout(timeout) .build(); } syncClient.initialize(); SyncMcpToolCallbackProvider callbackProvider = new SyncMcpToolCallbackProvider(syncClient); ToolCallback[] callbacks = callbackProvider.getToolCallbacks(); return new MountSession(mount, syncClient, target.url, callbacks == null ? new ToolCallback[0] : callbacks); } private Duration resolveTimeout(AiMcpMount mount) { if (mount != null && mount.getRequestTimeoutMs() != null && mount.getRequestTimeoutMs() > 0) { return Duration.ofMillis(mount.getRequestTimeoutMs()); } return defaultRequestTimeout == null ? Duration.ofSeconds(20) : defaultRequestTimeout; } private String resolveUrl(AiMcpMount mount) { return trim(mount == null ? null : mount.getUrl()); } private TransportTarget resolveTransportTarget(AiMcpMount mount) { String rawUrl = resolveUrl(mount); if (isBlank(rawUrl)) { throw new IllegalArgumentException("missing url"); } try { URI finalUri = URI.create(rawUrl); String authority = finalUri.getRawAuthority(); String scheme = finalUri.getScheme(); if (isBlank(scheme) || isBlank(authority)) { throw new IllegalArgumentException("invalid MCP url"); } String baseUrl = scheme + "://" + authority; String endpoint = finalUri.getRawPath(); if (isBlank(endpoint)) { throw new IllegalArgumentException("missing MCP path"); } if (finalUri.getRawQuery() != null && !finalUri.getRawQuery().isEmpty()) { endpoint = endpoint + "?" + finalUri.getRawQuery(); } return new TransportTarget(rawUrl, baseUrl, endpoint); } catch (IllegalArgumentException e) { throw e; } catch (Exception e) { throw new IllegalArgumentException("invalid url: " + rawUrl, e); } } private Object parseToolResult(String rawResult) { if (rawResult == null || rawResult.trim().isEmpty()) { return rawResult; } try { return JSON.parse(rawResult); } catch (Exception ignore) { return rawResult; } } @SuppressWarnings("unchecked") private Map parseSchema(String inputSchema) { if (inputSchema == null || inputSchema.trim().isEmpty()) { return Collections.emptyMap(); } try { Object parsed = JSON.parse(inputSchema); if (parsed instanceof Map) { return new LinkedHashMap((Map) parsed); } } catch (Exception e) { log.warn("Failed to parse MCP tool schema: {}", inputSchema, e); } Map fallback = new LinkedHashMap(); fallback.put("type", "object"); return fallback; } private String normalizePath(String path) { if (isBlank(path)) { return "/"; } String value = path.trim(); if (!value.startsWith("/")) { value = "/" + value; } return value; } private String safeMessage(Throwable throwable) { if (throwable == null) { return "unknown error"; } if (!isBlank(throwable.getMessage())) { return throwable.getMessage(); } return throwable.getClass().getSimpleName(); } private String trim(String text) { return text == null ? null : text.trim(); } private boolean isBlank(String text) { return text == null || text.trim().isEmpty(); } private void resetClientRegistry() { synchronized (clientMonitor) { ClientRegistry current = clientRegistry; clientRegistry = null; if (current != null) { current.close(); } } } @PreDestroy public void destroy() { resetClientRegistry(); } private static final class ClientRegistry implements AutoCloseable { private final List sessions; private final Map toolMap; private final List> toolList; private ClientRegistry(List sessions, Map toolMap, List> toolList) { this.sessions = sessions; this.toolMap = toolMap; this.toolList = toolList; } @Override public void close() { if (sessions == null) { return; } for (MountSession session : sessions) { if (session != null) { session.close(); } } } } private static final class MountSession implements AutoCloseable { private final AiMcpMount mount; private final McpSyncClient syncClient; private final String url; private final ToolCallback[] callbacks; private MountSession(AiMcpMount mount, McpSyncClient syncClient, String url, ToolCallback[] callbacks) { this.mount = mount; this.syncClient = syncClient; this.url = url; this.callbacks = callbacks; } @Override public void close() { try { syncClient.closeGracefully(); } catch (Exception e) { log.debug("Close MCP client gracefully failed, mountCode={}", mount == null ? null : mount.getMountCode(), e); } try { syncClient.close(); } catch (Exception e) { log.debug("Close MCP client failed, mountCode={}", mount == null ? null : mount.getMountCode(), e); } } } private static final class MountedTool { private final AiMcpMount mount; private final String originalName; private final String toolName; private final ToolCallback callback; private MountedTool(AiMcpMount mount, String originalName, String toolName, ToolCallback callback) { this.mount = mount; this.originalName = originalName; this.toolName = toolName; this.callback = callback; } } private static final class MountedToolCallback implements ToolCallback { private final ToolCallback delegate; private final ToolDefinition definition; private MountedToolCallback(String name, ToolCallback delegate) { this.delegate = delegate; ToolDefinition source = delegate.getToolDefinition(); this.definition = DefaultToolDefinition.builder() .name(name) .description(source == null ? "" : source.description()) .inputSchema(source == null ? "{}" : source.inputSchema()) .build(); } @Override public ToolDefinition getToolDefinition() { return definition; } @Override public org.springframework.ai.tool.metadata.ToolMetadata getToolMetadata() { return delegate.getToolMetadata(); } @Override public String call(String toolInput) { return delegate.call(toolInput); } @Override public String call(String toolInput, org.springframework.ai.chat.model.ToolContext toolContext) { return delegate.call(toolInput, toolContext); } } private static final class TransportTarget { private final String url; private final String baseUrl; private final String endpoint; private TransportTarget(String url, String baseUrl, String endpoint) { this.url = url; this.baseUrl = baseUrl; this.endpoint = endpoint; } } }