pom.xml
@@ -150,6 +150,10 @@ <artifactId>spring-ai-openai</artifactId> </dependency> <dependency> <groupId>org.springframework.ai</groupId> <artifactId>spring-ai-starter-mcp-server-webmvc</artifactId> </dependency> <dependency> <groupId>com.google.ortools</groupId> <artifactId>ortools-java</artifactId> <version>${ortools.version}</version> src/main/java/com/zy/ai/mcp/config/McpToolsBootstrap.java
File was deleted src/main/java/com/zy/ai/mcp/config/SpringAiMcpConfig.java
New file @@ -0,0 +1,17 @@ package com.zy.ai.mcp.config; import com.zy.ai.mcp.tool.WcsMcpTools; import org.springframework.ai.support.ToolCallbacks; import org.springframework.ai.tool.StaticToolCallbackProvider; import org.springframework.ai.tool.ToolCallbackProvider; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @Configuration public class SpringAiMcpConfig { @Bean("wcsMcpToolCallbackProvider") public ToolCallbackProvider wcsMcpToolCallbackProvider(WcsMcpTools wcsMcpTools) { return new StaticToolCallbackProvider(ToolCallbacks.from(wcsMcpTools)); } } src/main/java/com/zy/ai/mcp/config/SpringAiMcpTransportConfig.java
New file @@ -0,0 +1,113 @@ package com.zy.ai.mcp.config; import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.json.jackson.JacksonMcpJsonMapper; import io.modelcontextprotocol.server.McpServer; import io.modelcontextprotocol.server.McpServerFeatures; import io.modelcontextprotocol.server.McpSyncServer; import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; import io.modelcontextprotocol.spec.McpSchema; import org.springframework.ai.mcp.server.common.autoconfigure.properties.McpServerChangeNotificationProperties; import org.springframework.ai.mcp.server.common.autoconfigure.properties.McpServerProperties; import org.springframework.ai.mcp.server.common.autoconfigure.properties.McpServerSseProperties; import org.springframework.beans.factory.ObjectProvider; import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.core.env.Environment; import org.springframework.web.context.support.StandardServletEnvironment; import org.springframework.web.servlet.function.RouterFunction; import org.springframework.web.servlet.function.ServerResponse; import java.util.Collections; import java.util.List; @Configuration @EnableConfigurationProperties(McpServerSseProperties.class) public class SpringAiMcpTransportConfig { @Bean("wcsOfficialSseMcpSupport") public OfficialSseMcpSupport wcsOfficialSseMcpSupport( @Qualifier("mcpServerObjectMapper") ObjectMapper objectMapper, McpServerSseProperties sseProperties, McpServerProperties serverProperties, McpServerChangeNotificationProperties changeNotificationProperties, @Qualifier("syncTools") ObjectProvider<List<McpServerFeatures.SyncToolSpecification>> syncToolsProvider, Environment environment) { WebMvcSseServerTransportProvider transportProvider = WebMvcSseServerTransportProvider.builder() .jsonMapper(new JacksonMcpJsonMapper(objectMapper)) .baseUrl(sseProperties.getBaseUrl()) .sseEndpoint(sseProperties.getSseEndpoint()) .messageEndpoint(sseProperties.getSseMessageEndpoint()) .keepAliveInterval(sseProperties.getKeepAliveInterval()) .build(); List<McpServerFeatures.SyncToolSpecification> syncTools = syncToolsProvider.getIfAvailable(Collections::emptyList); McpSyncServer mcpSyncServer = buildSseSyncServer( transportProvider, serverProperties, changeNotificationProperties, syncTools, environment ); return new OfficialSseMcpSupport(transportProvider, mcpSyncServer); } @Bean("webMvcSseServerRouterFunction") public RouterFunction<ServerResponse> webMvcSseServerRouterFunction( @Qualifier("wcsOfficialSseMcpSupport") OfficialSseMcpSupport support) { return support.routerFunction(); } private McpSyncServer buildSseSyncServer( WebMvcSseServerTransportProvider transportProvider, McpServerProperties serverProperties, McpServerChangeNotificationProperties changeNotificationProperties, List<McpServerFeatures.SyncToolSpecification> syncTools, Environment environment) { McpServer.SingleSessionSyncSpecification specification = McpServer.sync(transportProvider); specification.serverInfo(new McpSchema.Implementation(serverProperties.getName(), serverProperties.getVersion())); McpSchema.ServerCapabilities.Builder capabilitiesBuilder = McpSchema.ServerCapabilities.builder(); if (serverProperties.getCapabilities().isTool()) { capabilitiesBuilder.tools(changeNotificationProperties.isToolChangeNotification()); if (syncTools != null && !syncTools.isEmpty()) { specification.tools(syncTools); } } specification.capabilities(capabilitiesBuilder.build()); specification.instructions(serverProperties.getInstructions()); specification.requestTimeout(serverProperties.getRequestTimeout()); if (environment instanceof StandardServletEnvironment) { specification.immediateExecution(true); } return specification.build(); } public static final class OfficialSseMcpSupport implements AutoCloseable { private final WebMvcSseServerTransportProvider transportProvider; private final McpSyncServer mcpSyncServer; public OfficialSseMcpSupport(WebMvcSseServerTransportProvider transportProvider, McpSyncServer mcpSyncServer) { this.transportProvider = transportProvider; this.mcpSyncServer = mcpSyncServer; } public RouterFunction<ServerResponse> routerFunction() { return transportProvider.getRouterFunction(); } @Override public void close() { mcpSyncServer.closeGracefully(); mcpSyncServer.close(); } } } src/main/java/com/zy/ai/mcp/controller/McpController.java
File was deleted src/main/java/com/zy/ai/mcp/dto/JsonRpcError.java
File was deleted src/main/java/com/zy/ai/mcp/dto/JsonRpcRequest.java
File was deleted src/main/java/com/zy/ai/mcp/dto/JsonRpcResponse.java
File was deleted src/main/java/com/zy/ai/mcp/dto/McpToolHandler.java
File was deleted src/main/java/com/zy/ai/mcp/dto/ToolDefinition.java
File was deleted src/main/java/com/zy/ai/mcp/dto/ToolRegistry.java
File was deleted src/main/java/com/zy/ai/mcp/service/SpringAiMcpToolManager.java
New file @@ -0,0 +1,296 @@ package com.zy.ai.mcp.service; import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONObject; import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.McpSyncClient; import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; import io.modelcontextprotocol.spec.McpSchema; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.mcp.SyncMcpToolCallbackProvider; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.ToolCallbackProvider; import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; import jakarta.annotation.PreDestroy; import java.time.Duration; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @Slf4j @Service public class SpringAiMcpToolManager { private static final McpSchema.Implementation CLIENT_INFO = new McpSchema.Implementation("wcs-ai-assistant", "1.0.0"); private final Object clientMonitor = new Object(); @Value("${server.port:9090}") private Integer serverPort; @Value("${server.servlet.context-path:}") private String contextPath; @Value("${spring.ai.mcp.server.sse-endpoint:/ai/mcp/sse}") private String sseEndpoint; @Value("${spring.ai.mcp.server.request-timeout:20s}") private Duration requestTimeout; @Value("${app.ai.mcp.client.base-url:}") private String configuredBaseUrl; private volatile ClientSession clientSession; public List<Map<String, Object>> listTools() { List<Map<String, Object>> tools = new ArrayList<Map<String, Object>>(); for (ToolCallback callback : getToolCallbacks()) { if (callback == null || callback.getToolDefinition() == null) { continue; } ToolDefinition definition = callback.getToolDefinition(); Map<String, Object> item = new LinkedHashMap<String, Object>(); item.put("name", definition.name()); item.put("description", definition.description()); item.put("inputSchema", parseSchema(definition.inputSchema())); tools.add(item); } tools.sort(new Comparator<Map<String, Object>>() { @Override public int compare(Map<String, Object> left, Map<String, Object> right) { return String.valueOf(left.get("name")).compareTo(String.valueOf(right.get("name"))); } }); return tools; } public List<Object> buildOpenAiTools() { List<Object> tools = new ArrayList<Object>(); for (Map<String, Object> item : listTools()) { Object name = item.get("name"); if (name == null) { continue; } Map<String, Object> function = new LinkedHashMap<String, Object>(); function.put("name", String.valueOf(name)); Object description = item.get("description"); if (description != null) { function.put("description", String.valueOf(description)); } Object inputSchema = item.get("inputSchema"); function.put("parameters", inputSchema == null ? new LinkedHashMap<String, Object>() : inputSchema); Map<String, Object> tool = new LinkedHashMap<String, Object>(); tool.put("type", "function"); tool.put("function", function); tools.add(tool); } return tools; } public Object callTool(String toolName, JSONObject arguments) { if (toolName == null || toolName.trim().isEmpty()) { throw new IllegalArgumentException("missing tool name"); } ToolCallback callback = findCallback(toolName); if (callback == null) { throw new IllegalArgumentException("tool not found: " + toolName); } String rawResult = callback.call(arguments == null ? "{}" : arguments.toJSONString()); return parseToolResult(rawResult); } private ToolCallback findCallback(String toolName) { for (ToolCallback callback : getToolCallbacks()) { if (callback == null || callback.getToolDefinition() == null) { continue; } if (toolName.equals(callback.getToolDefinition().name())) { return callback; } } return null; } private ToolCallback[] getToolCallbacks() { try { ToolCallback[] callbacks = ensureToolCallbackProvider().getToolCallbacks(); return callbacks == null ? new ToolCallback[0] : callbacks; } catch (Exception e) { log.warn("Failed to load MCP tools through SSE client, baseUrl={}, sseEndpoint={}", resolveBaseUrl(), resolveClientSseEndpoint(), e); resetClientSession(); return new ToolCallback[0]; } } private Object parseToolResult(String rawResult) { if (rawResult == null || rawResult.trim().isEmpty()) { return rawResult; } try { return JSON.parse(rawResult); } catch (Exception ignore) { return rawResult; } } @SuppressWarnings("unchecked") private Map<String, Object> parseSchema(String inputSchema) { if (inputSchema == null || inputSchema.trim().isEmpty()) { return Collections.emptyMap(); } try { Object parsed = JSON.parse(inputSchema); if (parsed instanceof Map) { return new LinkedHashMap<String, Object>((Map<String, Object>) parsed); } } catch (Exception e) { log.warn("Failed to parse MCP tool schema: {}", inputSchema, e); } Map<String, Object> fallback = new LinkedHashMap<String, Object>(); fallback.put("type", "object"); return fallback; } private ToolCallbackProvider ensureToolCallbackProvider() { return ensureClientSession().toolCallbackProvider; } private ClientSession ensureClientSession() { ClientSession current = clientSession; if (current != null) { return current; } synchronized (clientMonitor) { current = clientSession; if (current != null) { return current; } String baseUrl = resolveBaseUrl(); String clientSseEndpoint = resolveClientSseEndpoint(); HttpClientSseClientTransport transport = HttpClientSseClientTransport.builder(baseUrl) .sseEndpoint(clientSseEndpoint) .connectTimeout(requestTimeout) .build(); McpSyncClient syncClient = McpClient.sync(transport) .clientInfo(CLIENT_INFO) .requestTimeout(requestTimeout) .initializationTimeout(requestTimeout) .build(); syncClient.initialize(); SyncMcpToolCallbackProvider callbackProvider = new SyncMcpToolCallbackProvider(syncClient); current = new ClientSession(syncClient, callbackProvider, baseUrl); clientSession = current; log.info("Spring AI MCP SSE client initialized, baseUrl={}, sseEndpoint={}, tools={}", baseUrl, clientSseEndpoint, current.toolCallbackProvider.getToolCallbacks().length); return current; } } private void resetClientSession() { synchronized (clientMonitor) { ClientSession current = clientSession; clientSession = null; if (current != null) { current.close(); } } } private String resolveBaseUrl() { if (configuredBaseUrl != null && !configuredBaseUrl.trim().isEmpty()) { return trimTrailingSlash(configuredBaseUrl.trim()); } StringBuilder url = new StringBuilder("http://127.0.0.1:"); url.append(serverPort == null ? 9090 : serverPort); return trimTrailingSlash(url.toString()); } private String resolveClientSseEndpoint() { String endpoint = normalizePath(sseEndpoint); if (configuredBaseUrl != null && !configuredBaseUrl.trim().isEmpty()) { return endpoint; } String context = normalizeContextPath(contextPath); if (context.isEmpty()) { return endpoint; } return context + endpoint; } private String normalizeContextPath(String path) { if (path == null || path.trim().isEmpty() || "/".equals(path.trim())) { return ""; } String value = path.trim(); if (!value.startsWith("/")) { value = "/" + value; } return trimTrailingSlash(value); } private String normalizePath(String path) { if (path == null || path.trim().isEmpty()) { return "/"; } String value = path.trim(); if (!value.startsWith("/")) { value = "/" + value; } return value; } private String trimTrailingSlash(String value) { if (value == null || value.isEmpty()) { return ""; } return value.endsWith("/") && value.length() > 1 ? value.substring(0, value.length() - 1) : value; } @PreDestroy public void destroy() { resetClientSession(); } private static final class ClientSession implements AutoCloseable { private final McpSyncClient syncClient; private final ToolCallbackProvider toolCallbackProvider; private final String baseUrl; private ClientSession(McpSyncClient syncClient, ToolCallbackProvider toolCallbackProvider, String baseUrl) { this.syncClient = syncClient; this.toolCallbackProvider = toolCallbackProvider; this.baseUrl = baseUrl; } @Override public void close() { try { syncClient.closeGracefully(); } catch (Exception e) { log.debug("Close MCP SSE client gracefully failed, baseUrl={}", baseUrl, e); } try { syncClient.close(); } catch (Exception e) { log.debug("Close MCP SSE client failed, baseUrl={}", baseUrl, e); } } } } src/main/java/com/zy/ai/mcp/tool/WcsMcpTools.java
New file @@ -0,0 +1,73 @@ package com.zy.ai.mcp.tool; import com.alibaba.fastjson.JSONObject; import com.zy.ai.mcp.service.WcsDataFacade; import lombok.RequiredArgsConstructor; import org.springframework.ai.tool.annotation.Tool; import org.springframework.ai.tool.annotation.ToolParam; import org.springframework.stereotype.Component; import java.util.List; @Component @RequiredArgsConstructor public class WcsMcpTools { private final WcsDataFacade wcsDataFacade; @Tool(name = "device_get_crn_status", description = "通过堆垛机编号查询堆垛机设备实时数据") public Object getCrnDeviceStatus( @ToolParam(description = "堆垛机编号列表,不传则查询全部堆垛机", required = false) List<Integer> crnNos) { return wcsDataFacade.getCrnDeviceStatus(json().fluentPut("crnNos", crnNos)); } @Tool(name = "device_get_station_status", description = "查询输送线站点设备实时数据") public Object getStationDeviceStatus() { return wcsDataFacade.getStationDeviceStatus(json()); } @Tool(name = "device_get_rgv_status", description = "通过RGV编号查询RGV设备实时数据") public Object getRgvDeviceStatus( @ToolParam(description = "RGV编号列表,不传则查询全部RGV", required = false) List<Integer> rgvNos) { return wcsDataFacade.getRgvDeviceStatus(json().fluentPut("rgvNos", rgvNos)); } @Tool(name = "task_list", description = "通过筛选条件查询任务数据") public Object getTasks( @ToolParam(description = "堆垛机编号", required = false) Integer crnNo, @ToolParam(description = "RGV编号", required = false) Integer rgvNo, @ToolParam(description = "任务单号列表", required = false) List<Integer> taskNos, @ToolParam(description = "返回条数上限,默认 200", required = false) Integer limit) { return wcsDataFacade.getTasks(json() .fluentPut("crnNo", crnNo) .fluentPut("rgvNo", rgvNo) .fluentPut("taskNos", taskNos) .fluentPut("limit", limit)); } @Tool(name = "log_query", description = "通过筛选条件查询 AI 日志数据") public Object getLogs( @ToolParam(description = "返回日志行数上限,默认 500", required = false) Integer limit) { return wcsDataFacade.getLogs(json().fluentPut("limit", limit)); } @Tool(name = "config_get_device_config", description = "通过设备编号查询设备配置数据") public Object getDeviceConfig( @ToolParam(description = "堆垛机编号列表", required = false) List<Integer> crnNos, @ToolParam(description = "RGV编号列表", required = false) List<Integer> rgvNos, @ToolParam(description = "输送线编号列表", required = false) List<Integer> devpNos) { return wcsDataFacade.getDeviceConfig(json() .fluentPut("crnNos", crnNos) .fluentPut("rgvNos", rgvNos) .fluentPut("devpNos", devpNos)); } @Tool(name = "config_get_system_config", description = "查询系统配置数据") public Object getSystemConfig() { return wcsDataFacade.getSystemConfig(json()); } private JSONObject json() { return new JSONObject(); } } src/main/java/com/zy/ai/service/WcsDiagnosisService.java
@@ -5,7 +5,7 @@ import com.zy.ai.entity.ChatCompletionRequest; import com.zy.ai.entity.ChatCompletionResponse; import com.zy.ai.entity.WcsDiagnosisRequest; import com.zy.ai.mcp.controller.McpController; import com.zy.ai.mcp.service.SpringAiMcpToolManager; import com.zy.ai.utils.AiPromptUtils; import com.zy.ai.utils.AiUtils; import com.zy.common.utils.RedisUtil; @@ -36,7 +36,7 @@ @Autowired private AiUtils aiUtils; @Autowired(required = false) private McpController mcpController; private SpringAiMcpToolManager mcpToolManager; public void diagnoseStream(WcsDiagnosisRequest request, SseEmitter emitter) { List<ChatCompletionRequest.Message> messages = new ArrayList<>(); @@ -257,8 +257,8 @@ SseEmitter emitter, String chatId) { try { if (mcpController == null) return false; List<Object> tools = buildOpenAiTools(); if (mcpToolManager == null) return false; List<Object> tools = mcpToolManager.buildOpenAiTools(); if (tools.isEmpty()) return false; baseMessages.add(systemPrompt); @@ -303,7 +303,7 @@ } Object output; try { output = mcpController.callTool(toolName, args); output = mcpToolManager.callTool(toolName, args); } catch (Exception e) { java.util.LinkedHashMap<String, Object> err = new java.util.LinkedHashMap<String, Object>(); err.put("tool", toolName); @@ -386,31 +386,6 @@ } catch (Exception e) { log.warn("SSE send failed", e); } } private List<Object> buildOpenAiTools() { if (mcpController == null) return java.util.Collections.emptyList(); List<Map<String, Object>> mcpTools = mcpController.listTools(); if (mcpTools == null || mcpTools.isEmpty()) return java.util.Collections.emptyList(); List<Object> tools = new ArrayList<>(); for (Map<String, Object> t : mcpTools) { if (t == null) continue; Object name = t.get("name"); if (name == null) continue; Object inputSchema = t.get("inputSchema"); java.util.LinkedHashMap<String, Object> function = new java.util.LinkedHashMap<String, Object>(); function.put("name", String.valueOf(name)); Object desc = t.get("description"); if (desc != null) function.put("description", String.valueOf(desc)); function.put("parameters", inputSchema == null ? new java.util.LinkedHashMap<String, Object>() : inputSchema); java.util.LinkedHashMap<String, Object> tool = new java.util.LinkedHashMap<String, Object>(); tool.put("type", "function"); tool.put("function", function); tools.add(tool); } return tools; } private void sendLargeText(SseEmitter emitter, String text) { src/main/java/com/zy/common/config/AdminInterceptor.java
@@ -64,6 +64,9 @@ } // 跨域设置 // response.setHeader("Access-Control-Allow-Origin", "*"); if (!(handler instanceof HandlerMethod)) { return true; } HandlerMethod handlerMethod = (HandlerMethod) handler; Method method = handlerMethod.getMethod(); if (method.isAnnotationPresent(ManagerAuth.class)){ src/main/java/com/zy/common/config/CoolExceptionHandler.java
@@ -7,7 +7,6 @@ import org.springframework.web.HttpRequestMethodNotSupportedException; import org.springframework.web.bind.annotation.ExceptionHandler; import org.springframework.web.bind.annotation.RestControllerAdvice; import org.springframework.web.method.HandlerMethod; /** * Created by vincent on 2019-06-09 @@ -19,7 +18,7 @@ private I18nMessageService i18nMessageService; @ExceptionHandler(Exception.class) public R handlerException(HandlerMethod handler, Exception e) { public R handlerException(Exception e) { e.printStackTrace(); return R.error(i18nMessageService.getMessage("response.common.systemError")); } src/main/resources/application.yml
@@ -51,6 +51,26 @@ await-termination-period: 30s lifecycle: timeout-per-shutdown-phase: 20s ai: mcp: server: base-url: "${app.ai.mcp.server.public-base-url:http://127.0.0.1:${server.port:9090}${server.servlet.context-path:}}" name: wcs-mcp version: 1.0.0 protocol: STREAMABLE type: SYNC sse-endpoint: /ai/mcp/sse sse-message-endpoint: /ai/mcp/message streamable-http: mcp-endpoint: /ai/mcp instructions: 提供 WCS 设备状态、任务、日志和配置查询能力 annotation-scanner: enabled: false capabilities: tool: true resource: false prompt: false completion: false mybatis-plus: mapper-locations: classpath:mapper/*.xml