#
Junjie
3 天以前 1b8a4677f362d234d834120deac4880d7ae89a50
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
package com.zy.ai.service.impl;
 
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.zy.ai.entity.AiMcpMount;
import com.zy.ai.enums.AiMcpTransportType;
import com.zy.ai.mapper.AiMcpMountMapper;
import com.zy.ai.service.AiMcpMountService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
 
import java.net.URI;
import java.util.ArrayList;
import java.util.Date;
import java.util.LinkedHashMap;
import java.util.List;
 
@Slf4j
@Service("aiMcpMountService")
public class AiMcpMountServiceImpl extends ServiceImpl<AiMcpMountMapper, AiMcpMount> implements AiMcpMountService {
 
    private static final int DEFAULT_TIMEOUT_MS = 20000;
    private static final int DEFAULT_PRIORITY = 100;
    private static final String DEFAULT_LOCAL_MOUNT_CODE = "wcs_local";
 
    @Value("${spring.ai.mcp.server.sse-endpoint:/ai/mcp/sse}")
    private String defaultSseEndpoint;
 
    @Value("${spring.ai.mcp.server.streamable-http.mcp-endpoint:/ai/mcp}")
    private String defaultStreamableEndpoint;
 
    @Value("${app.ai.mcp.server.public-base-url:http://127.0.0.1:${server.port:9090}${server.servlet.context-path:}}")
    private String defaultLocalBaseUrl;
 
    @Override
    public List<AiMcpMount> listOrdered() {
        return this.list(new QueryWrapper<AiMcpMount>()
                .orderByAsc("priority")
                .orderByAsc("id"));
    }
 
    @Override
    public List<AiMcpMount> listEnabledOrdered() {
        return this.list(new QueryWrapper<AiMcpMount>()
                .eq("status", 1)
                .orderByAsc("priority")
                .orderByAsc("id"));
    }
 
    @Override
    @Transactional(rollbackFor = Exception.class)
    public AiMcpMount saveMount(AiMcpMount mount) {
        AiMcpMount candidate = prepareMountDraft(mount);
        Date now = new Date();
        if (candidate.getId() == null) {
            candidate.setCreateTime(now);
            candidate.setUpdateTime(now);
            this.save(candidate);
            return candidate;
        }
 
        AiMcpMount db = this.getById(candidate.getId());
        if (db == null) {
            throw new IllegalArgumentException("MCP挂载不存在");
        }
        candidate.setCreateTime(db.getCreateTime());
        candidate.setLastTestOk(db.getLastTestOk());
        candidate.setLastTestTime(db.getLastTestTime());
        candidate.setLastTestSummary(db.getLastTestSummary());
        candidate.setUpdateTime(now);
        this.updateById(candidate);
        return candidate;
    }
 
    @Override
    @Transactional(rollbackFor = Exception.class)
    public boolean deleteMount(Long id) {
        if (id == null) {
            return false;
        }
        AiMcpMount db = this.getById(id);
        if (db == null) {
            return false;
        }
        return this.removeById(id);
    }
 
    @Override
    @Transactional(rollbackFor = Exception.class)
    public int initDefaultsIfMissing() {
        AiMcpMount existing = this.getOne(new QueryWrapper<AiMcpMount>()
                .eq("mount_code", DEFAULT_LOCAL_MOUNT_CODE)
                .last("limit 1"));
        if (existing != null) {
            boolean changed = false;
            AiMcpTransportType transportType = AiMcpTransportType.ofCode(existing.getTransportType());
            if (transportType == null) {
                existing.setTransportType(AiMcpTransportType.SSE.getCode());
                changed = true;
            }
            String expectedEndpoint = defaultEndpoint(AiMcpTransportType.ofCode(existing.getTransportType()));
            String expectedUrl = defaultUrl(AiMcpTransportType.ofCode(existing.getTransportType()));
            if (isBlank(existing.getUrl()) || isLegacyLocalUrl(existing.getUrl(), expectedEndpoint)) {
                existing.setUrl(expectedUrl);
                changed = true;
            }
            if (changed) {
                existing.setUpdateTime(new Date());
                this.updateById(existing);
                return 1;
            }
            return 0;
        }
 
        AiMcpMount mount = new AiMcpMount();
        mount.setName("WCS默认MCP");
        mount.setMountCode(DEFAULT_LOCAL_MOUNT_CODE);
        mount.setTransportType(AiMcpTransportType.SSE.getCode());
        mount.setUrl(defaultUrl(AiMcpTransportType.SSE));
        mount.setRequestTimeoutMs(DEFAULT_TIMEOUT_MS);
        mount.setPriority(0);
        mount.setStatus((short) 1);
        mount.setMemo("默认挂载当前WCS自身的MCP服务,AI助手也通过挂载配置访问本系统工具");
        Date now = new Date();
        mount.setCreateTime(now);
        mount.setUpdateTime(now);
        this.save(mount);
        return 1;
    }
 
