package com.vincent.rsf.server.ai.service.impl; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.vincent.rsf.framework.exception.CoolException; import com.vincent.rsf.server.ai.config.AiDefaults; import com.vincent.rsf.server.ai.entity.AiMcpMount; import com.vincent.rsf.server.ai.service.BuiltinMcpToolRegistry; import com.vincent.rsf.server.ai.service.MountedToolCallback; import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory; import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.McpSyncClient; import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; import io.modelcontextprotocol.client.transport.ServerParameters; import io.modelcontextprotocol.client.transport.StdioClientTransport; import io.modelcontextprotocol.json.jackson.JacksonMcpJsonMapper; import io.modelcontextprotocol.spec.McpSchema; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.mcp.SyncMcpToolCallbackProvider; import org.springframework.ai.tool.ToolCallback; import org.springframework.stereotype.Service; import org.springframework.util.StringUtils; import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.LinkedHashSet; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @Slf4j @Service @RequiredArgsConstructor public class McpMountRuntimeFactoryImpl implements McpMountRuntimeFactory { private final ObjectMapper objectMapper; private final BuiltinMcpToolRegistry builtinMcpToolRegistry; /** * 把一组 MCP 挂载记录解析成一次对话可直接使用的运行时对象。 * 该方法统一处理内置 MCP、远程 SSE MCP 和本地 STDIO MCP, * 同时收集挂载成功项、失败项以及最终暴露给模型的工具回调列表。 */ @Override public McpMountRuntime create(List mounts, Long userId) { List clients = new ArrayList<>(); List callbacks = new ArrayList<>(); List mountedNames = new ArrayList<>(); List errors = new ArrayList<>(); for (AiMcpMount mount : mounts) { try { if (AiDefaults.MCP_TRANSPORT_BUILTIN.equals(mount.getTransportType())) { callbacks.addAll(wrapMountedCallbacks( builtinMcpToolRegistry.createToolCallbacks(mount, userId), mount.getName() )); mountedNames.add(mount.getName()); continue; } McpSyncClient client = createClient(mount); client.initialize(); client.listTools(); clients.add(client); callbacks.addAll(wrapMountedCallbacks( Arrays.asList(SyncMcpToolCallbackProvider.builder().mcpClients(List.of(client)).build().getToolCallbacks()), mount.getName() )); mountedNames.add(mount.getName()); } catch (Exception e) { String message = mount.getName() + " 挂载失败: " + e.getMessage(); errors.add(message); log.warn(message, e); } } ensureUniqueToolNames(callbacks); return new DefaultMcpMountRuntime(clients, callbacks.toArray(new ToolCallback[0]), mountedNames, errors); } private List wrapMountedCallbacks(List source, String mountName) { /** 为每个工具回调补上挂载来源,便于后续审计、观测和前端工具轨迹展示。 */ List mountedCallbacks = new ArrayList<>(); for (ToolCallback callback : source) { if (callback == null) { continue; } mountedCallbacks.add(new MountedToolCallbackImpl(callback, mountName)); } return mountedCallbacks; } private void ensureUniqueToolNames(List callbacks) { /** 确保多挂载聚合后不会出现同名工具,否则模型侧无法正确分辨工具定义。 */ LinkedHashSet duplicateNames = new LinkedHashSet<>(); LinkedHashSet seenNames = new LinkedHashSet<>(); for (ToolCallback callback : callbacks) { if (callback == null || callback.getToolDefinition() == null) { continue; } String name = callback.getToolDefinition().name(); if (!StringUtils.hasText(name)) { continue; } if (!seenNames.add(name)) { duplicateNames.add(name); } } if (!duplicateNames.isEmpty()) { throw new CoolException("MCP 工具名称重复,请调整挂载配置: " + String.join(", ", duplicateNames)); } } private McpSyncClient createClient(AiMcpMount mount) { /** * 按挂载配置动态创建 MCP Client。 * 该方法只负责 transport 层初始化,不负责工具去重和错误聚合。 */ Duration timeout = Duration.ofMillis(mount.getRequestTimeoutMs() == null ? AiDefaults.DEFAULT_TIMEOUT_MS : mount.getRequestTimeoutMs()); JacksonMcpJsonMapper jsonMapper = new JacksonMcpJsonMapper(objectMapper); if (AiDefaults.MCP_TRANSPORT_STDIO.equals(mount.getTransportType())) { ServerParameters.Builder parametersBuilder = ServerParameters.builder(mount.getCommand()); List args = readStringList(mount.getArgsJson()); if (!args.isEmpty()) { parametersBuilder.args(args); } Map env = readStringMap(mount.getEnvJson()); if (!env.isEmpty()) { parametersBuilder.env(env); } StdioClientTransport transport = new StdioClientTransport(parametersBuilder.build(), jsonMapper); transport.setStdErrorHandler(message -> log.warn("MCP STDIO stderr [{}]: {}", mount.getName(), message)); return McpClient.sync(transport) .requestTimeout(timeout) .initializationTimeout(timeout) .clientInfo(new McpSchema.Implementation("rsf-ai-client", "RSF AI Client", "1.0.0")) .build(); } if (!AiDefaults.MCP_TRANSPORT_SSE_HTTP.equals(mount.getTransportType())) { throw new CoolException("不支持的 MCP 传输类型: " + mount.getTransportType()); } if (!StringUtils.hasText(mount.getServerUrl())) { throw new CoolException("MCP 服务地址不能为空"); } HttpClientSseClientTransport.Builder transportBuilder = HttpClientSseClientTransport.builder(mount.getServerUrl()) .jsonMapper(jsonMapper) .connectTimeout(timeout); if (StringUtils.hasText(mount.getEndpoint())) { transportBuilder.sseEndpoint(mount.getEndpoint()); } Map headers = readStringMap(mount.getHeadersJson()); if (!headers.isEmpty()) { transportBuilder.customizeRequest(builder -> headers.forEach(builder::header)); } return McpClient.sync(transportBuilder.build()) .requestTimeout(timeout) .initializationTimeout(timeout) .clientInfo(new McpSchema.Implementation("rsf-ai-client", "RSF AI Client", "1.0.0")) .build(); } private List readStringList(String json) { /** 解析挂载表里的 JSON 数组配置,例如 STDIO args。 */ if (!StringUtils.hasText(json)) { return Collections.emptyList(); } try { return objectMapper.readValue(json, new TypeReference>() { }); } catch (Exception e) { throw new CoolException("解析 MCP 列表配置失败: " + e.getMessage()); } } private Map readStringMap(String json) { /** 解析挂载表里的 JSON Map 配置,例如 headers 或环境变量。 */ if (!StringUtils.hasText(json)) { return Collections.emptyMap(); } try { Map result = objectMapper.readValue(json, new TypeReference>() { }); return result == null ? Collections.emptyMap() : result; } catch (Exception e) { throw new CoolException("解析 MCP Map 配置失败: " + e.getMessage()); } } private static class DefaultMcpMountRuntime implements McpMountRuntime { private final List clients; private final ToolCallback[] callbacks; private final List mountedNames; private final List errors; /** 运行时对象本身只做数据封装和资源释放,不引入额外业务逻辑。 */ private DefaultMcpMountRuntime(List clients, ToolCallback[] callbacks, List mountedNames, List errors) { this.clients = clients; this.callbacks = callbacks; this.mountedNames = mountedNames; this.errors = errors; } @Override public ToolCallback[] getToolCallbacks() { return callbacks; } @Override public List getMountedNames() { return mountedNames; } @Override public List getErrors() { return errors; } @Override public int getMountedCount() { return mountedNames.size(); } @Override public void close() { /** 统一关闭本次运行时里创建的外部 MCP Client,避免连接泄漏。 */ for (McpSyncClient client : clients) { try { client.close(); } catch (Exception e) { log.warn("关闭 MCP Client 失败", e); } } } } private static class MountedToolCallbackImpl implements MountedToolCallback { private final ToolCallback delegate; private final String mountName; /** 装饰器仅补充挂载来源,不改变底层工具定义和调用行为。 */ private MountedToolCallbackImpl(ToolCallback delegate, String mountName) { this.delegate = delegate; this.mountName = mountName; } @Override public String getMountName() { return mountName; } @Override public org.springframework.ai.tool.definition.ToolDefinition getToolDefinition() { return delegate.getToolDefinition(); } @Override public org.springframework.ai.tool.metadata.ToolMetadata getToolMetadata() { return delegate.getToolMetadata(); } @Override public String call(String toolInput) { return delegate.call(toolInput); } @Override public String call(String toolInput, org.springframework.ai.chat.model.ToolContext toolContext) { return delegate.call(toolInput, toolContext); } } }