package com.vincent.rsf.server.ai.service.impl; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.vincent.rsf.framework.exception.CoolException; import com.vincent.rsf.server.ai.config.AiDefaults; import com.vincent.rsf.server.ai.entity.AiMcpMount; import com.vincent.rsf.server.ai.service.BuiltinMcpToolRegistry; import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory; import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.McpSyncClient; import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; import io.modelcontextprotocol.client.transport.ServerParameters; import io.modelcontextprotocol.client.transport.StdioClientTransport; import io.modelcontextprotocol.json.jackson.JacksonMcpJsonMapper; import io.modelcontextprotocol.spec.McpSchema; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.mcp.SyncMcpToolCallbackProvider; import org.springframework.ai.tool.ToolCallback; import org.springframework.stereotype.Service; import org.springframework.util.StringUtils; import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.LinkedHashSet; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @Slf4j @Service @RequiredArgsConstructor public class McpMountRuntimeFactoryImpl implements McpMountRuntimeFactory { private final ObjectMapper objectMapper; private final BuiltinMcpToolRegistry builtinMcpToolRegistry; @Override public McpMountRuntime create(List mounts, Long userId) { List clients = new ArrayList<>(); List callbacks = new ArrayList<>(); List mountedNames = new ArrayList<>(); List errors = new ArrayList<>(); for (AiMcpMount mount : mounts) { try { if (AiDefaults.MCP_TRANSPORT_BUILTIN.equals(mount.getTransportType())) { callbacks.addAll(builtinMcpToolRegistry.createToolCallbacks(mount, userId)); mountedNames.add(mount.getName()); continue; } McpSyncClient client = createClient(mount); client.initialize(); client.listTools(); clients.add(client); mountedNames.add(mount.getName()); } catch (Exception e) { String message = mount.getName() + " 挂载失败: " + e.getMessage(); errors.add(message); log.warn(message, e); } } if (!clients.isEmpty()) { callbacks.addAll(Arrays.asList( SyncMcpToolCallbackProvider.builder().mcpClients(clients).build().getToolCallbacks() )); } ensureUniqueToolNames(callbacks); return new DefaultMcpMountRuntime(clients, callbacks.toArray(new ToolCallback[0]), mountedNames, errors); } private void ensureUniqueToolNames(List callbacks) { LinkedHashSet duplicateNames = new LinkedHashSet<>(); LinkedHashSet seenNames = new LinkedHashSet<>(); for (ToolCallback callback : callbacks) { if (callback == null || callback.getToolDefinition() == null) { continue; } String name = callback.getToolDefinition().name(); if (!StringUtils.hasText(name)) { continue; } if (!seenNames.add(name)) { duplicateNames.add(name); } } if (!duplicateNames.isEmpty()) { throw new CoolException("MCP 工具名称重复,请调整挂载配置: " + String.join(", ", duplicateNames)); } } private McpSyncClient createClient(AiMcpMount mount) { Duration timeout = Duration.ofMillis(mount.getRequestTimeoutMs() == null ? AiDefaults.DEFAULT_TIMEOUT_MS : mount.getRequestTimeoutMs()); JacksonMcpJsonMapper jsonMapper = new JacksonMcpJsonMapper(objectMapper); if (AiDefaults.MCP_TRANSPORT_STDIO.equals(mount.getTransportType())) { ServerParameters.Builder parametersBuilder = ServerParameters.builder(mount.getCommand()); List args = readStringList(mount.getArgsJson()); if (!args.isEmpty()) { parametersBuilder.args(args); } Map env = readStringMap(mount.getEnvJson()); if (!env.isEmpty()) { parametersBuilder.env(env); } StdioClientTransport transport = new StdioClientTransport(parametersBuilder.build(), jsonMapper); transport.setStdErrorHandler(message -> log.warn("MCP STDIO stderr [{}]: {}", mount.getName(), message)); return McpClient.sync(transport) .requestTimeout(timeout) .initializationTimeout(timeout) .clientInfo(new McpSchema.Implementation("rsf-ai-client", "RSF AI Client", "1.0.0")) .build(); } if (!AiDefaults.MCP_TRANSPORT_SSE_HTTP.equals(mount.getTransportType())) { throw new CoolException("不支持的 MCP 传输类型: " + mount.getTransportType()); } if (!StringUtils.hasText(mount.getServerUrl())) { throw new CoolException("MCP 服务地址不能为空"); } HttpClientSseClientTransport.Builder transportBuilder = HttpClientSseClientTransport.builder(mount.getServerUrl()) .jsonMapper(jsonMapper) .connectTimeout(timeout); if (StringUtils.hasText(mount.getEndpoint())) { transportBuilder.sseEndpoint(mount.getEndpoint()); } Map headers = readStringMap(mount.getHeadersJson()); if (!headers.isEmpty()) { transportBuilder.customizeRequest(builder -> headers.forEach(builder::header)); } return McpClient.sync(transportBuilder.build()) .requestTimeout(timeout) .initializationTimeout(timeout) .clientInfo(new McpSchema.Implementation("rsf-ai-client", "RSF AI Client", "1.0.0")) .build(); } private List readStringList(String json) { if (!StringUtils.hasText(json)) { return Collections.emptyList(); } try { return objectMapper.readValue(json, new TypeReference>() { }); } catch (Exception e) { throw new CoolException("解析 MCP 列表配置失败: " + e.getMessage()); } } private Map readStringMap(String json) { if (!StringUtils.hasText(json)) { return Collections.emptyMap(); } try { Map result = objectMapper.readValue(json, new TypeReference>() { }); return result == null ? Collections.emptyMap() : result; } catch (Exception e) { throw new CoolException("解析 MCP Map 配置失败: " + e.getMessage()); } } private static class DefaultMcpMountRuntime implements McpMountRuntime { private final List clients; private final ToolCallback[] callbacks; private final List mountedNames; private final List errors; private DefaultMcpMountRuntime(List clients, ToolCallback[] callbacks, List mountedNames, List errors) { this.clients = clients; this.callbacks = callbacks; this.mountedNames = mountedNames; this.errors = errors; } @Override public ToolCallback[] getToolCallbacks() { return callbacks; } @Override public List getMountedNames() { return mountedNames; } @Override public List getErrors() { return errors; } @Override public int getMountedCount() { return mountedNames.size(); } @Override public void close() { for (McpSyncClient client : clients) { try { client.close(); } catch (Exception e) { log.warn("关闭 MCP Client 失败", e); } } } } }