| | |
| | | import com.vincent.rsf.server.ai.config.AiDefaults; |
| | | import com.vincent.rsf.server.ai.entity.AiMcpMount; |
| | | import com.vincent.rsf.server.ai.service.BuiltinMcpToolRegistry; |
| | | import com.vincent.rsf.server.ai.service.MountedToolCallback; |
| | | import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory; |
| | | import io.modelcontextprotocol.client.McpClient; |
| | | import io.modelcontextprotocol.client.McpSyncClient; |
| | |
| | | for (AiMcpMount mount : mounts) { |
| | | try { |
| | | if (AiDefaults.MCP_TRANSPORT_BUILTIN.equals(mount.getTransportType())) { |
| | | callbacks.addAll(builtinMcpToolRegistry.createToolCallbacks(mount, userId)); |
| | | callbacks.addAll(wrapMountedCallbacks( |
| | | builtinMcpToolRegistry.createToolCallbacks(mount, userId), |
| | | mount.getName() |
| | | )); |
| | | mountedNames.add(mount.getName()); |
| | | continue; |
| | | } |
| | |
| | | client.initialize(); |
| | | client.listTools(); |
| | | clients.add(client); |
| | | callbacks.addAll(wrapMountedCallbacks( |
| | | Arrays.asList(SyncMcpToolCallbackProvider.builder().mcpClients(List.of(client)).build().getToolCallbacks()), |
| | | mount.getName() |
| | | )); |
| | | mountedNames.add(mount.getName()); |
| | | } catch (Exception e) { |
| | | String message = mount.getName() + " 挂载失败: " + e.getMessage(); |
| | |
| | | log.warn(message, e); |
| | | } |
| | | } |
| | | if (!clients.isEmpty()) { |
| | | callbacks.addAll(Arrays.asList( |
| | | SyncMcpToolCallbackProvider.builder().mcpClients(clients).build().getToolCallbacks() |
| | | )); |
| | | } |
| | | ensureUniqueToolNames(callbacks); |
| | | return new DefaultMcpMountRuntime(clients, callbacks.toArray(new ToolCallback[0]), mountedNames, errors); |
| | | } |
| | | |
| | | private List<ToolCallback> wrapMountedCallbacks(List<ToolCallback> source, String mountName) { |
| | | List<ToolCallback> mountedCallbacks = new ArrayList<>(); |
| | | for (ToolCallback callback : source) { |
| | | if (callback == null) { |
| | | continue; |
| | | } |
| | | mountedCallbacks.add(new MountedToolCallbackImpl(callback, mountName)); |
| | | } |
| | | return mountedCallbacks; |
| | | } |
| | | |
| | | private void ensureUniqueToolNames(List<ToolCallback> callbacks) { |
| | |
| | | } |
| | | } |
| | | } |
| | | |
| | | private static class MountedToolCallbackImpl implements MountedToolCallback { |
| | | |
| | | private final ToolCallback delegate; |
| | | private final String mountName; |
| | | |
| | | private MountedToolCallbackImpl(ToolCallback delegate, String mountName) { |
| | | this.delegate = delegate; |
| | | this.mountName = mountName; |
| | | } |
| | | |
| | | @Override |
| | | public String getMountName() { |
| | | return mountName; |
| | | } |
| | | |
| | | @Override |
| | | public org.springframework.ai.tool.definition.ToolDefinition getToolDefinition() { |
| | | return delegate.getToolDefinition(); |
| | | } |
| | | |
| | | @Override |
| | | public org.springframework.ai.tool.metadata.ToolMetadata getToolMetadata() { |
| | | return delegate.getToolMetadata(); |
| | | } |
| | | |
| | | @Override |
| | | public String call(String toolInput) { |
| | | return delegate.call(toolInput); |
| | | } |
| | | |
| | | @Override |
| | | public String call(String toolInput, org.springframework.ai.chat.model.ToolContext toolContext) { |
| | | return delegate.call(toolInput, toolContext); |
| | | } |
| | | } |
| | | } |