    @Override
    public void recordTestResult(Long id, boolean ok, String summary) {
        if (id == null) {
            return;
        }
        AiMcpMount db = this.getById(id);
        if (db == null) {
            return;
        }
        db.setLastTestOk(ok ? (short) 1 : (short) 0);
        db.setLastTestTime(new Date());
        db.setLastTestSummary(cut(trim(summary), 1000));
        db.setUpdateTime(new Date());
        this.updateById(db);
    }
 
    @Override
    public AiMcpMount prepareMountDraft(AiMcpMount mount) {
        if (mount == null) {
            throw new IllegalArgumentException("参数不能为空");
        }
        AiMcpMount candidate = new AiMcpMount();
        candidate.setId(mount.getId());
        candidate.setName(trim(mount.getName()));
        candidate.setMountCode(normalizeIdentifier(mount.getMountCode()));
        candidate.setTransportType(normalizeTransportType(mount.getTransportType()).getCode());
        candidate.setUrl(normalizeUrl(mount.getUrl(), AiMcpTransportType.ofCode(candidate.getTransportType())));
        candidate.setRequestTimeoutMs(normalizeTimeout(mount.getRequestTimeoutMs()));
        candidate.setPriority(mount.getPriority() == null ? DEFAULT_PRIORITY : Math.max(0, mount.getPriority()));
        candidate.setStatus(normalizeShort(mount.getStatus(), (short) 1));
        candidate.setMemo(cut(trim(mount.getMemo()), 1000));
 
        if (isBlank(candidate.getMountCode())) {
            throw new IllegalArgumentException("必须填写挂载编码");
        }
        if (isBlank(candidate.getName())) {
            candidate.setName(candidate.getMountCode());
        }
        if (isBlank(candidate.getUrl())) {
            throw new IllegalArgumentException("必须填写URL");
        }
 
        AiMcpMount duplicate = this.getOne(new QueryWrapper<AiMcpMount>()
                .eq("mount_code", candidate.getMountCode())
                .ne(candidate.getId() != null, "id", candidate.getId())
                .last("limit 1"));
        if (duplicate != null) {
            throw new IllegalArgumentException("挂载编码已存在:" + candidate.getMountCode());
        }
        return candidate;
    }
 
    @Override
    public List<java.util.Map<String, Object>> listSupportedTransportTypes() {
        List<java.util.Map<String, Object>> result = new ArrayList<java.util.Map<String, Object>>();
        for (AiMcpTransportType item : AiMcpTransportType.values()) {
            LinkedHashMap<String, Object> row = new LinkedHashMap<String, Object>();
            row.put("code", item.getCode());
            row.put("label", item.getLabel());
            row.put("defaultUrl", defaultUrl(item));
            result.add(row);
        }
        return result;
    }
 
    private AiMcpTransportType normalizeTransportType(String raw) {
        AiMcpTransportType transportType = AiMcpTransportType.ofCode(raw);
        if (transportType == null) {
            throw new IllegalArgumentException("不支持的MCP传输类型:" + raw);
        }
        return transportType;
    }
 
    private Short normalizeShort(Short value, Short defaultValue) {
        return value == null ? defaultValue : value;
    }
 
    private int normalizeTimeout(Integer requestTimeoutMs) {
        int timeout = requestTimeoutMs == null ? DEFAULT_TIMEOUT_MS : requestTimeoutMs;
        if (timeout < 1000) {
            timeout = 1000;
        }
        if (timeout > 300000) {
            timeout = 300000;
        }
        return timeout;
    }
 
    private String normalizeUrl(String url, AiMcpTransportType transportType) {
        String value = trim(url);
        if (isBlank(value)) {
            return defaultUrl(transportType);
        }
        while (value.endsWith("/") && value.length() > "http://x".length()) {
            value = value.substring(0, value.length() - 1);
        }
        if (!value.startsWith("http://") && !value.startsWith("https://")) {
            throw new IllegalArgumentException("URL必须以 http:// 或 https:// 开头");
        }
        try {
            URI uri = URI.create(value);
            if (isBlank(uri.getScheme()) || isBlank(uri.getHost())) {
                throw new IllegalArgumentException("URL格式不正确");
            }
            if (isBlank(uri.getPath()) || "/".equals(uri.getPath())) {
                throw new IllegalArgumentException("URL必须包含完整的MCP路径");
            }
            return value;
        } catch (IllegalArgumentException e) {
            throw e;
        } catch (Exception e) {
            throw new IllegalArgumentException("URL格式不正确");
        }
    }
 
