package com.vincent.rsf.server.ai.service.impl;
|
|
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
import com.vincent.rsf.framework.exception.CoolException;
|
import com.vincent.rsf.server.ai.config.AiDefaults;
|
import com.vincent.rsf.server.ai.dto.AiMcpToolPreviewDto;
|
import com.vincent.rsf.server.ai.dto.AiMcpToolTestDto;
|
import com.vincent.rsf.server.ai.dto.AiMcpToolTestRequest;
|
import com.vincent.rsf.server.ai.entity.AiMcpMount;
|
import com.vincent.rsf.server.ai.mapper.AiMcpMountMapper;
|
import com.vincent.rsf.server.ai.service.AiMcpMountService;
|
import com.vincent.rsf.server.ai.service.BuiltinMcpToolRegistry;
|
import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory;
|
import com.vincent.rsf.server.system.enums.StatusType;
|
import lombok.RequiredArgsConstructor;
|
import org.springframework.ai.chat.model.ToolContext;
|
import org.springframework.ai.tool.ToolCallback;
|
import org.springframework.stereotype.Service;
|
import org.springframework.util.StringUtils;
|
|
import java.util.ArrayList;
|
import java.util.Arrays;
|
import java.util.List;
|
import java.util.Map;
|
|
@Service("aiMcpMountService")
|
@RequiredArgsConstructor
|
public class AiMcpMountServiceImpl extends ServiceImpl<AiMcpMountMapper, AiMcpMount> implements AiMcpMountService {
|
|
private final BuiltinMcpToolRegistry builtinMcpToolRegistry;
|
private final McpMountRuntimeFactory mcpMountRuntimeFactory;
|
private final ObjectMapper objectMapper;
|
|
@Override
|
public List<AiMcpMount> listActiveMounts() {
|
return this.list(new LambdaQueryWrapper<AiMcpMount>()
|
.eq(AiMcpMount::getStatus, StatusType.ENABLE.val)
|
.orderByAsc(AiMcpMount::getSort)
|
.orderByAsc(AiMcpMount::getId));
|
}
|
|
@Override
|
public void validateBeforeSave(AiMcpMount aiMcpMount) {
|
fillDefaults(aiMcpMount);
|
ensureRequiredFields(aiMcpMount);
|
}
|
|
@Override
|
public void validateBeforeUpdate(AiMcpMount aiMcpMount) {
|
fillDefaults(aiMcpMount);
|
if (aiMcpMount.getId() == null) {
|
throw new CoolException("MCP 挂载 ID 不能为空");
|
}
|
ensureRequiredFields(aiMcpMount);
|
}
|
|
@Override
|
public List<AiMcpToolPreviewDto> previewTools(Long mountId, Long userId, Long tenantId) {
|
AiMcpMount mount = requireMount(mountId);
|
try (McpMountRuntimeFactory.McpMountRuntime runtime = mcpMountRuntimeFactory.create(List.of(mount), userId)) {
|
List<AiMcpToolPreviewDto> tools = new ArrayList<>();
|
for (ToolCallback callback : runtime.getToolCallbacks()) {
|
if (callback == null || callback.getToolDefinition() == null) {
|
continue;
|
}
|
tools.add(AiMcpToolPreviewDto.builder()
|
.name(callback.getToolDefinition().name())
|
.description(callback.getToolDefinition().description())
|
.inputSchema(callback.getToolDefinition().inputSchema())
|
.returnDirect(callback.getToolMetadata() == null ? null : callback.getToolMetadata().returnDirect())
|
.build());
|
}
|
return tools;
|
}
|
}
|
|
@Override
|
public AiMcpToolTestDto testTool(Long mountId, Long userId, Long tenantId, AiMcpToolTestRequest request) {
|
if (userId == null) {
|
throw new CoolException("当前登录用户不存在");
|
}
|
if (tenantId == null) {
|
throw new CoolException("当前租户不存在");
|
}
|
if (request == null) {
|
throw new CoolException("工具测试参数不能为空");
|
}
|
if (!StringUtils.hasText(request.getToolName())) {
|
throw new CoolException("工具名称不能为空");
|
}
|
if (!StringUtils.hasText(request.getInputJson())) {
|
throw new CoolException("工具输入 JSON 不能为空");
|
}
|
try {
|
objectMapper.readTree(request.getInputJson());
|
} catch (Exception e) {
|
throw new CoolException("工具输入 JSON 格式错误: " + e.getMessage());
|
}
|
AiMcpMount mount = requireMount(mountId);
|
try (McpMountRuntimeFactory.McpMountRuntime runtime = mcpMountRuntimeFactory.create(List.of(mount), userId)) {
|
ToolCallback callback = Arrays.stream(runtime.getToolCallbacks())
|
.filter(item -> item != null && item.getToolDefinition() != null)
|
.filter(item -> request.getToolName().equals(item.getToolDefinition().name()))
|
.findFirst()
|
.orElseThrow(() -> new CoolException("未找到要测试的工具: " + request.getToolName()));
|
String output = callback.call(
|
request.getInputJson(),
|
new ToolContext(Map.of("userId", userId, "tenantId", tenantId, "mountId", mountId))
|
);
|
return AiMcpToolTestDto.builder()
|
.toolName(request.getToolName())
|
.inputJson(request.getInputJson())
|
.output(output)
|
.build();
|
}
|
}
|
|
private void fillDefaults(AiMcpMount aiMcpMount) {
|
if (!StringUtils.hasText(aiMcpMount.getTransportType())) {
|
aiMcpMount.setTransportType(AiDefaults.MCP_TRANSPORT_SSE_HTTP);
|
}
|
if (aiMcpMount.getRequestTimeoutMs() == null) {
|
aiMcpMount.setRequestTimeoutMs(AiDefaults.DEFAULT_TIMEOUT_MS);
|
}
|
if (aiMcpMount.getSort() == null) {
|
aiMcpMount.setSort(0);
|
}
|
if (aiMcpMount.getStatus() == null) {
|
aiMcpMount.setStatus(StatusType.ENABLE.val);
|
}
|
}
|
|
private void ensureRequiredFields(AiMcpMount aiMcpMount) {
|
if (!StringUtils.hasText(aiMcpMount.getName())) {
|
throw new CoolException("MCP 挂载名称不能为空");
|
}
|
if (AiDefaults.MCP_TRANSPORT_BUILTIN.equals(aiMcpMount.getTransportType())) {
|
builtinMcpToolRegistry.validateBuiltinCode(aiMcpMount.getBuiltinCode());
|
ensureBuiltinConflictFree(aiMcpMount);
|
return;
|
}
|
if (AiDefaults.MCP_TRANSPORT_SSE_HTTP.equals(aiMcpMount.getTransportType())) {
|
if (!StringUtils.hasText(aiMcpMount.getServerUrl())) {
|
throw new CoolException("远程 MCP 服务地址不能为空");
|
}
|
return;
|
}
|
if (AiDefaults.MCP_TRANSPORT_STDIO.equals(aiMcpMount.getTransportType())) {
|
if (!StringUtils.hasText(aiMcpMount.getCommand())) {
|
throw new CoolException("STDIO MCP 命令不能为空");
|
}
|
return;
|
}
|
throw new CoolException("不支持的 MCP 传输类型: " + aiMcpMount.getTransportType());
|
}
|
|
private AiMcpMount requireMount(Long mountId) {
|
if (mountId == null) {
|
throw new CoolException("MCP 挂载 ID 不能为空");
|
}
|
AiMcpMount mount = this.getById(mountId);
|
if (mount == null || (mount.getDeleted() != null && mount.getDeleted() == 1)) {
|
throw new CoolException("MCP 挂载不存在");
|
}
|
return mount;
|
}
|
|
private void ensureBuiltinConflictFree(AiMcpMount aiMcpMount) {
|
if (aiMcpMount.getStatus() == null || aiMcpMount.getStatus() != StatusType.ENABLE.val) {
|
return;
|
}
|
List<String> conflictCodes = resolveConflictCodes(aiMcpMount.getBuiltinCode());
|
if (conflictCodes.isEmpty()) {
|
return;
|
}
|
LambdaQueryWrapper<AiMcpMount> queryWrapper = new LambdaQueryWrapper<AiMcpMount>()
|
.eq(AiMcpMount::getTransportType, AiDefaults.MCP_TRANSPORT_BUILTIN)
|
.eq(AiMcpMount::getStatus, StatusType.ENABLE.val)
|
.in(AiMcpMount::getBuiltinCode, conflictCodes);
|
if (aiMcpMount.getId() != null) {
|
queryWrapper.ne(AiMcpMount::getId, aiMcpMount.getId());
|
}
|
List<AiMcpMount> conflictMounts = this.list(queryWrapper);
|
if (conflictMounts.isEmpty()) {
|
return;
|
}
|
String conflictNames = String.join("、", conflictMounts.stream().map(AiMcpMount::getName).toList());
|
throw new CoolException("当前内置 MCP 与已启用挂载冲突,请关闭后再启用: " + conflictNames);
|
}
|
|
private List<String> resolveConflictCodes(String builtinCode) {
|
List<String> codes = new ArrayList<>();
|
if (AiDefaults.MCP_BUILTIN_RSF_WMS.equals(builtinCode)) {
|
codes.add(AiDefaults.MCP_BUILTIN_RSF_WMS_STOCK);
|
codes.add(AiDefaults.MCP_BUILTIN_RSF_WMS_TASK);
|
codes.add(AiDefaults.MCP_BUILTIN_RSF_WMS_BASE);
|
return codes;
|
}
|
if (AiDefaults.MCP_BUILTIN_RSF_WMS_STOCK.equals(builtinCode)
|
|| AiDefaults.MCP_BUILTIN_RSF_WMS_TASK.equals(builtinCode)
|
|| AiDefaults.MCP_BUILTIN_RSF_WMS_BASE.equals(builtinCode)) {
|
codes.add(AiDefaults.MCP_BUILTIN_RSF_WMS);
|
}
|
return codes;
|
}
|
}
|