package com.vincent.rsf.server.ai.service.impl; import com.vincent.rsf.framework.exception.CoolException; import com.vincent.rsf.server.ai.config.AiDefaults; import com.vincent.rsf.server.ai.entity.AiMcpMount; import com.vincent.rsf.server.ai.service.BuiltinMcpToolRegistry; import com.vincent.rsf.server.ai.tool.RsfWmsBaseTools; import com.vincent.rsf.server.ai.tool.RsfWmsStockTools; import com.vincent.rsf.server.ai.tool.RsfWmsTaskTools; import lombok.RequiredArgsConstructor; import org.springframework.ai.support.ToolCallbacks; import org.springframework.ai.tool.ToolCallback; import org.springframework.stereotype.Service; import org.springframework.util.StringUtils; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @Service @RequiredArgsConstructor public class BuiltinMcpToolRegistryImpl implements BuiltinMcpToolRegistry { private final RsfWmsStockTools rsfWmsStockTools; private final RsfWmsTaskTools rsfWmsTaskTools; private final RsfWmsBaseTools rsfWmsBaseTools; @Override public void validateBuiltinCode(String builtinCode) { if (!StringUtils.hasText(builtinCode)) { throw new CoolException("内置 MCP 编码不能为空"); } if (!supportedBuiltinCodes().contains(builtinCode)) { throw new CoolException("不支持的内置 MCP 编码: " + builtinCode); } } @Override public List createToolCallbacks(AiMcpMount mount, Long userId) { String builtinCode = mount.getBuiltinCode(); validateBuiltinCode(builtinCode); if (AiDefaults.MCP_BUILTIN_RSF_WMS.equals(builtinCode)) { List callbacks = new ArrayList<>(); callbacks.addAll(Arrays.asList(ToolCallbacks.from(rsfWmsStockTools))); callbacks.addAll(Arrays.asList(ToolCallbacks.from(rsfWmsTaskTools))); callbacks.addAll(Arrays.asList(ToolCallbacks.from(rsfWmsBaseTools))); return callbacks; } if (AiDefaults.MCP_BUILTIN_RSF_WMS_STOCK.equals(builtinCode)) { return Arrays.asList(ToolCallbacks.from(rsfWmsStockTools)); } if (AiDefaults.MCP_BUILTIN_RSF_WMS_TASK.equals(builtinCode)) { return Arrays.asList(ToolCallbacks.from(rsfWmsTaskTools)); } if (AiDefaults.MCP_BUILTIN_RSF_WMS_BASE.equals(builtinCode)) { return Arrays.asList(ToolCallbacks.from(rsfWmsBaseTools)); } throw new CoolException("不支持的内置 MCP 编码: " + builtinCode); } private List supportedBuiltinCodes() { return List.of( AiDefaults.MCP_BUILTIN_RSF_WMS, AiDefaults.MCP_BUILTIN_RSF_WMS_STOCK, AiDefaults.MCP_BUILTIN_RSF_WMS_TASK, AiDefaults.MCP_BUILTIN_RSF_WMS_BASE ); } }