package com.zy.ai.service.impl;
|
|
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
import com.zy.ai.entity.AiMcpMount;
|
import com.zy.ai.enums.AiMcpTransportType;
|
import com.zy.ai.mapper.AiMcpMountMapper;
|
import com.zy.ai.service.AiMcpMountService;
|
import lombok.extern.slf4j.Slf4j;
|
import org.springframework.beans.factory.annotation.Value;
|
import org.springframework.stereotype.Service;
|
import org.springframework.transaction.annotation.Transactional;
|
|
import java.net.URI;
|
import java.util.ArrayList;
|
import java.util.Date;
|
import java.util.LinkedHashMap;
|
import java.util.List;
|
|
@Slf4j
|
@Service("aiMcpMountService")
|
public class AiMcpMountServiceImpl extends ServiceImpl<AiMcpMountMapper, AiMcpMount> implements AiMcpMountService {
|
|
private static final int DEFAULT_TIMEOUT_MS = 20000;
|
private static final int DEFAULT_PRIORITY = 100;
|
private static final String DEFAULT_LOCAL_MOUNT_CODE = "wcs_local";
|
|
@Value("${spring.ai.mcp.server.sse-endpoint:/ai/mcp/sse}")
|
private String defaultSseEndpoint;
|
|
@Value("${spring.ai.mcp.server.streamable-http.mcp-endpoint:/ai/mcp}")
|
private String defaultStreamableEndpoint;
|
|
@Value("${app.ai.mcp.server.public-base-url:http://127.0.0.1:${server.port:9090}${server.servlet.context-path:}}")
|
private String defaultLocalBaseUrl;
|
|
@Override
|
public List<AiMcpMount> listOrdered() {
|
return this.list(new QueryWrapper<AiMcpMount>()
|
.orderByAsc("priority")
|
.orderByAsc("id"));
|
}
|
|
@Override
|
public List<AiMcpMount> listEnabledOrdered() {
|
return this.list(new QueryWrapper<AiMcpMount>()
|
.eq("status", 1)
|
.orderByAsc("priority")
|
.orderByAsc("id"));
|
}
|
|
@Override
|
@Transactional(rollbackFor = Exception.class)
|
public AiMcpMount saveMount(AiMcpMount mount) {
|
AiMcpMount candidate = prepareMountDraft(mount);
|
Date now = new Date();
|
if (candidate.getId() == null) {
|
candidate.setCreateTime(now);
|
candidate.setUpdateTime(now);
|
this.save(candidate);
|
return candidate;
|
}
|
|
AiMcpMount db = this.getById(candidate.getId());
|
if (db == null) {
|
throw new IllegalArgumentException("MCP挂载不存在");
|
}
|
candidate.setCreateTime(db.getCreateTime());
|
candidate.setLastTestOk(db.getLastTestOk());
|
candidate.setLastTestTime(db.getLastTestTime());
|
candidate.setLastTestSummary(db.getLastTestSummary());
|
candidate.setUpdateTime(now);
|
this.updateById(candidate);
|
return candidate;
|
}
|
|
@Override
|
@Transactional(rollbackFor = Exception.class)
|
public boolean deleteMount(Long id) {
|
if (id == null) {
|
return false;
|
}
|
AiMcpMount db = this.getById(id);
|
if (db == null) {
|
return false;
|
}
|
return this.removeById(id);
|
}
|
|
@Override
|
@Transactional(rollbackFor = Exception.class)
|
public int initDefaultsIfMissing() {
|
AiMcpMount existing = this.getOne(new QueryWrapper<AiMcpMount>()
|
.eq("mount_code", DEFAULT_LOCAL_MOUNT_CODE)
|
.last("limit 1"));
|
if (existing != null) {
|
boolean changed = false;
|
AiMcpTransportType transportType = AiMcpTransportType.ofCode(existing.getTransportType());
|
if (transportType == null) {
|
existing.setTransportType(AiMcpTransportType.SSE.getCode());
|
changed = true;
|
}
|
String expectedEndpoint = defaultEndpoint(AiMcpTransportType.ofCode(existing.getTransportType()));
|
String expectedUrl = defaultUrl(AiMcpTransportType.ofCode(existing.getTransportType()));
|
if (isBlank(existing.getUrl()) || isLegacyLocalUrl(existing.getUrl(), expectedEndpoint)) {
|
existing.setUrl(expectedUrl);
|
changed = true;
|
}
|
if (changed) {
|
existing.setUpdateTime(new Date());
|
this.updateById(existing);
|
return 1;
|
}
|
return 0;
|
}
|
|
AiMcpMount mount = new AiMcpMount();
|
mount.setName("WCS默认MCP");
|
mount.setMountCode(DEFAULT_LOCAL_MOUNT_CODE);
|
mount.setTransportType(AiMcpTransportType.SSE.getCode());
|
mount.setUrl(defaultUrl(AiMcpTransportType.SSE));
|
mount.setRequestTimeoutMs(DEFAULT_TIMEOUT_MS);
|
mount.setPriority(0);
|
mount.setStatus((short) 1);
|
mount.setMemo("默认挂载当前WCS自身的MCP服务,AI助手也通过挂载配置访问本系统工具");
|
Date now = new Date();
|
mount.setCreateTime(now);
|
mount.setUpdateTime(now);
|
this.save(mount);
|
return 1;
|
}
|
|
@Override
|
public void recordTestResult(Long id, boolean ok, String summary) {
|
if (id == null) {
|
return;
|
}
|
AiMcpMount db = this.getById(id);
|
if (db == null) {
|
return;
|
}
|
db.setLastTestOk(ok ? (short) 1 : (short) 0);
|
db.setLastTestTime(new Date());
|
db.setLastTestSummary(cut(trim(summary), 1000));
|
db.setUpdateTime(new Date());
|
this.updateById(db);
|
}
|
|
@Override
|
public AiMcpMount prepareMountDraft(AiMcpMount mount) {
|
if (mount == null) {
|
throw new IllegalArgumentException("参数不能为空");
|
}
|
AiMcpMount candidate = new AiMcpMount();
|
candidate.setId(mount.getId());
|
candidate.setName(trim(mount.getName()));
|
candidate.setMountCode(normalizeIdentifier(mount.getMountCode()));
|
candidate.setTransportType(normalizeTransportType(mount.getTransportType()).getCode());
|
candidate.setUrl(normalizeUrl(mount.getUrl(), AiMcpTransportType.ofCode(candidate.getTransportType())));
|
candidate.setRequestTimeoutMs(normalizeTimeout(mount.getRequestTimeoutMs()));
|
candidate.setPriority(mount.getPriority() == null ? DEFAULT_PRIORITY : Math.max(0, mount.getPriority()));
|
candidate.setStatus(normalizeShort(mount.getStatus(), (short) 1));
|
candidate.setMemo(cut(trim(mount.getMemo()), 1000));
|
|
if (isBlank(candidate.getMountCode())) {
|
throw new IllegalArgumentException("必须填写挂载编码");
|
}
|
if (isBlank(candidate.getName())) {
|
candidate.setName(candidate.getMountCode());
|
}
|
if (isBlank(candidate.getUrl())) {
|
throw new IllegalArgumentException("必须填写URL");
|
}
|
|
AiMcpMount duplicate = this.getOne(new QueryWrapper<AiMcpMount>()
|
.eq("mount_code", candidate.getMountCode())
|
.ne(candidate.getId() != null, "id", candidate.getId())
|
.last("limit 1"));
|
if (duplicate != null) {
|
throw new IllegalArgumentException("挂载编码已存在:" + candidate.getMountCode());
|
}
|
return candidate;
|
}
|
|
@Override
|
public List<java.util.Map<String, Object>> listSupportedTransportTypes() {
|
List<java.util.Map<String, Object>> result = new ArrayList<java.util.Map<String, Object>>();
|
for (AiMcpTransportType item : AiMcpTransportType.values()) {
|
LinkedHashMap<String, Object> row = new LinkedHashMap<String, Object>();
|
row.put("code", item.getCode());
|
row.put("label", item.getLabel());
|
row.put("defaultUrl", defaultUrl(item));
|
result.add(row);
|
}
|
return result;
|
}
|
|
private AiMcpTransportType normalizeTransportType(String raw) {
|
AiMcpTransportType transportType = AiMcpTransportType.ofCode(raw);
|
if (transportType == null) {
|
throw new IllegalArgumentException("不支持的MCP传输类型:" + raw);
|
}
|
return transportType;
|
}
|
|
private Short normalizeShort(Short value, Short defaultValue) {
|
return value == null ? defaultValue : value;
|
}
|
|
private int normalizeTimeout(Integer requestTimeoutMs) {
|
int timeout = requestTimeoutMs == null ? DEFAULT_TIMEOUT_MS : requestTimeoutMs;
|
if (timeout < 1000) {
|
timeout = 1000;
|
}
|
if (timeout > 300000) {
|
timeout = 300000;
|
}
|
return timeout;
|
}
|
|
private String normalizeUrl(String url, AiMcpTransportType transportType) {
|
String value = trim(url);
|
if (isBlank(value)) {
|
return defaultUrl(transportType);
|
}
|
while (value.endsWith("/") && value.length() > "http://x".length()) {
|
value = value.substring(0, value.length() - 1);
|
}
|
if (!value.startsWith("http://") && !value.startsWith("https://")) {
|
throw new IllegalArgumentException("URL必须以 http:// 或 https:// 开头");
|
}
|
try {
|
URI uri = URI.create(value);
|
if (isBlank(uri.getScheme()) || isBlank(uri.getHost())) {
|
throw new IllegalArgumentException("URL格式不正确");
|
}
|
if (isBlank(uri.getPath()) || "/".equals(uri.getPath())) {
|
throw new IllegalArgumentException("URL必须包含完整的MCP路径");
|
}
|
return value;
|
} catch (IllegalArgumentException e) {
|
throw e;
|
} catch (Exception e) {
|
throw new IllegalArgumentException("URL格式不正确");
|
}
|
}
|
|
private String defaultEndpoint(AiMcpTransportType transportType) {
|
if (transportType == AiMcpTransportType.STREAMABLE_HTTP) {
|
String endpoint = trim(defaultStreamableEndpoint);
|
if (isBlank(endpoint)) {
|
endpoint = AiMcpTransportType.STREAMABLE_HTTP.getDefaultEndpoint();
|
}
|
if (!endpoint.startsWith("/")) {
|
endpoint = "/" + endpoint;
|
}
|
return endpoint;
|
}
|
String endpoint = trim(defaultSseEndpoint);
|
if (isBlank(endpoint)) {
|
endpoint = AiMcpTransportType.SSE.getDefaultEndpoint();
|
}
|
if (!endpoint.startsWith("/")) {
|
endpoint = "/" + endpoint;
|
}
|
return endpoint;
|
}
|
|
private String resolveDefaultLocalBaseUrl() {
|
String value = trim(defaultLocalBaseUrl);
|
if (isBlank(value)) {
|
return null;
|
}
|
while (value.endsWith("/")) {
|
value = value.substring(0, value.length() - 1);
|
}
|
return value;
|
}
|
|
private String defaultUrl(AiMcpTransportType transportType) {
|
String baseUrl = resolveDefaultLocalBaseUrl();
|
String endpoint = defaultEndpoint(transportType);
|
if (isBlank(baseUrl)) {
|
return endpoint;
|
}
|
return baseUrl + endpoint;
|
}
|
|
private boolean isLegacyLocalUrl(String url, String expectedEndpoint) {
|
String current = trim(url);
|
String targetBase = resolveDefaultLocalBaseUrl();
|
if (isBlank(current) || isBlank(targetBase) || isBlank(expectedEndpoint)) {
|
return false;
|
}
|
try {
|
URI currentUri = URI.create(current);
|
URI targetBaseUri = URI.create(targetBase);
|
String currentPath = trim(currentUri.getPath());
|
if (isBlank(currentPath) || "/".equals(currentPath)) {
|
return true;
|
}
|
String expectedPath = expectedEndpoint.startsWith("/") ? expectedEndpoint : ("/" + expectedEndpoint);
|
if (currentPath.equals(expectedPath)) {
|
return sameOrigin(currentUri, targetBaseUri);
|
}
|
String targetPath = trim(targetBaseUri.getPath());
|
if (isBlank(targetPath)) {
|
return false;
|
}
|
String expectedFullPath = targetPath + expectedPath;
|
return sameOrigin(currentUri, targetBaseUri) && currentPath.equals(expectedFullPath);
|
} catch (Exception e) {
|
log.warn("Failed to inspect MCP mount url for legacy migration, url={}", current, e);
|
return false;
|
}
|
}
|
|
private boolean sameOrigin(URI left, URI right) {
|
return equalsIgnoreCase(left.getScheme(), right.getScheme())
|
&& equalsIgnoreCase(left.getHost(), right.getHost())
|
&& effectivePort(left) == effectivePort(right);
|
}
|
|
private boolean equalsIgnoreCase(String left, String right) {
|
if (left == null) {
|
return right == null;
|
}
|
return left.equalsIgnoreCase(right);
|
}
|
|
private int effectivePort(URI uri) {
|
if (uri == null) {
|
return -1;
|
}
|
if (uri.getPort() > 0) {
|
return uri.getPort();
|
}
|
String scheme = uri.getScheme();
|
if ("https".equalsIgnoreCase(scheme)) {
|
return 443;
|
}
|
if ("http".equalsIgnoreCase(scheme)) {
|
return 80;
|
}
|
return -1;
|
}
|
|
private String normalizeIdentifier(String text) {
|
String value = trim(text);
|
if (isBlank(value)) {
|
return null;
|
}
|
value = value.replaceAll("[^0-9A-Za-z_]+", "_");
|
value = value.replaceAll("_+", "_");
|
value = value.replaceAll("^_+", "").replaceAll("_+$", "");
|
return value.toLowerCase();
|
}
|
|
private String trim(String text) {
|
return text == null ? null : text.trim();
|
}
|
|
private String cut(String text, int maxLen) {
|
if (text == null) {
|
return null;
|
}
|
return text.length() > maxLen ? text.substring(0, maxLen) : text;
|
}
|
|
private boolean isBlank(String text) {
|
return text == null || text.trim().isEmpty();
|
}
|
|
}
|