package com.vincent.rsf.server.ai.service.impl; import com.vincent.rsf.framework.exception.CoolException; import com.vincent.rsf.server.ai.config.AiDefaults; import com.vincent.rsf.server.ai.entity.AiMcpMount; import com.vincent.rsf.server.ai.service.BuiltinMcpToolRegistry; import com.vincent.rsf.server.ai.service.MountedToolCallback; import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory; import io.modelcontextprotocol.client.McpSyncClient; import com.vincent.rsf.server.ai.service.impl.mcp.McpClientFactory; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.mcp.SyncMcpToolCallbackProvider; import org.springframework.ai.tool.ToolCallback; import org.springframework.stereotype.Service; import org.springframework.util.StringUtils; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.LinkedHashSet; import java.util.List; @Slf4j @Service @RequiredArgsConstructor public class McpMountRuntimeFactoryImpl implements McpMountRuntimeFactory { private final BuiltinMcpToolRegistry builtinMcpToolRegistry; private final McpClientFactory mcpClientFactory; /** * 把一组 MCP 挂载记录解析成一次对话可直接使用的运行时对象。 * 该方法统一处理内置 MCP、远程 SSE MCP 和本地 STDIO MCP, * 同时收集挂载成功项、失败项以及最终暴露给模型的工具回调列表。 */ @Override public McpMountRuntime create(List mounts, Long userId) { List clients = new ArrayList<>(); List callbacks = new ArrayList<>(); List mountedNames = new ArrayList<>(); List errors = new ArrayList<>(); for (AiMcpMount mount : mounts) { try { if (AiDefaults.MCP_TRANSPORT_BUILTIN.equals(mount.getTransportType())) { callbacks.addAll(wrapMountedCallbacks( builtinMcpToolRegistry.createToolCallbacks(mount, userId), mount.getName() )); mountedNames.add(mount.getName()); continue; } McpSyncClient client = mcpClientFactory.createClient(mount); client.initialize(); client.listTools(); clients.add(client); callbacks.addAll(wrapMountedCallbacks( Arrays.asList(SyncMcpToolCallbackProvider.builder().mcpClients(List.of(client)).build().getToolCallbacks()), mount.getName() )); mountedNames.add(mount.getName()); } catch (Exception e) { String message = mount.getName() + " 挂载失败: " + e.getMessage(); errors.add(message); log.warn(message, e); } } ensureUniqueToolNames(callbacks); return new DefaultMcpMountRuntime(clients, callbacks.toArray(new ToolCallback[0]), mountedNames, errors); } private List wrapMountedCallbacks(List source, String mountName) { /** 为每个工具回调补上挂载来源,便于后续审计、观测和前端工具轨迹展示。 */ List mountedCallbacks = new ArrayList<>(); for (ToolCallback callback : source) { if (callback == null) { continue; } mountedCallbacks.add(new MountedToolCallbackImpl(callback, mountName)); } return mountedCallbacks; } private void ensureUniqueToolNames(List callbacks) { /** 确保多挂载聚合后不会出现同名工具,否则模型侧无法正确分辨工具定义。 */ LinkedHashSet duplicateNames = new LinkedHashSet<>(); LinkedHashSet seenNames = new LinkedHashSet<>(); for (ToolCallback callback : callbacks) { if (callback == null || callback.getToolDefinition() == null) { continue; } String name = callback.getToolDefinition().name(); if (!StringUtils.hasText(name)) { continue; } if (!seenNames.add(name)) { duplicateNames.add(name); } } if (!duplicateNames.isEmpty()) { throw new CoolException("MCP 工具名称重复,请调整挂载配置: " + String.join(", ", duplicateNames)); } } private static class DefaultMcpMountRuntime implements McpMountRuntime { private final List clients; private final ToolCallback[] callbacks; private final List mountedNames; private final List errors; /** 运行时对象本身只做数据封装和资源释放,不引入额外业务逻辑。 */ private DefaultMcpMountRuntime(List clients, ToolCallback[] callbacks, List mountedNames, List errors) { this.clients = clients; this.callbacks = callbacks; this.mountedNames = mountedNames; this.errors = errors; } @Override public ToolCallback[] getToolCallbacks() { return callbacks; } @Override public List getMountedNames() { return mountedNames; } @Override public List getErrors() { return errors; } @Override public int getMountedCount() { return mountedNames.size(); } @Override public void close() { /** 统一关闭本次运行时里创建的外部 MCP Client,避免连接泄漏。 */ for (McpSyncClient client : clients) { try { client.close(); } catch (Exception e) { log.warn("关闭 MCP Client 失败", e); } } } } private static class MountedToolCallbackImpl implements MountedToolCallback { private final ToolCallback delegate; private final String mountName; /** 装饰器仅补充挂载来源,不改变底层工具定义和调用行为。 */ private MountedToolCallbackImpl(ToolCallback delegate, String mountName) { this.delegate = delegate; this.mountName = mountName; } @Override public String getMountName() { return mountName; } @Override public org.springframework.ai.tool.definition.ToolDefinition getToolDefinition() { return delegate.getToolDefinition(); } @Override public org.springframework.ai.tool.metadata.ToolMetadata getToolMetadata() { return delegate.getToolMetadata(); } @Override public String call(String toolInput) { return delegate.call(toolInput); } @Override public String call(String toolInput, org.springframework.ai.chat.model.ToolContext toolContext) { return delegate.call(toolInput, toolContext); } } }