| | |
| | | |
| | | import com.alibaba.fastjson.JSON; |
| | | import com.alibaba.fastjson.JSONObject; |
| | | import com.zy.ai.entity.AiMcpMount; |
| | | import com.zy.ai.enums.AiMcpTransportType; |
| | | import com.zy.ai.service.AiMcpMountService; |
| | | import io.modelcontextprotocol.client.McpClient; |
| | | import io.modelcontextprotocol.client.McpSyncClient; |
| | | import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; |
| | | import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; |
| | | import io.modelcontextprotocol.spec.McpSchema; |
| | | import lombok.RequiredArgsConstructor; |
| | | import lombok.extern.slf4j.Slf4j; |
| | | import org.springframework.ai.mcp.SyncMcpToolCallbackProvider; |
| | | import org.springframework.ai.tool.ToolCallback; |
| | | import org.springframework.ai.tool.ToolCallbackProvider; |
| | | import org.springframework.ai.tool.definition.DefaultToolDefinition; |
| | | import org.springframework.ai.tool.definition.ToolDefinition; |
| | | import org.springframework.beans.factory.annotation.Value; |
| | | import org.springframework.stereotype.Service; |
| | | |
| | | import jakarta.annotation.PreDestroy; |
| | | import java.net.URI; |
| | | import java.time.Duration; |
| | | import java.util.ArrayList; |
| | | import java.util.Collections; |
| | | import java.util.Comparator; |
| | | import java.util.LinkedHashMap; |
| | | import java.util.LinkedHashSet; |
| | | import java.util.List; |
| | | import java.util.Map; |
| | | |
| | | @Slf4j |
| | | @Service |
| | | @RequiredArgsConstructor |
| | | public class SpringAiMcpToolManager { |
| | | |
| | | private static final McpSchema.Implementation CLIENT_INFO = |
| | | new McpSchema.Implementation("wcs-ai-assistant", "1.0.0"); |
| | | |
| | | private final Object clientMonitor = new Object(); |
| | | |
| | | @Value("${server.port:9090}") |
| | | private Integer serverPort; |
| | | |
| | | @Value("${server.servlet.context-path:}") |
| | | private String contextPath; |
| | | |
| | | @Value("${spring.ai.mcp.server.sse-endpoint:/ai/mcp/sse}") |
| | | private String sseEndpoint; |
| | | private final AiMcpMountService aiMcpMountService; |
| | | |
| | | @Value("${spring.ai.mcp.server.request-timeout:20s}") |
| | | private Duration requestTimeout; |
| | | private Duration defaultRequestTimeout; |
| | | |
| | | @Value("${app.ai.mcp.client.base-url:}") |
| | | private String configuredBaseUrl; |
| | | |
| | | private volatile ClientSession clientSession; |
| | | private volatile ClientRegistry clientRegistry; |
| | | |
| | | public List<Map<String, Object>> listTools() { |
| | | List<Map<String, Object>> tools = new ArrayList<Map<String, Object>>(); |
| | | for (ToolCallback callback : getToolCallbacks()) { |
| | | if (callback == null || callback.getToolDefinition() == null) { |
| | | continue; |
| | | } |
| | | ToolDefinition definition = callback.getToolDefinition(); |
| | | Map<String, Object> item = new LinkedHashMap<String, Object>(); |
| | | item.put("name", definition.name()); |
| | | item.put("description", definition.description()); |
| | | item.put("inputSchema", parseSchema(definition.inputSchema())); |
| | | tools.add(item); |
| | | } |
| | | tools.sort(new Comparator<Map<String, Object>>() { |
| | | @Override |
| | | public int compare(Map<String, Object> left, Map<String, Object> right) { |
| | | return String.valueOf(left.get("name")).compareTo(String.valueOf(right.get("name"))); |
| | | } |
| | | }); |
| | | return tools; |
| | | return ensureClientRegistry().toolList; |
| | | } |
| | | |
| | | public List<Object> buildOpenAiTools() { |
| | |
| | | } |
| | | |
| | | public Object callTool(String toolName, JSONObject arguments) { |
| | | if (toolName == null || toolName.trim().isEmpty()) { |
| | | if (isBlank(toolName)) { |
| | | throw new IllegalArgumentException("missing tool name"); |
| | | } |
| | | |
| | | ToolCallback callback = findCallback(toolName); |
| | | if (callback == null) { |
| | | MountedTool mountedTool = ensureClientRegistry().toolMap.get(toolName); |
| | | if (mountedTool == null) { |
| | | throw new IllegalArgumentException("tool not found: " + toolName); |
| | | } |
| | | |
| | | String rawResult = callback.call(arguments == null ? "{}" : arguments.toJSONString()); |
| | | return parseToolResult(rawResult); |
| | | try { |
| | | String rawResult = mountedTool.callback.call(arguments == null ? "{}" : arguments.toJSONString()); |
| | | return parseToolResult(rawResult); |
| | | } catch (Exception e) { |
| | | evictCache(); |
| | | throw e; |
| | | } |
| | | } |
| | | |
| | | private ToolCallback findCallback(String toolName) { |
| | | for (ToolCallback callback : getToolCallbacks()) { |
| | | if (callback == null || callback.getToolDefinition() == null) { |
| | | continue; |
| | | } |
| | | if (toolName.equals(callback.getToolDefinition().name())) { |
| | | return callback; |
| | | public Map<String, Object> testMount(AiMcpMount mount) { |
| | | if (mount == null) { |
| | | throw new IllegalArgumentException("参数不能为空"); |
| | | } |
| | | MountSession session = null; |
| | | long start = System.currentTimeMillis(); |
| | | try { |
| | | session = openSession(mount); |
| | | LinkedHashMap<String, Object> result = new LinkedHashMap<String, Object>(); |
| | | result.put("ok", true); |
| | | result.put("message", "连接成功,已发现 " + session.callbacks.length + " 个工具"); |
| | | result.put("latencyMs", System.currentTimeMillis() - start); |
| | | result.put("url", session.url); |
| | | result.put("transportType", mount.getTransportType()); |
| | | result.put("toolCount", session.callbacks.length); |
| | | result.put("toolNames", collectToolNames(session.callbacks, mount, reservedToolNames(mount.getMountCode()))); |
| | | return result; |
| | | } catch (Exception e) { |
| | | TransportTarget target = resolveTransportTarget(mount); |
| | | LinkedHashMap<String, Object> result = new LinkedHashMap<String, Object>(); |
| | | result.put("ok", false); |
| | | result.put("message", safeMessage(e)); |
| | | result.put("latencyMs", System.currentTimeMillis() - start); |
| | | result.put("url", target.url); |
| | | result.put("transportType", mount.getTransportType()); |
| | | result.put("toolCount", 0); |
| | | result.put("toolNames", new ArrayList<String>()); |
| | | return result; |
| | | } finally { |
| | | if (session != null) { |
| | | session.close(); |
| | | } |
| | | } |
| | | return null; |
| | | } |
| | | |
| | | private ToolCallback[] getToolCallbacks() { |
| | | public void evictCache() { |
| | | resetClientRegistry(); |
| | | } |
| | | |
| | | private ClientRegistry ensureClientRegistry() { |
| | | ClientRegistry current = clientRegistry; |
| | | if (current != null) { |
| | | return current; |
| | | } |
| | | |
| | | synchronized (clientMonitor) { |
| | | current = clientRegistry; |
| | | if (current != null) { |
| | | return current; |
| | | } |
| | | |
| | | current = buildClientRegistry(); |
| | | clientRegistry = current; |
| | | return current; |
| | | } |
| | | } |
| | | |
| | | private ClientRegistry buildClientRegistry() { |
| | | List<AiMcpMount> mounts = loadEnabledMounts(); |
| | | List<MountSession> sessions = new ArrayList<MountSession>(); |
| | | LinkedHashMap<String, MountedTool> toolMap = new LinkedHashMap<String, MountedTool>(); |
| | | List<Map<String, Object>> toolList = new ArrayList<Map<String, Object>>(); |
| | | |
| | | for (AiMcpMount mount : mounts) { |
| | | if (mount == null) { |
| | | continue; |
| | | } |
| | | MountSession session = null; |
| | | try { |
| | | session = openSession(mount); |
| | | sessions.add(session); |
| | | for (ToolCallback callback : session.callbacks) { |
| | | MountedTool mountedTool = buildMountedTool(mount, callback, toolMap.keySet()); |
| | | toolMap.put(mountedTool.toolName, mountedTool); |
| | | toolList.add(toToolDescriptor(mountedTool)); |
| | | } |
| | | log.info("MCP mount loaded, mountCode={}, transport={}, url={}, tools={}", |
| | | mount.getMountCode(), mount.getTransportType(), session.url, session.callbacks.length); |
| | | } catch (Exception e) { |
| | | log.warn("Failed to load MCP mount, mountCode={}, transport={}, url={}", |
| | | mount.getMountCode(), mount.getTransportType(), resolveUrl(mount), e); |
| | | if (session != null) { |
| | | session.close(); |
| | | } |
| | | } |
| | | } |
| | | |
| | | toolList.sort(new Comparator<Map<String, Object>>() { |
| | | @Override |
| | | public int compare(Map<String, Object> left, Map<String, Object> right) { |
| | | return String.valueOf(left.get("name")).compareTo(String.valueOf(right.get("name"))); |
| | | } |
| | | }); |
| | | return new ClientRegistry(sessions, toolMap, toolList); |
| | | } |
| | | |
| | | private List<AiMcpMount> loadEnabledMounts() { |
| | | try { |
| | | ToolCallback[] callbacks = ensureToolCallbackProvider().getToolCallbacks(); |
| | | return callbacks == null ? new ToolCallback[0] : callbacks; |
| | | List<AiMcpMount> mounts = aiMcpMountService.listEnabledOrdered(); |
| | | if (mounts == null || mounts.isEmpty()) { |
| | | aiMcpMountService.initDefaultsIfMissing(); |
| | | mounts = aiMcpMountService.listEnabledOrdered(); |
| | | } |
| | | return mounts == null ? Collections.<AiMcpMount>emptyList() : mounts; |
| | | } catch (Exception e) { |
| | | log.warn("Failed to load MCP tools through SSE client, baseUrl={}, sseEndpoint={}", |
| | | resolveBaseUrl(), resolveClientSseEndpoint(), e); |
| | | resetClientSession(); |
| | | return new ToolCallback[0]; |
| | | log.warn("Failed to query MCP mount configuration", e); |
| | | return Collections.emptyList(); |
| | | } |
| | | } |
| | | |
| | | private MountedTool buildMountedTool(AiMcpMount mount, ToolCallback callback, java.util.Set<String> usedNames) { |
| | | ToolDefinition definition = callback == null ? null : callback.getToolDefinition(); |
| | | if (definition == null || isBlank(definition.name())) { |
| | | throw new IllegalArgumentException("invalid tool definition"); |
| | | } |
| | | |
| | | String originalName = definition.name(); |
| | | String preferredName = mount.getMountCode() + "_" + originalName; |
| | | String finalName = ensureUniqueToolName(preferredName, mount, originalName, usedNames); |
| | | ToolCallback effectiveCallback = originalName.equals(finalName) |
| | | ? callback |
| | | : new MountedToolCallback(finalName, callback); |
| | | |
| | | return new MountedTool(mount, originalName, finalName, effectiveCallback); |
| | | } |
| | | |
| | | private String ensureUniqueToolName(String preferredName, |
| | | AiMcpMount mount, |
| | | String originalName, |
| | | java.util.Set<String> usedNames) { |
| | | if (!usedNames.contains(preferredName)) { |
| | | return preferredName; |
| | | } |
| | | |
| | | String fallbackBase = mount.getMountCode() + "_" + originalName; |
| | | if (!usedNames.contains(fallbackBase)) { |
| | | log.warn("Duplicate MCP tool name detected, fallback rename applied, mountCode={}, originalName={}, finalName={}", |
| | | mount.getMountCode(), originalName, fallbackBase); |
| | | return fallbackBase; |
| | | } |
| | | |
| | | int index = 2; |
| | | String candidate = fallbackBase + "_" + index; |
| | | while (usedNames.contains(candidate)) { |
| | | index++; |
| | | candidate = fallbackBase + "_" + index; |
| | | } |
| | | log.warn("Duplicate MCP tool name detected, numbered rename applied, mountCode={}, originalName={}, finalName={}", |
| | | mount.getMountCode(), originalName, candidate); |
| | | return candidate; |
| | | } |
| | | |
| | | private Map<String, Object> toToolDescriptor(MountedTool mountedTool) { |
| | | ToolDefinition definition = mountedTool.callback.getToolDefinition(); |
| | | Map<String, Object> item = new LinkedHashMap<String, Object>(); |
| | | item.put("name", definition.name()); |
| | | item.put("originalName", mountedTool.originalName); |
| | | item.put("mountCode", mountedTool.mount.getMountCode()); |
| | | item.put("mountName", mountedTool.mount.getName()); |
| | | item.put("transportType", mountedTool.mount.getTransportType()); |
| | | item.put("description", definition.description()); |
| | | item.put("inputSchema", parseSchema(definition.inputSchema())); |
| | | return item; |
| | | } |
| | | |
| | | private List<String> collectToolNames(ToolCallback[] callbacks, AiMcpMount mount, java.util.Set<String> reservedNames) { |
| | | List<String> names = new ArrayList<String>(); |
| | | java.util.Set<String> used = new LinkedHashSet<String>(); |
| | | if (reservedNames != null && !reservedNames.isEmpty()) { |
| | | used.addAll(reservedNames); |
| | | } |
| | | if (callbacks == null) { |
| | | return names; |
| | | } |
| | | for (ToolCallback callback : callbacks) { |
| | | ToolDefinition definition = callback == null ? null : callback.getToolDefinition(); |
| | | if (definition == null || isBlank(definition.name())) { |
| | | continue; |
| | | } |
| | | String preferred = mount.getMountCode() + "_" + definition.name(); |
| | | String finalName = ensureUniqueToolName(preferred, mount, definition.name(), used); |
| | | used.add(finalName); |
| | | names.add(finalName); |
| | | } |
| | | return names; |
| | | } |
| | | |
| | | private java.util.Set<String> reservedToolNames(String mountCode) { |
| | | LinkedHashSet<String> reserved = new LinkedHashSet<String>(); |
| | | ClientRegistry current = clientRegistry; |
| | | if (current == null || current.toolMap == null || current.toolMap.isEmpty()) { |
| | | return reserved; |
| | | } |
| | | for (MountedTool mountedTool : current.toolMap.values()) { |
| | | if (mountedTool == null || mountedTool.mount == null) { |
| | | continue; |
| | | } |
| | | if (mountCode != null && mountCode.equals(mountedTool.mount.getMountCode())) { |
| | | continue; |
| | | } |
| | | reserved.add(mountedTool.toolName); |
| | | } |
| | | return reserved; |
| | | } |
| | | |
| | | private MountSession openSession(AiMcpMount mount) { |
| | | Duration timeout = resolveTimeout(mount); |
| | | TransportTarget target = resolveTransportTarget(mount); |
| | | String baseUrl = target.baseUrl; |
| | | String endpoint = target.endpoint; |
| | | AiMcpTransportType transportType = AiMcpTransportType.ofCode(mount.getTransportType()); |
| | | McpSyncClient syncClient; |
| | | |
| | | if (transportType == AiMcpTransportType.STREAMABLE_HTTP) { |
| | | HttpClientStreamableHttpTransport transport = HttpClientStreamableHttpTransport.builder(baseUrl) |
| | | .endpoint(endpoint) |
| | | .connectTimeout(timeout) |
| | | .build(); |
| | | syncClient = McpClient.sync(transport) |
| | | .clientInfo(CLIENT_INFO) |
| | | .requestTimeout(timeout) |
| | | .initializationTimeout(timeout) |
| | | .build(); |
| | | } else { |
| | | HttpClientSseClientTransport transport = HttpClientSseClientTransport.builder(baseUrl) |
| | | .sseEndpoint(endpoint) |
| | | .connectTimeout(timeout) |
| | | .build(); |
| | | syncClient = McpClient.sync(transport) |
| | | .clientInfo(CLIENT_INFO) |
| | | .requestTimeout(timeout) |
| | | .initializationTimeout(timeout) |
| | | .build(); |
| | | } |
| | | |
| | | syncClient.initialize(); |
| | | SyncMcpToolCallbackProvider callbackProvider = new SyncMcpToolCallbackProvider(syncClient); |
| | | ToolCallback[] callbacks = callbackProvider.getToolCallbacks(); |
| | | return new MountSession(mount, syncClient, target.url, callbacks == null ? new ToolCallback[0] : callbacks); |
| | | } |
| | | |
| | | private Duration resolveTimeout(AiMcpMount mount) { |
| | | if (mount != null && mount.getRequestTimeoutMs() != null && mount.getRequestTimeoutMs() > 0) { |
| | | return Duration.ofMillis(mount.getRequestTimeoutMs()); |
| | | } |
| | | return defaultRequestTimeout == null ? Duration.ofSeconds(20) : defaultRequestTimeout; |
| | | } |
| | | |
| | | private String resolveUrl(AiMcpMount mount) { |
| | | return trim(mount == null ? null : mount.getUrl()); |
| | | } |
| | | |
| | | private TransportTarget resolveTransportTarget(AiMcpMount mount) { |
| | | String rawUrl = resolveUrl(mount); |
| | | if (isBlank(rawUrl)) { |
| | | throw new IllegalArgumentException("missing url"); |
| | | } |
| | | try { |
| | | URI finalUri = URI.create(rawUrl); |
| | | String authority = finalUri.getRawAuthority(); |
| | | String scheme = finalUri.getScheme(); |
| | | if (isBlank(scheme) || isBlank(authority)) { |
| | | throw new IllegalArgumentException("invalid MCP url"); |
| | | } |
| | | String baseUrl = scheme + "://" + authority; |
| | | String endpoint = finalUri.getRawPath(); |
| | | if (isBlank(endpoint)) { |
| | | throw new IllegalArgumentException("missing MCP path"); |
| | | } |
| | | if (finalUri.getRawQuery() != null && !finalUri.getRawQuery().isEmpty()) { |
| | | endpoint = endpoint + "?" + finalUri.getRawQuery(); |
| | | } |
| | | return new TransportTarget(rawUrl, baseUrl, endpoint); |
| | | } catch (IllegalArgumentException e) { |
| | | throw e; |
| | | } catch (Exception e) { |
| | | throw new IllegalArgumentException("invalid url: " + rawUrl, e); |
| | | } |
| | | } |
| | | |
| | |
| | | return fallback; |
| | | } |
| | | |
| | | private ToolCallbackProvider ensureToolCallbackProvider() { |
| | | return ensureClientSession().toolCallbackProvider; |
| | | } |
| | | |
| | | private ClientSession ensureClientSession() { |
| | | ClientSession current = clientSession; |
| | | if (current != null) { |
| | | return current; |
| | | } |
| | | |
| | | synchronized (clientMonitor) { |
| | | current = clientSession; |
| | | if (current != null) { |
| | | return current; |
| | | } |
| | | |
| | | String baseUrl = resolveBaseUrl(); |
| | | String clientSseEndpoint = resolveClientSseEndpoint(); |
| | | HttpClientSseClientTransport transport = HttpClientSseClientTransport.builder(baseUrl) |
| | | .sseEndpoint(clientSseEndpoint) |
| | | .connectTimeout(requestTimeout) |
| | | .build(); |
| | | |
| | | McpSyncClient syncClient = McpClient.sync(transport) |
| | | .clientInfo(CLIENT_INFO) |
| | | .requestTimeout(requestTimeout) |
| | | .initializationTimeout(requestTimeout) |
| | | .build(); |
| | | syncClient.initialize(); |
| | | |
| | | SyncMcpToolCallbackProvider callbackProvider = new SyncMcpToolCallbackProvider(syncClient); |
| | | current = new ClientSession(syncClient, callbackProvider, baseUrl); |
| | | clientSession = current; |
| | | log.info("Spring AI MCP SSE client initialized, baseUrl={}, sseEndpoint={}, tools={}", |
| | | baseUrl, clientSseEndpoint, current.toolCallbackProvider.getToolCallbacks().length); |
| | | return current; |
| | | } |
| | | } |
| | | |
| | | private void resetClientSession() { |
| | | synchronized (clientMonitor) { |
| | | ClientSession current = clientSession; |
| | | clientSession = null; |
| | | if (current != null) { |
| | | current.close(); |
| | | } |
| | | } |
| | | } |
| | | |
| | | private String resolveBaseUrl() { |
| | | if (configuredBaseUrl != null && !configuredBaseUrl.trim().isEmpty()) { |
| | | return trimTrailingSlash(configuredBaseUrl.trim()); |
| | | } |
| | | StringBuilder url = new StringBuilder("http://127.0.0.1:"); |
| | | url.append(serverPort == null ? 9090 : serverPort); |
| | | return trimTrailingSlash(url.toString()); |
| | | } |
| | | |
| | | private String resolveClientSseEndpoint() { |
| | | String endpoint = normalizePath(sseEndpoint); |
| | | if (configuredBaseUrl != null && !configuredBaseUrl.trim().isEmpty()) { |
| | | return endpoint; |
| | | } |
| | | String context = normalizeContextPath(contextPath); |
| | | if (context.isEmpty()) { |
| | | return endpoint; |
| | | } |
| | | return context + endpoint; |
| | | } |
| | | |
| | | private String normalizeContextPath(String path) { |
| | | if (path == null || path.trim().isEmpty() || "/".equals(path.trim())) { |
| | | return ""; |
| | | } |
| | | String value = path.trim(); |
| | | if (!value.startsWith("/")) { |
| | | value = "/" + value; |
| | | } |
| | | return trimTrailingSlash(value); |
| | | } |
| | | |
| | | private String normalizePath(String path) { |
| | | if (path == null || path.trim().isEmpty()) { |
| | | if (isBlank(path)) { |
| | | return "/"; |
| | | } |
| | | String value = path.trim(); |
| | |
| | | return value; |
| | | } |
| | | |
| | | private String trimTrailingSlash(String value) { |
| | | if (value == null || value.isEmpty()) { |
| | | return ""; |
| | | private String safeMessage(Throwable throwable) { |
| | | if (throwable == null) { |
| | | return "unknown error"; |
| | | } |
| | | return value.endsWith("/") && value.length() > 1 ? value.substring(0, value.length() - 1) : value; |
| | | if (!isBlank(throwable.getMessage())) { |
| | | return throwable.getMessage(); |
| | | } |
| | | return throwable.getClass().getSimpleName(); |
| | | } |
| | | |
| | | private String trim(String text) { |
| | | return text == null ? null : text.trim(); |
| | | } |
| | | |
| | | private boolean isBlank(String text) { |
| | | return text == null || text.trim().isEmpty(); |
| | | } |
| | | |
| | | private void resetClientRegistry() { |
| | | synchronized (clientMonitor) { |
| | | ClientRegistry current = clientRegistry; |
| | | clientRegistry = null; |
| | | if (current != null) { |
| | | current.close(); |
| | | } |
| | | } |
| | | } |
| | | |
| | | @PreDestroy |
| | | public void destroy() { |
| | | resetClientSession(); |
| | | resetClientRegistry(); |
| | | } |
| | | |
| | | private static final class ClientSession implements AutoCloseable { |
| | | private static final class ClientRegistry implements AutoCloseable { |
| | | |
| | | private final List<MountSession> sessions; |
| | | private final Map<String, MountedTool> toolMap; |
| | | private final List<Map<String, Object>> toolList; |
| | | |
| | | private ClientRegistry(List<MountSession> sessions, |
| | | Map<String, MountedTool> toolMap, |
| | | List<Map<String, Object>> toolList) { |
| | | this.sessions = sessions; |
| | | this.toolMap = toolMap; |
| | | this.toolList = toolList; |
| | | } |
| | | |
| | | @Override |
| | | public void close() { |
| | | if (sessions == null) { |
| | | return; |
| | | } |
| | | for (MountSession session : sessions) { |
| | | if (session != null) { |
| | | session.close(); |
| | | } |
| | | } |
| | | } |
| | | } |
| | | |
| | | private static final class MountSession implements AutoCloseable { |
| | | |
| | | private final AiMcpMount mount; |
| | | private final McpSyncClient syncClient; |
| | | private final ToolCallbackProvider toolCallbackProvider; |
| | | private final String baseUrl; |
| | | private final String url; |
| | | private final ToolCallback[] callbacks; |
| | | |
| | | private ClientSession(McpSyncClient syncClient, ToolCallbackProvider toolCallbackProvider, String baseUrl) { |
| | | private MountSession(AiMcpMount mount, |
| | | McpSyncClient syncClient, |
| | | String url, |
| | | ToolCallback[] callbacks) { |
| | | this.mount = mount; |
| | | this.syncClient = syncClient; |
| | | this.toolCallbackProvider = toolCallbackProvider; |
| | | this.baseUrl = baseUrl; |
| | | this.url = url; |
| | | this.callbacks = callbacks; |
| | | } |
| | | |
| | | @Override |
| | |
| | | try { |
| | | syncClient.closeGracefully(); |
| | | } catch (Exception e) { |
| | | log.debug("Close MCP SSE client gracefully failed, baseUrl={}", baseUrl, e); |
| | | log.debug("Close MCP client gracefully failed, mountCode={}", mount == null ? null : mount.getMountCode(), e); |
| | | } |
| | | try { |
| | | syncClient.close(); |
| | | } catch (Exception e) { |
| | | log.debug("Close MCP SSE client failed, baseUrl={}", baseUrl, e); |
| | | log.debug("Close MCP client failed, mountCode={}", mount == null ? null : mount.getMountCode(), e); |
| | | } |
| | | } |
| | | } |
| | | |
| | | private static final class MountedTool { |
| | | |
| | | private final AiMcpMount mount; |
| | | private final String originalName; |
| | | private final String toolName; |
| | | private final ToolCallback callback; |
| | | |
| | | private MountedTool(AiMcpMount mount, String originalName, String toolName, ToolCallback callback) { |
| | | this.mount = mount; |
| | | this.originalName = originalName; |
| | | this.toolName = toolName; |
| | | this.callback = callback; |
| | | } |
| | | } |
| | | |
| | | private static final class MountedToolCallback implements ToolCallback { |
| | | |
| | | private final ToolCallback delegate; |
| | | private final ToolDefinition definition; |
| | | |
| | | private MountedToolCallback(String name, ToolCallback delegate) { |
| | | this.delegate = delegate; |
| | | ToolDefinition source = delegate.getToolDefinition(); |
| | | this.definition = DefaultToolDefinition.builder() |
| | | .name(name) |
| | | .description(source == null ? "" : source.description()) |
| | | .inputSchema(source == null ? "{}" : source.inputSchema()) |
| | | .build(); |
| | | } |
| | | |
| | | @Override |
| | | public ToolDefinition getToolDefinition() { |
| | | return definition; |
| | | } |
| | | |
| | | @Override |
| | | public org.springframework.ai.tool.metadata.ToolMetadata getToolMetadata() { |
| | | return delegate.getToolMetadata(); |
| | | } |
| | | |
| | | @Override |
| | | public String call(String toolInput) { |
| | | return delegate.call(toolInput); |
| | | } |
| | | |
| | | @Override |
| | | public String call(String toolInput, org.springframework.ai.chat.model.ToolContext toolContext) { |
| | | return delegate.call(toolInput, toolContext); |
| | | } |
| | | } |
| | | |
| | | private static final class TransportTarget { |
| | | |
| | | private final String url; |
| | | private final String baseUrl; |
| | | private final String endpoint; |
| | | |
| | | private TransportTarget(String url, String baseUrl, String endpoint) { |
| | | this.url = url; |
| | | this.baseUrl = baseUrl; |
| | | this.endpoint = endpoint; |
| | | } |
| | | } |
| | | } |