package com.zy.ai.mcp.service;
|
|
import com.alibaba.fastjson.JSON;
|
import com.alibaba.fastjson.JSONObject;
|
import io.modelcontextprotocol.client.McpClient;
|
import io.modelcontextprotocol.client.McpSyncClient;
|
import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
|
import io.modelcontextprotocol.spec.McpSchema;
|
import lombok.extern.slf4j.Slf4j;
|
import org.springframework.ai.mcp.SyncMcpToolCallbackProvider;
|
import org.springframework.ai.tool.ToolCallback;
|
import org.springframework.ai.tool.ToolCallbackProvider;
|
import org.springframework.ai.tool.definition.ToolDefinition;
|
import org.springframework.beans.factory.annotation.Value;
|
import org.springframework.stereotype.Service;
|
|
import jakarta.annotation.PreDestroy;
|
import java.time.Duration;
|
import java.util.ArrayList;
|
import java.util.Collections;
|
import java.util.Comparator;
|
import java.util.LinkedHashMap;
|
import java.util.List;
|
import java.util.Map;
|
|
@Slf4j
|
@Service
|
public class SpringAiMcpToolManager {
|
|
private static final McpSchema.Implementation CLIENT_INFO =
|
new McpSchema.Implementation("wcs-ai-assistant", "1.0.0");
|
|
private final Object clientMonitor = new Object();
|
|
@Value("${server.port:9090}")
|
private Integer serverPort;
|
|
@Value("${server.servlet.context-path:}")
|
private String contextPath;
|
|
@Value("${spring.ai.mcp.server.sse-endpoint:/ai/mcp/sse}")
|
private String sseEndpoint;
|
|
@Value("${spring.ai.mcp.server.request-timeout:20s}")
|
private Duration requestTimeout;
|
|
@Value("${app.ai.mcp.client.base-url:}")
|
private String configuredBaseUrl;
|
|
private volatile ClientSession clientSession;
|
|
public List<Map<String, Object>> listTools() {
|
List<Map<String, Object>> tools = new ArrayList<Map<String, Object>>();
|
for (ToolCallback callback : getToolCallbacks()) {
|
if (callback == null || callback.getToolDefinition() == null) {
|
continue;
|
}
|
ToolDefinition definition = callback.getToolDefinition();
|
Map<String, Object> item = new LinkedHashMap<String, Object>();
|
item.put("name", definition.name());
|
item.put("description", definition.description());
|
item.put("inputSchema", parseSchema(definition.inputSchema()));
|
tools.add(item);
|
}
|
tools.sort(new Comparator<Map<String, Object>>() {
|
@Override
|
public int compare(Map<String, Object> left, Map<String, Object> right) {
|
return String.valueOf(left.get("name")).compareTo(String.valueOf(right.get("name")));
|
}
|
});
|
return tools;
|
}
|
|
public List<Object> buildOpenAiTools() {
|
List<Object> tools = new ArrayList<Object>();
|
for (Map<String, Object> item : listTools()) {
|
Object name = item.get("name");
|
if (name == null) {
|
continue;
|
}
|
|
Map<String, Object> function = new LinkedHashMap<String, Object>();
|
function.put("name", String.valueOf(name));
|
Object description = item.get("description");
|
if (description != null) {
|
function.put("description", String.valueOf(description));
|
}
|
Object inputSchema = item.get("inputSchema");
|
function.put("parameters", inputSchema == null ? new LinkedHashMap<String, Object>() : inputSchema);
|
|
Map<String, Object> tool = new LinkedHashMap<String, Object>();
|
tool.put("type", "function");
|
tool.put("function", function);
|
tools.add(tool);
|
}
|
return tools;
|
}
|
|
public Object callTool(String toolName, JSONObject arguments) {
|
if (toolName == null || toolName.trim().isEmpty()) {
|
throw new IllegalArgumentException("missing tool name");
|
}
|
|
ToolCallback callback = findCallback(toolName);
|
if (callback == null) {
|
throw new IllegalArgumentException("tool not found: " + toolName);
|
}
|
|
String rawResult = callback.call(arguments == null ? "{}" : arguments.toJSONString());
|
return parseToolResult(rawResult);
|
}
|
|
private ToolCallback findCallback(String toolName) {
|
for (ToolCallback callback : getToolCallbacks()) {
|
if (callback == null || callback.getToolDefinition() == null) {
|
continue;
|
}
|
if (toolName.equals(callback.getToolDefinition().name())) {
|
return callback;
|
}
|
}
|
return null;
|
}
|
|
private ToolCallback[] getToolCallbacks() {
|
try {
|
ToolCallback[] callbacks = ensureToolCallbackProvider().getToolCallbacks();
|
return callbacks == null ? new ToolCallback[0] : callbacks;
|
} catch (Exception e) {
|
log.warn("Failed to load MCP tools through SSE client, baseUrl={}, sseEndpoint={}",
|
resolveBaseUrl(), resolveClientSseEndpoint(), e);
|
resetClientSession();
|
return new ToolCallback[0];
|
}
|
}
|
|
private Object parseToolResult(String rawResult) {
|
if (rawResult == null || rawResult.trim().isEmpty()) {
|
return rawResult;
|
}
|
try {
|
return JSON.parse(rawResult);
|
} catch (Exception ignore) {
|
return rawResult;
|
}
|
}
|
|
@SuppressWarnings("unchecked")
|
private Map<String, Object> parseSchema(String inputSchema) {
|
if (inputSchema == null || inputSchema.trim().isEmpty()) {
|
return Collections.emptyMap();
|
}
|
try {
|
Object parsed = JSON.parse(inputSchema);
|
if (parsed instanceof Map) {
|
return new LinkedHashMap<String, Object>((Map<String, Object>) parsed);
|
}
|
} catch (Exception e) {
|
log.warn("Failed to parse MCP tool schema: {}", inputSchema, e);
|
}
|
Map<String, Object> fallback = new LinkedHashMap<String, Object>();
|
fallback.put("type", "object");
|
return fallback;
|
}
|
|
private ToolCallbackProvider ensureToolCallbackProvider() {
|
return ensureClientSession().toolCallbackProvider;
|
}
|
|
private ClientSession ensureClientSession() {
|
ClientSession current = clientSession;
|
if (current != null) {
|
return current;
|
}
|
|
synchronized (clientMonitor) {
|
current = clientSession;
|
if (current != null) {
|
return current;
|
}
|
|
String baseUrl = resolveBaseUrl();
|
String clientSseEndpoint = resolveClientSseEndpoint();
|
HttpClientSseClientTransport transport = HttpClientSseClientTransport.builder(baseUrl)
|
.sseEndpoint(clientSseEndpoint)
|
.connectTimeout(requestTimeout)
|
.build();
|
|
McpSyncClient syncClient = McpClient.sync(transport)
|
.clientInfo(CLIENT_INFO)
|
.requestTimeout(requestTimeout)
|
.initializationTimeout(requestTimeout)
|
.build();
|
syncClient.initialize();
|
|
SyncMcpToolCallbackProvider callbackProvider = new SyncMcpToolCallbackProvider(syncClient);
|
current = new ClientSession(syncClient, callbackProvider, baseUrl);
|
clientSession = current;
|
log.info("Spring AI MCP SSE client initialized, baseUrl={}, sseEndpoint={}, tools={}",
|
baseUrl, clientSseEndpoint, current.toolCallbackProvider.getToolCallbacks().length);
|
return current;
|
}
|
}
|
|
private void resetClientSession() {
|
synchronized (clientMonitor) {
|
ClientSession current = clientSession;
|
clientSession = null;
|
if (current != null) {
|
current.close();
|
}
|
}
|
}
|
|
private String resolveBaseUrl() {
|
if (configuredBaseUrl != null && !configuredBaseUrl.trim().isEmpty()) {
|
return trimTrailingSlash(configuredBaseUrl.trim());
|
}
|
StringBuilder url = new StringBuilder("http://127.0.0.1:");
|
url.append(serverPort == null ? 9090 : serverPort);
|
return trimTrailingSlash(url.toString());
|
}
|
|
private String resolveClientSseEndpoint() {
|
String endpoint = normalizePath(sseEndpoint);
|
if (configuredBaseUrl != null && !configuredBaseUrl.trim().isEmpty()) {
|
return endpoint;
|
}
|
String context = normalizeContextPath(contextPath);
|
if (context.isEmpty()) {
|
return endpoint;
|
}
|
return context + endpoint;
|
}
|
|
private String normalizeContextPath(String path) {
|
if (path == null || path.trim().isEmpty() || "/".equals(path.trim())) {
|
return "";
|
}
|
String value = path.trim();
|
if (!value.startsWith("/")) {
|
value = "/" + value;
|
}
|
return trimTrailingSlash(value);
|
}
|
|
private String normalizePath(String path) {
|
if (path == null || path.trim().isEmpty()) {
|
return "/";
|
}
|
String value = path.trim();
|
if (!value.startsWith("/")) {
|
value = "/" + value;
|
}
|
return value;
|
}
|
|
private String trimTrailingSlash(String value) {
|
if (value == null || value.isEmpty()) {
|
return "";
|
}
|
return value.endsWith("/") && value.length() > 1 ? value.substring(0, value.length() - 1) : value;
|
}
|
|
@PreDestroy
|
public void destroy() {
|
resetClientSession();
|
}
|
|
private static final class ClientSession implements AutoCloseable {
|
|
private final McpSyncClient syncClient;
|
private final ToolCallbackProvider toolCallbackProvider;
|
private final String baseUrl;
|
|
private ClientSession(McpSyncClient syncClient, ToolCallbackProvider toolCallbackProvider, String baseUrl) {
|
this.syncClient = syncClient;
|
this.toolCallbackProvider = toolCallbackProvider;
|
this.baseUrl = baseUrl;
|
}
|
|
@Override
|
public void close() {
|
try {
|
syncClient.closeGracefully();
|
} catch (Exception e) {
|
log.debug("Close MCP SSE client gracefully failed, baseUrl={}", baseUrl, e);
|
}
|
try {
|
syncClient.close();
|
} catch (Exception e) {
|
log.debug("Close MCP SSE client failed, baseUrl={}", baseUrl, e);
|
}
|
}
|
}
|
}
|