zhou zhou
10 小时以前 80a6d9236ade191a5de0975abe4de5a6e7e63915
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/McpMountRuntimeFactoryImpl.java
@@ -6,6 +6,7 @@
import com.vincent.rsf.server.ai.config.AiDefaults;
import com.vincent.rsf.server.ai.entity.AiMcpMount;
import com.vincent.rsf.server.ai.service.BuiltinMcpToolRegistry;
import com.vincent.rsf.server.ai.service.MountedToolCallback;
import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory;
import io.modelcontextprotocol.client.McpClient;
import io.modelcontextprotocol.client.McpSyncClient;
@@ -38,6 +39,11 @@
    private final ObjectMapper objectMapper;
    private final BuiltinMcpToolRegistry builtinMcpToolRegistry;
    /**
     * 把一组 MCP 挂载记录解析成一次对话可直接使用的运行时对象。
     * 该方法统一处理内置 MCP、远程 SSE MCP 和本地 STDIO MCP,
     * 同时收集挂载成功项、失败项以及最终暴露给模型的工具回调列表。
     */
    @Override
    public McpMountRuntime create(List<AiMcpMount> mounts, Long userId) {
        List<McpSyncClient> clients = new ArrayList<>();
@@ -47,7 +53,10 @@
        for (AiMcpMount mount : mounts) {
            try {
                if (AiDefaults.MCP_TRANSPORT_BUILTIN.equals(mount.getTransportType())) {
                    callbacks.addAll(builtinMcpToolRegistry.createToolCallbacks(mount, userId));
                    callbacks.addAll(wrapMountedCallbacks(
                            builtinMcpToolRegistry.createToolCallbacks(mount, userId),
                            mount.getName()
                    ));
                    mountedNames.add(mount.getName());
                    continue;
                }
@@ -55,6 +64,10 @@
                client.initialize();
                client.listTools();
                clients.add(client);
                callbacks.addAll(wrapMountedCallbacks(
                        Arrays.asList(SyncMcpToolCallbackProvider.builder().mcpClients(List.of(client)).build().getToolCallbacks()),
                        mount.getName()
                ));
                mountedNames.add(mount.getName());
            } catch (Exception e) {
                String message = mount.getName() + " 挂载失败: " + e.getMessage();
@@ -62,16 +75,24 @@
                log.warn(message, e);
            }
        }
        if (!clients.isEmpty()) {
            callbacks.addAll(Arrays.asList(
                    SyncMcpToolCallbackProvider.builder().mcpClients(clients).build().getToolCallbacks()
            ));
        }
        ensureUniqueToolNames(callbacks);
        return new DefaultMcpMountRuntime(clients, callbacks.toArray(new ToolCallback[0]), mountedNames, errors);
    }
    private List<ToolCallback> wrapMountedCallbacks(List<ToolCallback> source, String mountName) {
        /** 为每个工具回调补上挂载来源,便于后续审计、观测和前端工具轨迹展示。 */
        List<ToolCallback> mountedCallbacks = new ArrayList<>();
        for (ToolCallback callback : source) {
            if (callback == null) {
                continue;
            }
            mountedCallbacks.add(new MountedToolCallbackImpl(callback, mountName));
        }
        return mountedCallbacks;
    }
    private void ensureUniqueToolNames(List<ToolCallback> callbacks) {
        /** 确保多挂载聚合后不会出现同名工具,否则模型侧无法正确分辨工具定义。 */
        LinkedHashSet<String> duplicateNames = new LinkedHashSet<>();
        LinkedHashSet<String> seenNames = new LinkedHashSet<>();
        for (ToolCallback callback : callbacks) {
@@ -92,6 +113,10 @@
    }
    private McpSyncClient createClient(AiMcpMount mount) {
        /**
         * 按挂载配置动态创建 MCP Client。
         * 该方法只负责 transport 层初始化,不负责工具去重和错误聚合。
         */
        Duration timeout = Duration.ofMillis(mount.getRequestTimeoutMs() == null
                ? AiDefaults.DEFAULT_TIMEOUT_MS
                : mount.getRequestTimeoutMs());
@@ -139,6 +164,7 @@
    }
    private List<String> readStringList(String json) {
        /** 解析挂载表里的 JSON 数组配置,例如 STDIO args。 */
        if (!StringUtils.hasText(json)) {
            return Collections.emptyList();
        }
@@ -151,6 +177,7 @@
    }
    private Map<String, String> readStringMap(String json) {
        /** 解析挂载表里的 JSON Map 配置,例如 headers 或环境变量。 */
        if (!StringUtils.hasText(json)) {
            return Collections.emptyMap();
        }
@@ -170,6 +197,7 @@
        private final List<String> mountedNames;
        private final List<String> errors;
        /** 运行时对象本身只做数据封装和资源释放,不引入额外业务逻辑。 */
        private DefaultMcpMountRuntime(List<McpSyncClient> clients, ToolCallback[] callbacks, List<String> mountedNames, List<String> errors) {
            this.clients = clients;
            this.callbacks = callbacks;
@@ -199,6 +227,7 @@
        @Override
        public void close() {
            /** 统一关闭本次运行时里创建的外部 MCP Client,避免连接泄漏。 */
            for (McpSyncClient client : clients) {
                try {
                    client.close();
@@ -208,4 +237,41 @@
            }
        }
    }
    private static class MountedToolCallbackImpl implements MountedToolCallback {
        private final ToolCallback delegate;
        private final String mountName;
        /** 装饰器仅补充挂载来源,不改变底层工具定义和调用行为。 */
        private MountedToolCallbackImpl(ToolCallback delegate, String mountName) {
            this.delegate = delegate;
            this.mountName = mountName;
        }
        @Override
        public String getMountName() {
            return mountName;
        }
        @Override
        public org.springframework.ai.tool.definition.ToolDefinition getToolDefinition() {
            return delegate.getToolDefinition();
        }
        @Override
        public org.springframework.ai.tool.metadata.ToolMetadata getToolMetadata() {
            return delegate.getToolMetadata();
        }
        @Override
        public String call(String toolInput) {
            return delegate.call(toolInput);
        }
        @Override
        public String call(String toolInput, org.springframework.ai.chat.model.ToolContext toolContext) {
            return delegate.call(toolInput, toolContext);
        }
    }
}