package com.zy.ai.service.impl; import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import com.zy.ai.entity.AiMcpMount; import com.zy.ai.enums.AiMcpTransportType; import com.zy.ai.mapper.AiMcpMountMapper; import com.zy.ai.service.AiMcpMountService; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; import java.net.URI; import java.util.ArrayList; import java.util.Date; import java.util.LinkedHashMap; import java.util.List; @Slf4j @Service("aiMcpMountService") public class AiMcpMountServiceImpl extends ServiceImpl implements AiMcpMountService { private static final int DEFAULT_TIMEOUT_MS = 20000; private static final int DEFAULT_PRIORITY = 100; private static final String DEFAULT_LOCAL_MOUNT_CODE = "wcs_local"; @Value("${spring.ai.mcp.server.sse-endpoint:/ai/mcp/sse}") private String defaultSseEndpoint; @Value("${spring.ai.mcp.server.streamable-http.mcp-endpoint:/ai/mcp}") private String defaultStreamableEndpoint; @Value("${app.ai.mcp.server.public-base-url:http://127.0.0.1:${server.port:9090}${server.servlet.context-path:}}") private String defaultLocalBaseUrl; @Override public List listOrdered() { return this.list(new QueryWrapper() .orderByAsc("priority") .orderByAsc("id")); } @Override public List listEnabledOrdered() { return this.list(new QueryWrapper() .eq("status", 1) .orderByAsc("priority") .orderByAsc("id")); } @Override @Transactional(rollbackFor = Exception.class) public AiMcpMount saveMount(AiMcpMount mount) { AiMcpMount candidate = prepareMountDraft(mount); Date now = new Date(); if (candidate.getId() == null) { candidate.setCreateTime(now); candidate.setUpdateTime(now); this.save(candidate); return candidate; } AiMcpMount db = this.getById(candidate.getId()); if (db == null) { throw new IllegalArgumentException("MCP挂载不存在"); } candidate.setCreateTime(db.getCreateTime()); candidate.setLastTestOk(db.getLastTestOk()); candidate.setLastTestTime(db.getLastTestTime()); candidate.setLastTestSummary(db.getLastTestSummary()); candidate.setUpdateTime(now); this.updateById(candidate); return candidate; } @Override @Transactional(rollbackFor = Exception.class) public boolean deleteMount(Long id) { if (id == null) { return false; } AiMcpMount db = this.getById(id); if (db == null) { return false; } return this.removeById(id); } @Override @Transactional(rollbackFor = Exception.class) public int initDefaultsIfMissing() { AiMcpMount existing = this.getOne(new QueryWrapper() .eq("mount_code", DEFAULT_LOCAL_MOUNT_CODE) .last("limit 1")); if (existing != null) { boolean changed = false; AiMcpTransportType transportType = AiMcpTransportType.ofCode(existing.getTransportType()); if (transportType == null) { existing.setTransportType(AiMcpTransportType.SSE.getCode()); changed = true; } String expectedEndpoint = defaultEndpoint(AiMcpTransportType.ofCode(existing.getTransportType())); String expectedUrl = defaultUrl(AiMcpTransportType.ofCode(existing.getTransportType())); if (isBlank(existing.getUrl()) || isLegacyLocalUrl(existing.getUrl(), expectedEndpoint)) { existing.setUrl(expectedUrl); changed = true; } if (changed) { existing.setUpdateTime(new Date()); this.updateById(existing); return 1; } return 0; } AiMcpMount mount = new AiMcpMount(); mount.setName("WCS默认MCP"); mount.setMountCode(DEFAULT_LOCAL_MOUNT_CODE); mount.setTransportType(AiMcpTransportType.SSE.getCode()); mount.setUrl(defaultUrl(AiMcpTransportType.SSE)); mount.setRequestTimeoutMs(DEFAULT_TIMEOUT_MS); mount.setPriority(0); mount.setStatus((short) 1); mount.setMemo("默认挂载当前WCS自身的MCP服务,AI助手也通过挂载配置访问本系统工具"); Date now = new Date(); mount.setCreateTime(now); mount.setUpdateTime(now); this.save(mount); return 1; } @Override public void recordTestResult(Long id, boolean ok, String summary) { if (id == null) { return; } AiMcpMount db = this.getById(id); if (db == null) { return; } db.setLastTestOk(ok ? (short) 1 : (short) 0); db.setLastTestTime(new Date()); db.setLastTestSummary(cut(trim(summary), 1000)); db.setUpdateTime(new Date()); this.updateById(db); } @Override public AiMcpMount prepareMountDraft(AiMcpMount mount) { if (mount == null) { throw new IllegalArgumentException("参数不能为空"); } AiMcpMount candidate = new AiMcpMount(); candidate.setId(mount.getId()); candidate.setName(trim(mount.getName())); candidate.setMountCode(normalizeIdentifier(mount.getMountCode())); candidate.setTransportType(normalizeTransportType(mount.getTransportType()).getCode()); candidate.setUrl(normalizeUrl(mount.getUrl(), AiMcpTransportType.ofCode(candidate.getTransportType()))); candidate.setRequestTimeoutMs(normalizeTimeout(mount.getRequestTimeoutMs())); candidate.setPriority(mount.getPriority() == null ? DEFAULT_PRIORITY : Math.max(0, mount.getPriority())); candidate.setStatus(normalizeShort(mount.getStatus(), (short) 1)); candidate.setMemo(cut(trim(mount.getMemo()), 1000)); if (isBlank(candidate.getMountCode())) { throw new IllegalArgumentException("必须填写挂载编码"); } if (isBlank(candidate.getName())) { candidate.setName(candidate.getMountCode()); } if (isBlank(candidate.getUrl())) { throw new IllegalArgumentException("必须填写URL"); } AiMcpMount duplicate = this.getOne(new QueryWrapper() .eq("mount_code", candidate.getMountCode()) .ne(candidate.getId() != null, "id", candidate.getId()) .last("limit 1")); if (duplicate != null) { throw new IllegalArgumentException("挂载编码已存在:" + candidate.getMountCode()); } return candidate; } @Override public List> listSupportedTransportTypes() { List> result = new ArrayList>(); for (AiMcpTransportType item : AiMcpTransportType.values()) { LinkedHashMap row = new LinkedHashMap(); row.put("code", item.getCode()); row.put("label", item.getLabel()); row.put("defaultUrl", defaultUrl(item)); result.add(row); } return result; } private AiMcpTransportType normalizeTransportType(String raw) { AiMcpTransportType transportType = AiMcpTransportType.ofCode(raw); if (transportType == null) { throw new IllegalArgumentException("不支持的MCP传输类型:" + raw); } return transportType; } private Short normalizeShort(Short value, Short defaultValue) { return value == null ? defaultValue : value; } private int normalizeTimeout(Integer requestTimeoutMs) { int timeout = requestTimeoutMs == null ? DEFAULT_TIMEOUT_MS : requestTimeoutMs; if (timeout < 1000) { timeout = 1000; } if (timeout > 300000) { timeout = 300000; } return timeout; } private String normalizeUrl(String url, AiMcpTransportType transportType) { String value = trim(url); if (isBlank(value)) { return defaultUrl(transportType); } while (value.endsWith("/") && value.length() > "http://x".length()) { value = value.substring(0, value.length() - 1); } if (!value.startsWith("http://") && !value.startsWith("https://")) { throw new IllegalArgumentException("URL必须以 http:// 或 https:// 开头"); } try { URI uri = URI.create(value); if (isBlank(uri.getScheme()) || isBlank(uri.getHost())) { throw new IllegalArgumentException("URL格式不正确"); } if (isBlank(uri.getPath()) || "/".equals(uri.getPath())) { throw new IllegalArgumentException("URL必须包含完整的MCP路径"); } return value; } catch (IllegalArgumentException e) { throw e; } catch (Exception e) { throw new IllegalArgumentException("URL格式不正确"); } } private String defaultEndpoint(AiMcpTransportType transportType) { if (transportType == AiMcpTransportType.STREAMABLE_HTTP) { String endpoint = trim(defaultStreamableEndpoint); if (isBlank(endpoint)) { endpoint = AiMcpTransportType.STREAMABLE_HTTP.getDefaultEndpoint(); } if (!endpoint.startsWith("/")) { endpoint = "/" + endpoint; } return endpoint; } String endpoint = trim(defaultSseEndpoint); if (isBlank(endpoint)) { endpoint = AiMcpTransportType.SSE.getDefaultEndpoint(); } if (!endpoint.startsWith("/")) { endpoint = "/" + endpoint; } return endpoint; } private String resolveDefaultLocalBaseUrl() { String value = trim(defaultLocalBaseUrl); if (isBlank(value)) { return null; } while (value.endsWith("/")) { value = value.substring(0, value.length() - 1); } return value; } private String defaultUrl(AiMcpTransportType transportType) { String baseUrl = resolveDefaultLocalBaseUrl(); String endpoint = defaultEndpoint(transportType); if (isBlank(baseUrl)) { return endpoint; } return baseUrl + endpoint; } private boolean isLegacyLocalUrl(String url, String expectedEndpoint) { String current = trim(url); String targetBase = resolveDefaultLocalBaseUrl(); if (isBlank(current) || isBlank(targetBase) || isBlank(expectedEndpoint)) { return false; } try { URI currentUri = URI.create(current); URI targetBaseUri = URI.create(targetBase); String currentPath = trim(currentUri.getPath()); if (isBlank(currentPath) || "/".equals(currentPath)) { return true; } String expectedPath = expectedEndpoint.startsWith("/") ? expectedEndpoint : ("/" + expectedEndpoint); if (currentPath.equals(expectedPath)) { return sameOrigin(currentUri, targetBaseUri); } String targetPath = trim(targetBaseUri.getPath()); if (isBlank(targetPath)) { return false; } String expectedFullPath = targetPath + expectedPath; return sameOrigin(currentUri, targetBaseUri) && currentPath.equals(expectedFullPath); } catch (Exception e) { log.warn("Failed to inspect MCP mount url for legacy migration, url={}", current, e); return false; } } private boolean sameOrigin(URI left, URI right) { return equalsIgnoreCase(left.getScheme(), right.getScheme()) && equalsIgnoreCase(left.getHost(), right.getHost()) && effectivePort(left) == effectivePort(right); } private boolean equalsIgnoreCase(String left, String right) { if (left == null) { return right == null; } return left.equalsIgnoreCase(right); } private int effectivePort(URI uri) { if (uri == null) { return -1; } if (uri.getPort() > 0) { return uri.getPort(); } String scheme = uri.getScheme(); if ("https".equalsIgnoreCase(scheme)) { return 443; } if ("http".equalsIgnoreCase(scheme)) { return 80; } return -1; } private String normalizeIdentifier(String text) { String value = trim(text); if (isBlank(value)) { return null; } value = value.replaceAll("[^0-9A-Za-z_]+", "_"); value = value.replaceAll("_+", "_"); value = value.replaceAll("^_+", "").replaceAll("_+$", ""); return value.toLowerCase(); } private String trim(String text) { return text == null ? null : text.trim(); } private String cut(String text, int maxLen) { if (text == null) { return null; } return text.length() > maxLen ? text.substring(0, maxLen) : text; } private boolean isBlank(String text) { return text == null || text.trim().isEmpty(); } }