    private String defaultEndpoint(AiMcpTransportType transportType) {
        if (transportType == AiMcpTransportType.STREAMABLE_HTTP) {
            String endpoint = trim(defaultStreamableEndpoint);
            if (isBlank(endpoint)) {
                endpoint = AiMcpTransportType.STREAMABLE_HTTP.getDefaultEndpoint();
            }
            if (!endpoint.startsWith("/")) {
                endpoint = "/" + endpoint;
            }
            return endpoint;
        }
        String endpoint = trim(defaultSseEndpoint);
        if (isBlank(endpoint)) {
            endpoint = AiMcpTransportType.SSE.getDefaultEndpoint();
        }
        if (!endpoint.startsWith("/")) {
            endpoint = "/" + endpoint;
        }
        return endpoint;
    }
 
    private String resolveDefaultLocalBaseUrl() {
        String value = trim(defaultLocalBaseUrl);
        if (isBlank(value)) {
            return null;
        }
        while (value.endsWith("/")) {
            value = value.substring(0, value.length() - 1);
        }
        return value;
    }
 
    private String defaultUrl(AiMcpTransportType transportType) {
        String baseUrl = resolveDefaultLocalBaseUrl();
        String endpoint = defaultEndpoint(transportType);
        if (isBlank(baseUrl)) {
            return endpoint;
        }
        return baseUrl + endpoint;
    }
 
    private boolean isLegacyLocalUrl(String url, String expectedEndpoint) {
        String current = trim(url);
        String targetBase = resolveDefaultLocalBaseUrl();
        if (isBlank(current) || isBlank(targetBase) || isBlank(expectedEndpoint)) {
            return false;
        }
        try {
            URI currentUri = URI.create(current);
            URI targetBaseUri = URI.create(targetBase);
            String currentPath = trim(currentUri.getPath());
            if (isBlank(currentPath) || "/".equals(currentPath)) {
                return true;
            }
            String expectedPath = expectedEndpoint.startsWith("/") ? expectedEndpoint : ("/" + expectedEndpoint);
            if (currentPath.equals(expectedPath)) {
                return sameOrigin(currentUri, targetBaseUri);
            }
            String targetPath = trim(targetBaseUri.getPath());
            if (isBlank(targetPath)) {
                return false;
            }
            String expectedFullPath = targetPath + expectedPath;
            return sameOrigin(currentUri, targetBaseUri) && currentPath.equals(expectedFullPath);
        } catch (Exception e) {
            log.warn("Failed to inspect MCP mount url for legacy migration, url={}", current, e);
            return false;
        }
    }
 
    private boolean sameOrigin(URI left, URI right) {
        return equalsIgnoreCase(left.getScheme(), right.getScheme())
                && equalsIgnoreCase(left.getHost(), right.getHost())
                && effectivePort(left) == effectivePort(right);
    }
 
    private boolean equalsIgnoreCase(String left, String right) {
        if (left == null) {
            return right == null;
        }
        return left.equalsIgnoreCase(right);
    }
 
    private int effectivePort(URI uri) {
        if (uri == null) {
            return -1;
        }
        if (uri.getPort() > 0) {
            return uri.getPort();
        }
        String scheme = uri.getScheme();
        if ("https".equalsIgnoreCase(scheme)) {
            return 443;
        }
        if ("http".equalsIgnoreCase(scheme)) {
            return 80;
        }
        return -1;
    }
 
    private String normalizeIdentifier(String text) {
        String value = trim(text);
        if (isBlank(value)) {
            return null;
        }
        value = value.replaceAll("[^0-9A-Za-z_]+", "_");
        value = value.replaceAll("_+", "_");
        value = value.replaceAll("^_+", "").replaceAll("_+$", "");
        return value.toLowerCase();
    }
 
    private String trim(String text) {
        return text == null ? null : text.trim();
    }
 
    private String cut(String text, int maxLen) {
        if (text == null) {
            return null;
        }
        return text.length() > maxLen ? text.substring(0, maxLen) : text;
    }
 
    private boolean isBlank(String text) {
        return text == null || text.trim().isEmpty();
    }
 
}