1
昨天 b2deb1cc93b3d2c3fb9dc795e3589e1c62329a8f
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/McpMountRuntimeFactoryImpl.java
@@ -1,20 +1,13 @@
package com.vincent.rsf.server.ai.service.impl;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.vincent.rsf.framework.exception.CoolException;
import com.vincent.rsf.server.ai.config.AiDefaults;
import com.vincent.rsf.server.ai.entity.AiMcpMount;
import com.vincent.rsf.server.ai.service.BuiltinMcpToolRegistry;
import com.vincent.rsf.server.ai.service.MountedToolCallback;
import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory;
import io.modelcontextprotocol.client.McpClient;
import io.modelcontextprotocol.client.McpSyncClient;
import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
import io.modelcontextprotocol.client.transport.ServerParameters;
import io.modelcontextprotocol.client.transport.StdioClientTransport;
import io.modelcontextprotocol.json.jackson.JacksonMcpJsonMapper;
import io.modelcontextprotocol.spec.McpSchema;
import com.vincent.rsf.server.ai.service.impl.mcp.McpClientFactory;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.mcp.SyncMcpToolCallbackProvider;
@@ -22,23 +15,25 @@
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
@Slf4j
@Service
@RequiredArgsConstructor
public class McpMountRuntimeFactoryImpl implements McpMountRuntimeFactory {
    private final ObjectMapper objectMapper;
    private final BuiltinMcpToolRegistry builtinMcpToolRegistry;
    private final McpClientFactory mcpClientFactory;
    /**
     * 把一组 MCP 挂载记录解析成一次对话可直接使用的运行时对象。
     * 该方法统一处理内置 MCP、远程 SSE MCP 和本地 STDIO MCP,
     * 同时收集挂载成功项、失败项以及最终暴露给模型的工具回调列表。
     */
    @Override
    public McpMountRuntime create(List<AiMcpMount> mounts, Long userId) {
        List<McpSyncClient> clients = new ArrayList<>();
@@ -55,7 +50,7 @@
                    mountedNames.add(mount.getName());
                    continue;
                }
                McpSyncClient client = createClient(mount);
                McpSyncClient client = mcpClientFactory.createClient(mount);
                client.initialize();
                client.listTools();
                clients.add(client);
@@ -75,6 +70,7 @@
    }
    private List<ToolCallback> wrapMountedCallbacks(List<ToolCallback> source, String mountName) {
        /** 为每个工具回调补上挂载来源,便于后续审计、观测和前端工具轨迹展示。 */
        List<ToolCallback> mountedCallbacks = new ArrayList<>();
        for (ToolCallback callback : source) {
            if (callback == null) {
@@ -86,6 +82,7 @@
    }
    private void ensureUniqueToolNames(List<ToolCallback> callbacks) {
        /** 确保多挂载聚合后不会出现同名工具,否则模型侧无法正确分辨工具定义。 */
        LinkedHashSet<String> duplicateNames = new LinkedHashSet<>();
        LinkedHashSet<String> seenNames = new LinkedHashSet<>();
        for (ToolCallback callback : callbacks) {
@@ -105,78 +102,6 @@
        }
    }
    private McpSyncClient createClient(AiMcpMount mount) {
        Duration timeout = Duration.ofMillis(mount.getRequestTimeoutMs() == null
                ? AiDefaults.DEFAULT_TIMEOUT_MS
                : mount.getRequestTimeoutMs());
        JacksonMcpJsonMapper jsonMapper = new JacksonMcpJsonMapper(objectMapper);
        if (AiDefaults.MCP_TRANSPORT_STDIO.equals(mount.getTransportType())) {
            ServerParameters.Builder parametersBuilder = ServerParameters.builder(mount.getCommand());
            List<String> args = readStringList(mount.getArgsJson());
            if (!args.isEmpty()) {
                parametersBuilder.args(args);
            }
            Map<String, String> env = readStringMap(mount.getEnvJson());
            if (!env.isEmpty()) {
                parametersBuilder.env(env);
            }
            StdioClientTransport transport = new StdioClientTransport(parametersBuilder.build(), jsonMapper);
            transport.setStdErrorHandler(message -> log.warn("MCP STDIO stderr [{}]: {}", mount.getName(), message));
            return McpClient.sync(transport)
                    .requestTimeout(timeout)
                    .initializationTimeout(timeout)
                    .clientInfo(new McpSchema.Implementation("rsf-ai-client", "RSF AI Client", "1.0.0"))
                    .build();
        }
        if (!AiDefaults.MCP_TRANSPORT_SSE_HTTP.equals(mount.getTransportType())) {
            throw new CoolException("不支持的 MCP 传输类型: " + mount.getTransportType());
        }
        if (!StringUtils.hasText(mount.getServerUrl())) {
            throw new CoolException("MCP 服务地址不能为空");
        }
        HttpClientSseClientTransport.Builder transportBuilder = HttpClientSseClientTransport.builder(mount.getServerUrl())
                .jsonMapper(jsonMapper)
                .connectTimeout(timeout);
        if (StringUtils.hasText(mount.getEndpoint())) {
            transportBuilder.sseEndpoint(mount.getEndpoint());
        }
        Map<String, String> headers = readStringMap(mount.getHeadersJson());
        if (!headers.isEmpty()) {
            transportBuilder.customizeRequest(builder -> headers.forEach(builder::header));
        }
        return McpClient.sync(transportBuilder.build())
                .requestTimeout(timeout)
                .initializationTimeout(timeout)
                .clientInfo(new McpSchema.Implementation("rsf-ai-client", "RSF AI Client", "1.0.0"))
                .build();
    }
    private List<String> readStringList(String json) {
        if (!StringUtils.hasText(json)) {
            return Collections.emptyList();
        }
        try {
            return objectMapper.readValue(json, new TypeReference<List<String>>() {
            });
        } catch (Exception e) {
            throw new CoolException("解析 MCP 列表配置失败: " + e.getMessage());
        }
    }
    private Map<String, String> readStringMap(String json) {
        if (!StringUtils.hasText(json)) {
            return Collections.emptyMap();
        }
        try {
            Map<String, String> result = objectMapper.readValue(json, new TypeReference<LinkedHashMap<String, String>>() {
            });
            return result == null ? Collections.emptyMap() : result;
        } catch (Exception e) {
            throw new CoolException("解析 MCP Map 配置失败: " + e.getMessage());
        }
    }
    private static class DefaultMcpMountRuntime implements McpMountRuntime {
        private final List<McpSyncClient> clients;
@@ -184,6 +109,7 @@
        private final List<String> mountedNames;
        private final List<String> errors;
        /** 运行时对象本身只做数据封装和资源释放,不引入额外业务逻辑。 */
        private DefaultMcpMountRuntime(List<McpSyncClient> clients, ToolCallback[] callbacks, List<String> mountedNames, List<String> errors) {
            this.clients = clients;
            this.callbacks = callbacks;
@@ -213,6 +139,7 @@
        @Override
        public void close() {
            /** 统一关闭本次运行时里创建的外部 MCP Client,避免连接泄漏。 */
            for (McpSyncClient client : clients) {
                try {
                    client.close();
@@ -228,6 +155,7 @@
        private final ToolCallback delegate;
        private final String mountName;
        /** 装饰器仅补充挂载来源,不改变底层工具定义和调用行为。 */
        private MountedToolCallbackImpl(ToolCallback delegate, String mountName) {
            this.delegate = delegate;
            this.mountName = mountName;