zhou zhou
7 小时以前 d5884d0974d17d96225a5d80e432de33a5ee6552
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/McpMountRuntimeFactoryImpl.java
@@ -6,6 +6,7 @@
import com.vincent.rsf.server.ai.config.AiDefaults;
import com.vincent.rsf.server.ai.entity.AiMcpMount;
import com.vincent.rsf.server.ai.service.BuiltinMcpToolRegistry;
import com.vincent.rsf.server.ai.service.MountedToolCallback;
import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory;
import io.modelcontextprotocol.client.McpClient;
import io.modelcontextprotocol.client.McpSyncClient;
@@ -47,7 +48,10 @@
        for (AiMcpMount mount : mounts) {
            try {
                if (AiDefaults.MCP_TRANSPORT_BUILTIN.equals(mount.getTransportType())) {
                    callbacks.addAll(builtinMcpToolRegistry.createToolCallbacks(mount, userId));
                    callbacks.addAll(wrapMountedCallbacks(
                            builtinMcpToolRegistry.createToolCallbacks(mount, userId),
                            mount.getName()
                    ));
                    mountedNames.add(mount.getName());
                    continue;
                }
@@ -55,6 +59,10 @@
                client.initialize();
                client.listTools();
                clients.add(client);
                callbacks.addAll(wrapMountedCallbacks(
                        Arrays.asList(SyncMcpToolCallbackProvider.builder().mcpClients(List.of(client)).build().getToolCallbacks()),
                        mount.getName()
                ));
                mountedNames.add(mount.getName());
            } catch (Exception e) {
                String message = mount.getName() + " 挂载失败: " + e.getMessage();
@@ -62,13 +70,19 @@
                log.warn(message, e);
            }
        }
        if (!clients.isEmpty()) {
            callbacks.addAll(Arrays.asList(
                    SyncMcpToolCallbackProvider.builder().mcpClients(clients).build().getToolCallbacks()
            ));
        }
        ensureUniqueToolNames(callbacks);
        return new DefaultMcpMountRuntime(clients, callbacks.toArray(new ToolCallback[0]), mountedNames, errors);
    }
    private List<ToolCallback> wrapMountedCallbacks(List<ToolCallback> source, String mountName) {
        List<ToolCallback> mountedCallbacks = new ArrayList<>();
        for (ToolCallback callback : source) {
            if (callback == null) {
                continue;
            }
            mountedCallbacks.add(new MountedToolCallbackImpl(callback, mountName));
        }
        return mountedCallbacks;
    }
    private void ensureUniqueToolNames(List<ToolCallback> callbacks) {
@@ -208,4 +222,40 @@
            }
        }
    }
    private static class MountedToolCallbackImpl implements MountedToolCallback {
        private final ToolCallback delegate;
        private final String mountName;
        private MountedToolCallbackImpl(ToolCallback delegate, String mountName) {
            this.delegate = delegate;
            this.mountName = mountName;
        }
        @Override
        public String getMountName() {
            return mountName;
        }
        @Override
        public org.springframework.ai.tool.definition.ToolDefinition getToolDefinition() {
            return delegate.getToolDefinition();
        }
        @Override
        public org.springframework.ai.tool.metadata.ToolMetadata getToolMetadata() {
            return delegate.getToolMetadata();
        }
        @Override
        public String call(String toolInput) {
            return delegate.call(toolInput);
        }
        @Override
        public String call(String toolInput, org.springframework.ai.chat.model.ToolContext toolContext) {
            return delegate.call(toolInput, toolContext);
        }
    }
}