package com.vincent.rsf.server.ai.service.impl;
|
|
import com.vincent.rsf.framework.exception.CoolException;
|
import com.vincent.rsf.server.ai.config.AiDefaults;
|
import com.vincent.rsf.server.ai.entity.AiMcpMount;
|
import com.vincent.rsf.server.ai.service.BuiltinMcpToolRegistry;
|
import com.vincent.rsf.server.ai.tool.RsfWmsBaseTools;
|
import com.vincent.rsf.server.ai.tool.RsfWmsStockTools;
|
import com.vincent.rsf.server.ai.tool.RsfWmsTaskTools;
|
import lombok.RequiredArgsConstructor;
|
import org.springframework.ai.support.ToolCallbacks;
|
import org.springframework.ai.tool.ToolCallback;
|
import org.springframework.stereotype.Service;
|
import org.springframework.util.StringUtils;
|
|
import java.util.ArrayList;
|
import java.util.Arrays;
|
import java.util.List;
|
|
@Service
|
@RequiredArgsConstructor
|
public class BuiltinMcpToolRegistryImpl implements BuiltinMcpToolRegistry {
|
|
private final RsfWmsStockTools rsfWmsStockTools;
|
private final RsfWmsTaskTools rsfWmsTaskTools;
|
private final RsfWmsBaseTools rsfWmsBaseTools;
|
|
@Override
|
public void validateBuiltinCode(String builtinCode) {
|
if (!StringUtils.hasText(builtinCode)) {
|
throw new CoolException("内置 MCP 编码不能为空");
|
}
|
if (!supportedBuiltinCodes().contains(builtinCode)) {
|
throw new CoolException("不支持的内置 MCP 编码: " + builtinCode);
|
}
|
}
|
|
@Override
|
public List<ToolCallback> createToolCallbacks(AiMcpMount mount, Long userId) {
|
String builtinCode = mount.getBuiltinCode();
|
validateBuiltinCode(builtinCode);
|
if (AiDefaults.MCP_BUILTIN_RSF_WMS.equals(builtinCode)) {
|
List<ToolCallback> callbacks = new ArrayList<>();
|
callbacks.addAll(Arrays.asList(ToolCallbacks.from(rsfWmsStockTools)));
|
callbacks.addAll(Arrays.asList(ToolCallbacks.from(rsfWmsTaskTools)));
|
callbacks.addAll(Arrays.asList(ToolCallbacks.from(rsfWmsBaseTools)));
|
return callbacks;
|
}
|
if (AiDefaults.MCP_BUILTIN_RSF_WMS_STOCK.equals(builtinCode)) {
|
return Arrays.asList(ToolCallbacks.from(rsfWmsStockTools));
|
}
|
if (AiDefaults.MCP_BUILTIN_RSF_WMS_TASK.equals(builtinCode)) {
|
return Arrays.asList(ToolCallbacks.from(rsfWmsTaskTools));
|
}
|
if (AiDefaults.MCP_BUILTIN_RSF_WMS_BASE.equals(builtinCode)) {
|
return Arrays.asList(ToolCallbacks.from(rsfWmsBaseTools));
|
}
|
throw new CoolException("不支持的内置 MCP 编码: " + builtinCode);
|
}
|
|
private List<String> supportedBuiltinCodes() {
|
return List.of(
|
AiDefaults.MCP_BUILTIN_RSF_WMS,
|
AiDefaults.MCP_BUILTIN_RSF_WMS_STOCK,
|
AiDefaults.MCP_BUILTIN_RSF_WMS_TASK,
|
AiDefaults.MCP_BUILTIN_RSF_WMS_BASE
|
);
|
}
|
}
|