package com.zy.ai.mcp.service; import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONObject; import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.McpSyncClient; import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; import io.modelcontextprotocol.spec.McpSchema; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.mcp.SyncMcpToolCallbackProvider; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.ToolCallbackProvider; import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; import jakarta.annotation.PreDestroy; import java.time.Duration; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @Slf4j @Service public class SpringAiMcpToolManager { private static final McpSchema.Implementation CLIENT_INFO = new McpSchema.Implementation("wcs-ai-assistant", "1.0.0"); private final Object clientMonitor = new Object(); @Value("${server.port:9090}") private Integer serverPort; @Value("${server.servlet.context-path:}") private String contextPath; @Value("${spring.ai.mcp.server.sse-endpoint:/ai/mcp/sse}") private String sseEndpoint; @Value("${spring.ai.mcp.server.request-timeout:20s}") private Duration requestTimeout; @Value("${app.ai.mcp.client.base-url:}") private String configuredBaseUrl; private volatile ClientSession clientSession; public List> listTools() { List> tools = new ArrayList>(); for (ToolCallback callback : getToolCallbacks()) { if (callback == null || callback.getToolDefinition() == null) { continue; } ToolDefinition definition = callback.getToolDefinition(); Map item = new LinkedHashMap(); item.put("name", definition.name()); item.put("description", definition.description()); item.put("inputSchema", parseSchema(definition.inputSchema())); tools.add(item); } tools.sort(new Comparator>() { @Override public int compare(Map left, Map right) { return String.valueOf(left.get("name")).compareTo(String.valueOf(right.get("name"))); } }); return tools; } public List buildOpenAiTools() { List tools = new ArrayList(); for (Map item : listTools()) { Object name = item.get("name"); if (name == null) { continue; } Map function = new LinkedHashMap(); function.put("name", String.valueOf(name)); Object description = item.get("description"); if (description != null) { function.put("description", String.valueOf(description)); } Object inputSchema = item.get("inputSchema"); function.put("parameters", inputSchema == null ? new LinkedHashMap() : inputSchema); Map tool = new LinkedHashMap(); tool.put("type", "function"); tool.put("function", function); tools.add(tool); } return tools; } public Object callTool(String toolName, JSONObject arguments) { if (toolName == null || toolName.trim().isEmpty()) { throw new IllegalArgumentException("missing tool name"); } ToolCallback callback = findCallback(toolName); if (callback == null) { throw new IllegalArgumentException("tool not found: " + toolName); } String rawResult = callback.call(arguments == null ? "{}" : arguments.toJSONString()); return parseToolResult(rawResult); } private ToolCallback findCallback(String toolName) { for (ToolCallback callback : getToolCallbacks()) { if (callback == null || callback.getToolDefinition() == null) { continue; } if (toolName.equals(callback.getToolDefinition().name())) { return callback; } } return null; } private ToolCallback[] getToolCallbacks() { try { ToolCallback[] callbacks = ensureToolCallbackProvider().getToolCallbacks(); return callbacks == null ? new ToolCallback[0] : callbacks; } catch (Exception e) { log.warn("Failed to load MCP tools through SSE client, baseUrl={}, sseEndpoint={}", resolveBaseUrl(), resolveClientSseEndpoint(), e); resetClientSession(); return new ToolCallback[0]; } } private Object parseToolResult(String rawResult) { if (rawResult == null || rawResult.trim().isEmpty()) { return rawResult; } try { return JSON.parse(rawResult); } catch (Exception ignore) { return rawResult; } } @SuppressWarnings("unchecked") private Map parseSchema(String inputSchema) { if (inputSchema == null || inputSchema.trim().isEmpty()) { return Collections.emptyMap(); } try { Object parsed = JSON.parse(inputSchema); if (parsed instanceof Map) { return new LinkedHashMap((Map) parsed); } } catch (Exception e) { log.warn("Failed to parse MCP tool schema: {}", inputSchema, e); } Map fallback = new LinkedHashMap(); fallback.put("type", "object"); return fallback; } private ToolCallbackProvider ensureToolCallbackProvider() { return ensureClientSession().toolCallbackProvider; } private ClientSession ensureClientSession() { ClientSession current = clientSession; if (current != null) { return current; } synchronized (clientMonitor) { current = clientSession; if (current != null) { return current; } String baseUrl = resolveBaseUrl(); String clientSseEndpoint = resolveClientSseEndpoint(); HttpClientSseClientTransport transport = HttpClientSseClientTransport.builder(baseUrl) .sseEndpoint(clientSseEndpoint) .connectTimeout(requestTimeout) .build(); McpSyncClient syncClient = McpClient.sync(transport) .clientInfo(CLIENT_INFO) .requestTimeout(requestTimeout) .initializationTimeout(requestTimeout) .build(); syncClient.initialize(); SyncMcpToolCallbackProvider callbackProvider = new SyncMcpToolCallbackProvider(syncClient); current = new ClientSession(syncClient, callbackProvider, baseUrl); clientSession = current; log.info("Spring AI MCP SSE client initialized, baseUrl={}, sseEndpoint={}, tools={}", baseUrl, clientSseEndpoint, current.toolCallbackProvider.getToolCallbacks().length); return current; } } private void resetClientSession() { synchronized (clientMonitor) { ClientSession current = clientSession; clientSession = null; if (current != null) { current.close(); } } } private String resolveBaseUrl() { if (configuredBaseUrl != null && !configuredBaseUrl.trim().isEmpty()) { return trimTrailingSlash(configuredBaseUrl.trim()); } StringBuilder url = new StringBuilder("http://127.0.0.1:"); url.append(serverPort == null ? 9090 : serverPort); return trimTrailingSlash(url.toString()); } private String resolveClientSseEndpoint() { String endpoint = normalizePath(sseEndpoint); if (configuredBaseUrl != null && !configuredBaseUrl.trim().isEmpty()) { return endpoint; } String context = normalizeContextPath(contextPath); if (context.isEmpty()) { return endpoint; } return context + endpoint; } private String normalizeContextPath(String path) { if (path == null || path.trim().isEmpty() || "/".equals(path.trim())) { return ""; } String value = path.trim(); if (!value.startsWith("/")) { value = "/" + value; } return trimTrailingSlash(value); } private String normalizePath(String path) { if (path == null || path.trim().isEmpty()) { return "/"; } String value = path.trim(); if (!value.startsWith("/")) { value = "/" + value; } return value; } private String trimTrailingSlash(String value) { if (value == null || value.isEmpty()) { return ""; } return value.endsWith("/") && value.length() > 1 ? value.substring(0, value.length() - 1) : value; } @PreDestroy public void destroy() { resetClientSession(); } private static final class ClientSession implements AutoCloseable { private final McpSyncClient syncClient; private final ToolCallbackProvider toolCallbackProvider; private final String baseUrl; private ClientSession(McpSyncClient syncClient, ToolCallbackProvider toolCallbackProvider, String baseUrl) { this.syncClient = syncClient; this.toolCallbackProvider = toolCallbackProvider; this.baseUrl = baseUrl; } @Override public void close() { try { syncClient.closeGracefully(); } catch (Exception e) { log.debug("Close MCP SSE client gracefully failed, baseUrl={}", baseUrl, e); } try { syncClient.close(); } catch (Exception e) { log.debug("Close MCP SSE client failed, baseUrl={}", baseUrl, e); } } } }