zhou zhou
8 小时以前 3d81df739dc45599c257d8cdefe0996f66ccdeae
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
package com.vincent.rsf.server.ai.service.impl;
 
import com.vincent.rsf.framework.exception.CoolException;
import com.vincent.rsf.server.ai.config.AiDefaults;
import com.vincent.rsf.server.ai.entity.AiMcpMount;
import com.vincent.rsf.server.ai.service.BuiltinMcpToolRegistry;
import com.vincent.rsf.server.ai.tool.RsfWmsBaseTools;
import com.vincent.rsf.server.ai.tool.RsfWmsStockTools;
import com.vincent.rsf.server.ai.tool.RsfWmsTaskTools;
import lombok.RequiredArgsConstructor;
import org.springframework.ai.support.ToolCallbacks;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
 
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
 
@Service
@RequiredArgsConstructor
public class BuiltinMcpToolRegistryImpl implements BuiltinMcpToolRegistry {
 
    private final RsfWmsStockTools rsfWmsStockTools;
    private final RsfWmsTaskTools rsfWmsTaskTools;
    private final RsfWmsBaseTools rsfWmsBaseTools;
 
    @Override
    public void validateBuiltinCode(String builtinCode) {
        if (!StringUtils.hasText(builtinCode)) {
            throw new CoolException("内置 MCP 编码不能为空");
        }
        if (!supportedBuiltinCodes().contains(builtinCode)) {
            throw new CoolException("不支持的内置 MCP 编码: " + builtinCode);
        }
    }
 
    @Override
    public List<ToolCallback> createToolCallbacks(AiMcpMount mount, Long userId) {
        String builtinCode = mount.getBuiltinCode();
        validateBuiltinCode(builtinCode);
        if (AiDefaults.MCP_BUILTIN_RSF_WMS.equals(builtinCode)) {
            List<ToolCallback> callbacks = new ArrayList<>();
            callbacks.addAll(Arrays.asList(ToolCallbacks.from(rsfWmsStockTools)));
            callbacks.addAll(Arrays.asList(ToolCallbacks.from(rsfWmsTaskTools)));
            callbacks.addAll(Arrays.asList(ToolCallbacks.from(rsfWmsBaseTools)));
            return callbacks;
        }
        if (AiDefaults.MCP_BUILTIN_RSF_WMS_STOCK.equals(builtinCode)) {
            return Arrays.asList(ToolCallbacks.from(rsfWmsStockTools));
        }
        if (AiDefaults.MCP_BUILTIN_RSF_WMS_TASK.equals(builtinCode)) {
            return Arrays.asList(ToolCallbacks.from(rsfWmsTaskTools));
        }
        if (AiDefaults.MCP_BUILTIN_RSF_WMS_BASE.equals(builtinCode)) {
            return Arrays.asList(ToolCallbacks.from(rsfWmsBaseTools));
        }
        throw new CoolException("不支持的内置 MCP 编码: " + builtinCode);
    }
 
    private List<String> supportedBuiltinCodes() {
        return List.of(
                AiDefaults.MCP_BUILTIN_RSF_WMS,
                AiDefaults.MCP_BUILTIN_RSF_WMS_STOCK,
                AiDefaults.MCP_BUILTIN_RSF_WMS_TASK,
                AiDefaults.MCP_BUILTIN_RSF_WMS_BASE
        );
    }
}