package com.vincent.rsf.server.ai.service.impl;
|
|
import com.fasterxml.jackson.core.type.TypeReference;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.vincent.rsf.framework.exception.CoolException;
|
import com.vincent.rsf.server.ai.config.AiDefaults;
|
import com.vincent.rsf.server.ai.entity.AiMcpMount;
|
import com.vincent.rsf.server.ai.service.BuiltinMcpToolRegistry;
|
import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory;
|
import io.modelcontextprotocol.client.McpClient;
|
import io.modelcontextprotocol.client.McpSyncClient;
|
import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
|
import io.modelcontextprotocol.client.transport.ServerParameters;
|
import io.modelcontextprotocol.client.transport.StdioClientTransport;
|
import io.modelcontextprotocol.json.jackson.JacksonMcpJsonMapper;
|
import io.modelcontextprotocol.spec.McpSchema;
|
import lombok.RequiredArgsConstructor;
|
import lombok.extern.slf4j.Slf4j;
|
import org.springframework.ai.mcp.SyncMcpToolCallbackProvider;
|
import org.springframework.ai.tool.ToolCallback;
|
import org.springframework.stereotype.Service;
|
import org.springframework.util.StringUtils;
|
|
import java.time.Duration;
|
import java.util.ArrayList;
|
import java.util.Arrays;
|
import java.util.Collections;
|
import java.util.LinkedHashSet;
|
import java.util.LinkedHashMap;
|
import java.util.List;
|
import java.util.Map;
|
|
@Slf4j
|
@Service
|
@RequiredArgsConstructor
|
public class McpMountRuntimeFactoryImpl implements McpMountRuntimeFactory {
|
|
private final ObjectMapper objectMapper;
|
private final BuiltinMcpToolRegistry builtinMcpToolRegistry;
|
|
@Override
|
public McpMountRuntime create(List<AiMcpMount> mounts, Long userId) {
|
List<McpSyncClient> clients = new ArrayList<>();
|
List<ToolCallback> callbacks = new ArrayList<>();
|
List<String> mountedNames = new ArrayList<>();
|
List<String> errors = new ArrayList<>();
|
for (AiMcpMount mount : mounts) {
|
try {
|
if (AiDefaults.MCP_TRANSPORT_BUILTIN.equals(mount.getTransportType())) {
|
callbacks.addAll(builtinMcpToolRegistry.createToolCallbacks(mount, userId));
|
mountedNames.add(mount.getName());
|
continue;
|
}
|
McpSyncClient client = createClient(mount);
|
client.initialize();
|
client.listTools();
|
clients.add(client);
|
mountedNames.add(mount.getName());
|
} catch (Exception e) {
|
String message = mount.getName() + " 挂载失败: " + e.getMessage();
|
errors.add(message);
|
log.warn(message, e);
|
}
|
}
|
if (!clients.isEmpty()) {
|
callbacks.addAll(Arrays.asList(
|
SyncMcpToolCallbackProvider.builder().mcpClients(clients).build().getToolCallbacks()
|
));
|
}
|
ensureUniqueToolNames(callbacks);
|
return new DefaultMcpMountRuntime(clients, callbacks.toArray(new ToolCallback[0]), mountedNames, errors);
|
}
|
|
private void ensureUniqueToolNames(List<ToolCallback> callbacks) {
|
LinkedHashSet<String> duplicateNames = new LinkedHashSet<>();
|
LinkedHashSet<String> seenNames = new LinkedHashSet<>();
|
for (ToolCallback callback : callbacks) {
|
if (callback == null || callback.getToolDefinition() == null) {
|
continue;
|
}
|
String name = callback.getToolDefinition().name();
|
if (!StringUtils.hasText(name)) {
|
continue;
|
}
|
if (!seenNames.add(name)) {
|
duplicateNames.add(name);
|
}
|
}
|
if (!duplicateNames.isEmpty()) {
|
throw new CoolException("MCP 工具名称重复,请调整挂载配置: " + String.join(", ", duplicateNames));
|
}
|
}
|
|
private McpSyncClient createClient(AiMcpMount mount) {
|
Duration timeout = Duration.ofMillis(mount.getRequestTimeoutMs() == null
|
? AiDefaults.DEFAULT_TIMEOUT_MS
|
: mount.getRequestTimeoutMs());
|
JacksonMcpJsonMapper jsonMapper = new JacksonMcpJsonMapper(objectMapper);
|
if (AiDefaults.MCP_TRANSPORT_STDIO.equals(mount.getTransportType())) {
|
ServerParameters.Builder parametersBuilder = ServerParameters.builder(mount.getCommand());
|
List<String> args = readStringList(mount.getArgsJson());
|
if (!args.isEmpty()) {
|
parametersBuilder.args(args);
|
}
|
Map<String, String> env = readStringMap(mount.getEnvJson());
|
if (!env.isEmpty()) {
|
parametersBuilder.env(env);
|
}
|
StdioClientTransport transport = new StdioClientTransport(parametersBuilder.build(), jsonMapper);
|
transport.setStdErrorHandler(message -> log.warn("MCP STDIO stderr [{}]: {}", mount.getName(), message));
|
return McpClient.sync(transport)
|
.requestTimeout(timeout)
|
.initializationTimeout(timeout)
|
.clientInfo(new McpSchema.Implementation("rsf-ai-client", "RSF AI Client", "1.0.0"))
|
.build();
|
}
|
if (!AiDefaults.MCP_TRANSPORT_SSE_HTTP.equals(mount.getTransportType())) {
|
throw new CoolException("不支持的 MCP 传输类型: " + mount.getTransportType());
|
}
|
|
if (!StringUtils.hasText(mount.getServerUrl())) {
|
throw new CoolException("MCP 服务地址不能为空");
|
}
|
HttpClientSseClientTransport.Builder transportBuilder = HttpClientSseClientTransport.builder(mount.getServerUrl())
|
.jsonMapper(jsonMapper)
|
.connectTimeout(timeout);
|
if (StringUtils.hasText(mount.getEndpoint())) {
|
transportBuilder.sseEndpoint(mount.getEndpoint());
|
}
|
Map<String, String> headers = readStringMap(mount.getHeadersJson());
|
if (!headers.isEmpty()) {
|
transportBuilder.customizeRequest(builder -> headers.forEach(builder::header));
|
}
|
return McpClient.sync(transportBuilder.build())
|
.requestTimeout(timeout)
|
.initializationTimeout(timeout)
|
.clientInfo(new McpSchema.Implementation("rsf-ai-client", "RSF AI Client", "1.0.0"))
|
.build();
|
}
|
|
private List<String> readStringList(String json) {
|
if (!StringUtils.hasText(json)) {
|
return Collections.emptyList();
|
}
|
try {
|
return objectMapper.readValue(json, new TypeReference<List<String>>() {
|
});
|
} catch (Exception e) {
|
throw new CoolException("解析 MCP 列表配置失败: " + e.getMessage());
|
}
|
}
|
|
private Map<String, String> readStringMap(String json) {
|
if (!StringUtils.hasText(json)) {
|
return Collections.emptyMap();
|
}
|
try {
|
Map<String, String> result = objectMapper.readValue(json, new TypeReference<LinkedHashMap<String, String>>() {
|
});
|
return result == null ? Collections.emptyMap() : result;
|
} catch (Exception e) {
|
throw new CoolException("解析 MCP Map 配置失败: " + e.getMessage());
|
}
|
}
|
|
private static class DefaultMcpMountRuntime implements McpMountRuntime {
|
|
private final List<McpSyncClient> clients;
|
private final ToolCallback[] callbacks;
|
private final List<String> mountedNames;
|
private final List<String> errors;
|
|
private DefaultMcpMountRuntime(List<McpSyncClient> clients, ToolCallback[] callbacks, List<String> mountedNames, List<String> errors) {
|
this.clients = clients;
|
this.callbacks = callbacks;
|
this.mountedNames = mountedNames;
|
this.errors = errors;
|
}
|
|
@Override
|
public ToolCallback[] getToolCallbacks() {
|
return callbacks;
|
}
|
|
@Override
|
public List<String> getMountedNames() {
|
return mountedNames;
|
}
|
|
@Override
|
public List<String> getErrors() {
|
return errors;
|
}
|
|
@Override
|
public int getMountedCount() {
|
return mountedNames.size();
|
}
|
|
@Override
|
public void close() {
|
for (McpSyncClient client : clients) {
|
try {
|
client.close();
|
} catch (Exception e) {
|
log.warn("关闭 MCP Client 失败", e);
|
}
|
}
|
}
|
}
|
}
|