#
Junjie
3 天以前 1b8a4677f362d234d834120deac4880d7ae89a50
src/main/java/com/zy/ai/mcp/service/SpringAiMcpToolManager.java
@@ -2,73 +2,52 @@
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.zy.ai.entity.AiMcpMount;
import com.zy.ai.enums.AiMcpTransportType;
import com.zy.ai.service.AiMcpMountService;
import io.modelcontextprotocol.client.McpClient;
import io.modelcontextprotocol.client.McpSyncClient;
import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport;
import io.modelcontextprotocol.spec.McpSchema;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.mcp.SyncMcpToolCallbackProvider;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.ToolCallbackProvider;
import org.springframework.ai.tool.definition.DefaultToolDefinition;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import jakarta.annotation.PreDestroy;
import java.net.URI;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
@Slf4j
@Service
@RequiredArgsConstructor
public class SpringAiMcpToolManager {
    private static final McpSchema.Implementation CLIENT_INFO =
            new McpSchema.Implementation("wcs-ai-assistant", "1.0.0");
    private final Object clientMonitor = new Object();
    @Value("${server.port:9090}")
    private Integer serverPort;
    @Value("${server.servlet.context-path:}")
    private String contextPath;
    @Value("${spring.ai.mcp.server.sse-endpoint:/ai/mcp/sse}")
    private String sseEndpoint;
    private final AiMcpMountService aiMcpMountService;
    @Value("${spring.ai.mcp.server.request-timeout:20s}")
    private Duration requestTimeout;
    private Duration defaultRequestTimeout;
    @Value("${app.ai.mcp.client.base-url:}")
    private String configuredBaseUrl;
    private volatile ClientSession clientSession;
    private volatile ClientRegistry clientRegistry;
    public List<Map<String, Object>> listTools() {
        List<Map<String, Object>> tools = new ArrayList<Map<String, Object>>();
        for (ToolCallback callback : getToolCallbacks()) {
            if (callback == null || callback.getToolDefinition() == null) {
                continue;
            }
            ToolDefinition definition = callback.getToolDefinition();
            Map<String, Object> item = new LinkedHashMap<String, Object>();
            item.put("name", definition.name());
            item.put("description", definition.description());
            item.put("inputSchema", parseSchema(definition.inputSchema()));
            tools.add(item);
        }
        tools.sort(new Comparator<Map<String, Object>>() {
            @Override
            public int compare(Map<String, Object> left, Map<String, Object> right) {
                return String.valueOf(left.get("name")).compareTo(String.valueOf(right.get("name")));
            }
        });
        return tools;
        return ensureClientRegistry().toolList;
    }
    public List<Object> buildOpenAiTools() {
@@ -97,40 +76,301 @@
    }
    public Object callTool(String toolName, JSONObject arguments) {
        if (toolName == null || toolName.trim().isEmpty()) {
        if (isBlank(toolName)) {
            throw new IllegalArgumentException("missing tool name");
        }
        ToolCallback callback = findCallback(toolName);
        if (callback == null) {
        MountedTool mountedTool = ensureClientRegistry().toolMap.get(toolName);
        if (mountedTool == null) {
            throw new IllegalArgumentException("tool not found: " + toolName);
        }
        String rawResult = callback.call(arguments == null ? "{}" : arguments.toJSONString());
        return parseToolResult(rawResult);
        try {
            String rawResult = mountedTool.callback.call(arguments == null ? "{}" : arguments.toJSONString());
            return parseToolResult(rawResult);
        } catch (Exception e) {
            evictCache();
            throw e;
        }
    }
    private ToolCallback findCallback(String toolName) {
        for (ToolCallback callback : getToolCallbacks()) {
            if (callback == null || callback.getToolDefinition() == null) {
                continue;
            }
            if (toolName.equals(callback.getToolDefinition().name())) {
                return callback;
    public Map<String, Object> testMount(AiMcpMount mount) {
        if (mount == null) {
            throw new IllegalArgumentException("参数不能为空");
        }
        MountSession session = null;
        long start = System.currentTimeMillis();
        try {
            session = openSession(mount);
            LinkedHashMap<String, Object> result = new LinkedHashMap<String, Object>();
            result.put("ok", true);
            result.put("message", "连接成功,已发现 " + session.callbacks.length + " 个工具");
            result.put("latencyMs", System.currentTimeMillis() - start);
            result.put("url", session.url);
            result.put("transportType", mount.getTransportType());
            result.put("toolCount", session.callbacks.length);
            result.put("toolNames", collectToolNames(session.callbacks, mount, reservedToolNames(mount.getMountCode())));
            return result;
        } catch (Exception e) {
            TransportTarget target = resolveTransportTarget(mount);
            LinkedHashMap<String, Object> result = new LinkedHashMap<String, Object>();
            result.put("ok", false);
            result.put("message", safeMessage(e));
            result.put("latencyMs", System.currentTimeMillis() - start);
            result.put("url", target.url);
            result.put("transportType", mount.getTransportType());
            result.put("toolCount", 0);
            result.put("toolNames", new ArrayList<String>());
            return result;
        } finally {
            if (session != null) {
                session.close();
            }
        }
        return null;
    }
    private ToolCallback[] getToolCallbacks() {
    public void evictCache() {
        resetClientRegistry();
    }
    private ClientRegistry ensureClientRegistry() {
        ClientRegistry current = clientRegistry;
        if (current != null) {
            return current;
        }
        synchronized (clientMonitor) {
            current = clientRegistry;
            if (current != null) {
                return current;
            }
            current = buildClientRegistry();
            clientRegistry = current;
            return current;
        }
    }
    private ClientRegistry buildClientRegistry() {
        List<AiMcpMount> mounts = loadEnabledMounts();
        List<MountSession> sessions = new ArrayList<MountSession>();
        LinkedHashMap<String, MountedTool> toolMap = new LinkedHashMap<String, MountedTool>();
        List<Map<String, Object>> toolList = new ArrayList<Map<String, Object>>();
        for (AiMcpMount mount : mounts) {
            if (mount == null) {
                continue;
            }
            MountSession session = null;
            try {
                session = openSession(mount);
                sessions.add(session);
                for (ToolCallback callback : session.callbacks) {
                    MountedTool mountedTool = buildMountedTool(mount, callback, toolMap.keySet());
                    toolMap.put(mountedTool.toolName, mountedTool);
                    toolList.add(toToolDescriptor(mountedTool));
                }
                log.info("MCP mount loaded, mountCode={}, transport={}, url={}, tools={}",
                        mount.getMountCode(), mount.getTransportType(), session.url, session.callbacks.length);
            } catch (Exception e) {
                log.warn("Failed to load MCP mount, mountCode={}, transport={}, url={}",
                        mount.getMountCode(), mount.getTransportType(), resolveUrl(mount), e);
                if (session != null) {
                    session.close();
                }
            }
        }
        toolList.sort(new Comparator<Map<String, Object>>() {
            @Override
            public int compare(Map<String, Object> left, Map<String, Object> right) {
                return String.valueOf(left.get("name")).compareTo(String.valueOf(right.get("name")));
            }
        });
        return new ClientRegistry(sessions, toolMap, toolList);
    }
    private List<AiMcpMount> loadEnabledMounts() {
        try {
            ToolCallback[] callbacks = ensureToolCallbackProvider().getToolCallbacks();
            return callbacks == null ? new ToolCallback[0] : callbacks;
            List<AiMcpMount> mounts = aiMcpMountService.listEnabledOrdered();
            if (mounts == null || mounts.isEmpty()) {
                aiMcpMountService.initDefaultsIfMissing();
                mounts = aiMcpMountService.listEnabledOrdered();
            }
            return mounts == null ? Collections.<AiMcpMount>emptyList() : mounts;
        } catch (Exception e) {
            log.warn("Failed to load MCP tools through SSE client, baseUrl={}, sseEndpoint={}",
                    resolveBaseUrl(), resolveClientSseEndpoint(), e);
            resetClientSession();
            return new ToolCallback[0];
            log.warn("Failed to query MCP mount configuration", e);
            return Collections.emptyList();
        }
    }
    private MountedTool buildMountedTool(AiMcpMount mount, ToolCallback callback, java.util.Set<String> usedNames) {
        ToolDefinition definition = callback == null ? null : callback.getToolDefinition();
        if (definition == null || isBlank(definition.name())) {
            throw new IllegalArgumentException("invalid tool definition");
        }
        String originalName = definition.name();
        String preferredName = mount.getMountCode() + "_" + originalName;
        String finalName = ensureUniqueToolName(preferredName, mount, originalName, usedNames);
        ToolCallback effectiveCallback = originalName.equals(finalName)
                ? callback
                : new MountedToolCallback(finalName, callback);
        return new MountedTool(mount, originalName, finalName, effectiveCallback);
    }
    private String ensureUniqueToolName(String preferredName,
                                        AiMcpMount mount,
                                        String originalName,
                                        java.util.Set<String> usedNames) {
        if (!usedNames.contains(preferredName)) {
            return preferredName;
        }
        String fallbackBase = mount.getMountCode() + "_" + originalName;
        if (!usedNames.contains(fallbackBase)) {
            log.warn("Duplicate MCP tool name detected, fallback rename applied, mountCode={}, originalName={}, finalName={}",
                    mount.getMountCode(), originalName, fallbackBase);
            return fallbackBase;
        }
        int index = 2;
        String candidate = fallbackBase + "_" + index;
        while (usedNames.contains(candidate)) {
            index++;
            candidate = fallbackBase + "_" + index;
        }
        log.warn("Duplicate MCP tool name detected, numbered rename applied, mountCode={}, originalName={}, finalName={}",
                mount.getMountCode(), originalName, candidate);
        return candidate;
    }
    private Map<String, Object> toToolDescriptor(MountedTool mountedTool) {
        ToolDefinition definition = mountedTool.callback.getToolDefinition();
        Map<String, Object> item = new LinkedHashMap<String, Object>();
        item.put("name", definition.name());
        item.put("originalName", mountedTool.originalName);
        item.put("mountCode", mountedTool.mount.getMountCode());
        item.put("mountName", mountedTool.mount.getName());
        item.put("transportType", mountedTool.mount.getTransportType());
        item.put("description", definition.description());
        item.put("inputSchema", parseSchema(definition.inputSchema()));
        return item;
    }
    private List<String> collectToolNames(ToolCallback[] callbacks, AiMcpMount mount, java.util.Set<String> reservedNames) {
        List<String> names = new ArrayList<String>();
        java.util.Set<String> used = new LinkedHashSet<String>();
        if (reservedNames != null && !reservedNames.isEmpty()) {
            used.addAll(reservedNames);
        }
        if (callbacks == null) {
            return names;
        }
        for (ToolCallback callback : callbacks) {
            ToolDefinition definition = callback == null ? null : callback.getToolDefinition();
            if (definition == null || isBlank(definition.name())) {
                continue;
            }
            String preferred = mount.getMountCode() + "_" + definition.name();
            String finalName = ensureUniqueToolName(preferred, mount, definition.name(), used);
            used.add(finalName);
            names.add(finalName);
        }
        return names;
    }
    private java.util.Set<String> reservedToolNames(String mountCode) {
        LinkedHashSet<String> reserved = new LinkedHashSet<String>();
        ClientRegistry current = clientRegistry;
        if (current == null || current.toolMap == null || current.toolMap.isEmpty()) {
            return reserved;
        }
        for (MountedTool mountedTool : current.toolMap.values()) {
            if (mountedTool == null || mountedTool.mount == null) {
                continue;
            }
            if (mountCode != null && mountCode.equals(mountedTool.mount.getMountCode())) {
                continue;
            }
            reserved.add(mountedTool.toolName);
        }
        return reserved;
    }
    private MountSession openSession(AiMcpMount mount) {
        Duration timeout = resolveTimeout(mount);
        TransportTarget target = resolveTransportTarget(mount);
        String baseUrl = target.baseUrl;
        String endpoint = target.endpoint;
        AiMcpTransportType transportType = AiMcpTransportType.ofCode(mount.getTransportType());
        McpSyncClient syncClient;
        if (transportType == AiMcpTransportType.STREAMABLE_HTTP) {
            HttpClientStreamableHttpTransport transport = HttpClientStreamableHttpTransport.builder(baseUrl)
                    .endpoint(endpoint)
                    .connectTimeout(timeout)
                    .build();
            syncClient = McpClient.sync(transport)
                    .clientInfo(CLIENT_INFO)
                    .requestTimeout(timeout)
                    .initializationTimeout(timeout)
                    .build();
        } else {
            HttpClientSseClientTransport transport = HttpClientSseClientTransport.builder(baseUrl)
                    .sseEndpoint(endpoint)
                    .connectTimeout(timeout)
                    .build();
            syncClient = McpClient.sync(transport)
                    .clientInfo(CLIENT_INFO)
                    .requestTimeout(timeout)
                    .initializationTimeout(timeout)
                    .build();
        }
        syncClient.initialize();
        SyncMcpToolCallbackProvider callbackProvider = new SyncMcpToolCallbackProvider(syncClient);
        ToolCallback[] callbacks = callbackProvider.getToolCallbacks();
        return new MountSession(mount, syncClient, target.url, callbacks == null ? new ToolCallback[0] : callbacks);
    }
    private Duration resolveTimeout(AiMcpMount mount) {
        if (mount != null && mount.getRequestTimeoutMs() != null && mount.getRequestTimeoutMs() > 0) {
            return Duration.ofMillis(mount.getRequestTimeoutMs());
        }
        return defaultRequestTimeout == null ? Duration.ofSeconds(20) : defaultRequestTimeout;
    }
    private String resolveUrl(AiMcpMount mount) {
        return trim(mount == null ? null : mount.getUrl());
    }
    private TransportTarget resolveTransportTarget(AiMcpMount mount) {
        String rawUrl = resolveUrl(mount);
        if (isBlank(rawUrl)) {
            throw new IllegalArgumentException("missing url");
        }
        try {
            URI finalUri = URI.create(rawUrl);
            String authority = finalUri.getRawAuthority();
            String scheme = finalUri.getScheme();
            if (isBlank(scheme) || isBlank(authority)) {
                throw new IllegalArgumentException("invalid MCP url");
            }
            String baseUrl = scheme + "://" + authority;
            String endpoint = finalUri.getRawPath();
            if (isBlank(endpoint)) {
                throw new IllegalArgumentException("missing MCP path");
            }
            if (finalUri.getRawQuery() != null && !finalUri.getRawQuery().isEmpty()) {
                endpoint = endpoint + "?" + finalUri.getRawQuery();
            }
            return new TransportTarget(rawUrl, baseUrl, endpoint);
        } catch (IllegalArgumentException e) {
            throw e;
        } catch (Exception e) {
            throw new IllegalArgumentException("invalid url: " + rawUrl, e);
        }
    }
@@ -163,89 +403,8 @@
        return fallback;
    }
    private ToolCallbackProvider ensureToolCallbackProvider() {
        return ensureClientSession().toolCallbackProvider;
    }
    private ClientSession ensureClientSession() {
        ClientSession current = clientSession;
        if (current != null) {
            return current;
        }
        synchronized (clientMonitor) {
            current = clientSession;
            if (current != null) {
                return current;
            }
            String baseUrl = resolveBaseUrl();
            String clientSseEndpoint = resolveClientSseEndpoint();
            HttpClientSseClientTransport transport = HttpClientSseClientTransport.builder(baseUrl)
                    .sseEndpoint(clientSseEndpoint)
                    .connectTimeout(requestTimeout)
                    .build();
            McpSyncClient syncClient = McpClient.sync(transport)
                    .clientInfo(CLIENT_INFO)
                    .requestTimeout(requestTimeout)
                    .initializationTimeout(requestTimeout)
                    .build();
            syncClient.initialize();
            SyncMcpToolCallbackProvider callbackProvider = new SyncMcpToolCallbackProvider(syncClient);
            current = new ClientSession(syncClient, callbackProvider, baseUrl);
            clientSession = current;
            log.info("Spring AI MCP SSE client initialized, baseUrl={}, sseEndpoint={}, tools={}",
                    baseUrl, clientSseEndpoint, current.toolCallbackProvider.getToolCallbacks().length);
            return current;
        }
    }
    private void resetClientSession() {
        synchronized (clientMonitor) {
            ClientSession current = clientSession;
            clientSession = null;
            if (current != null) {
                current.close();
            }
        }
    }
    private String resolveBaseUrl() {
        if (configuredBaseUrl != null && !configuredBaseUrl.trim().isEmpty()) {
            return trimTrailingSlash(configuredBaseUrl.trim());
        }
        StringBuilder url = new StringBuilder("http://127.0.0.1:");
        url.append(serverPort == null ? 9090 : serverPort);
        return trimTrailingSlash(url.toString());
    }
    private String resolveClientSseEndpoint() {
        String endpoint = normalizePath(sseEndpoint);
        if (configuredBaseUrl != null && !configuredBaseUrl.trim().isEmpty()) {
            return endpoint;
        }
        String context = normalizeContextPath(contextPath);
        if (context.isEmpty()) {
            return endpoint;
        }
        return context + endpoint;
    }
    private String normalizeContextPath(String path) {
        if (path == null || path.trim().isEmpty() || "/".equals(path.trim())) {
            return "";
        }
        String value = path.trim();
        if (!value.startsWith("/")) {
            value = "/" + value;
        }
        return trimTrailingSlash(value);
    }
    private String normalizePath(String path) {
        if (path == null || path.trim().isEmpty()) {
        if (isBlank(path)) {
            return "/";
        }
        String value = path.trim();
@@ -255,28 +414,81 @@
        return value;
    }
    private String trimTrailingSlash(String value) {
        if (value == null || value.isEmpty()) {
            return "";
    private String safeMessage(Throwable throwable) {
        if (throwable == null) {
            return "unknown error";
        }
        return value.endsWith("/") && value.length() > 1 ? value.substring(0, value.length() - 1) : value;
        if (!isBlank(throwable.getMessage())) {
            return throwable.getMessage();
        }
        return throwable.getClass().getSimpleName();
    }
    private String trim(String text) {
        return text == null ? null : text.trim();
    }
    private boolean isBlank(String text) {
        return text == null || text.trim().isEmpty();
    }
    private void resetClientRegistry() {
        synchronized (clientMonitor) {
            ClientRegistry current = clientRegistry;
            clientRegistry = null;
            if (current != null) {
                current.close();
            }
        }
    }
    @PreDestroy
    public void destroy() {
        resetClientSession();
        resetClientRegistry();
    }
    private static final class ClientSession implements AutoCloseable {
    private static final class ClientRegistry implements AutoCloseable {
        private final List<MountSession> sessions;
        private final Map<String, MountedTool> toolMap;
        private final List<Map<String, Object>> toolList;
        private ClientRegistry(List<MountSession> sessions,
                               Map<String, MountedTool> toolMap,
                               List<Map<String, Object>> toolList) {
            this.sessions = sessions;
            this.toolMap = toolMap;
            this.toolList = toolList;
        }
        @Override
        public void close() {
            if (sessions == null) {
                return;
            }
            for (MountSession session : sessions) {
                if (session != null) {
                    session.close();
                }
            }
        }
    }
    private static final class MountSession implements AutoCloseable {
        private final AiMcpMount mount;
        private final McpSyncClient syncClient;
        private final ToolCallbackProvider toolCallbackProvider;
        private final String baseUrl;
        private final String url;
        private final ToolCallback[] callbacks;
        private ClientSession(McpSyncClient syncClient, ToolCallbackProvider toolCallbackProvider, String baseUrl) {
        private MountSession(AiMcpMount mount,
                             McpSyncClient syncClient,
                             String url,
                             ToolCallback[] callbacks) {
            this.mount = mount;
            this.syncClient = syncClient;
            this.toolCallbackProvider = toolCallbackProvider;
            this.baseUrl = baseUrl;
            this.url = url;
            this.callbacks = callbacks;
        }
        @Override
@@ -284,13 +496,77 @@
            try {
                syncClient.closeGracefully();
            } catch (Exception e) {
                log.debug("Close MCP SSE client gracefully failed, baseUrl={}", baseUrl, e);
                log.debug("Close MCP client gracefully failed, mountCode={}", mount == null ? null : mount.getMountCode(), e);
            }
            try {
                syncClient.close();
            } catch (Exception e) {
                log.debug("Close MCP SSE client failed, baseUrl={}", baseUrl, e);
                log.debug("Close MCP client failed, mountCode={}", mount == null ? null : mount.getMountCode(), e);
            }
        }
    }
    private static final class MountedTool {
        private final AiMcpMount mount;
        private final String originalName;
        private final String toolName;
        private final ToolCallback callback;
        private MountedTool(AiMcpMount mount, String originalName, String toolName, ToolCallback callback) {
            this.mount = mount;
            this.originalName = originalName;
            this.toolName = toolName;
            this.callback = callback;
        }
    }
    private static final class MountedToolCallback implements ToolCallback {
        private final ToolCallback delegate;
        private final ToolDefinition definition;
        private MountedToolCallback(String name, ToolCallback delegate) {
            this.delegate = delegate;
            ToolDefinition source = delegate.getToolDefinition();
            this.definition = DefaultToolDefinition.builder()
                    .name(name)
                    .description(source == null ? "" : source.description())
                    .inputSchema(source == null ? "{}" : source.inputSchema())
                    .build();
        }
        @Override
        public ToolDefinition getToolDefinition() {
            return definition;
        }
        @Override
        public org.springframework.ai.tool.metadata.ToolMetadata getToolMetadata() {
            return delegate.getToolMetadata();
        }
        @Override
        public String call(String toolInput) {
            return delegate.call(toolInput);
        }
        @Override
        public String call(String toolInput, org.springframework.ai.chat.model.ToolContext toolContext) {
            return delegate.call(toolInput, toolContext);
        }
    }
    private static final class TransportTarget {
        private final String url;
        private final String baseUrl;
        private final String endpoint;
        private TransportTarget(String url, String baseUrl, String endpoint) {
            this.url = url;
            this.baseUrl = baseUrl;
            this.endpoint = endpoint;
        }
    }
}