package com.vincent.rsf.server.ai.service.impl; import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import com.fasterxml.jackson.databind.ObjectMapper; import com.vincent.rsf.framework.exception.CoolException; import com.vincent.rsf.server.ai.config.AiDefaults; import com.vincent.rsf.server.ai.dto.AiMcpConnectivityTestDto; import com.vincent.rsf.server.ai.dto.AiMcpToolPreviewDto; import com.vincent.rsf.server.ai.dto.AiMcpToolTestDto; import com.vincent.rsf.server.ai.dto.AiMcpToolTestRequest; import com.vincent.rsf.server.ai.entity.AiMcpMount; import com.vincent.rsf.server.ai.mapper.AiMcpMountMapper; import com.vincent.rsf.server.ai.service.AiMcpMountService; import com.vincent.rsf.server.ai.service.BuiltinMcpToolRegistry; import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory; import com.vincent.rsf.server.system.enums.StatusType; import lombok.RequiredArgsConstructor; import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.tool.ToolCallback; import org.springframework.stereotype.Service; import org.springframework.util.StringUtils; import java.util.ArrayList; import java.util.Arrays; import java.util.Date; import java.util.List; import java.util.Map; @Service("aiMcpMountService") @RequiredArgsConstructor public class AiMcpMountServiceImpl extends ServiceImpl implements AiMcpMountService { private final BuiltinMcpToolRegistry builtinMcpToolRegistry; private final McpMountRuntimeFactory mcpMountRuntimeFactory; private final ObjectMapper objectMapper; @Override public List listActiveMounts(Long tenantId) { ensureTenantId(tenantId); return this.list(new LambdaQueryWrapper() .eq(AiMcpMount::getTenantId, tenantId) .eq(AiMcpMount::getStatus, StatusType.ENABLE.val) .eq(AiMcpMount::getDeleted, 0) .orderByAsc(AiMcpMount::getSort) .orderByAsc(AiMcpMount::getId)); } @Override public void validateBeforeSave(AiMcpMount aiMcpMount, Long tenantId) { ensureTenantId(tenantId); aiMcpMount.setTenantId(tenantId); fillDefaults(aiMcpMount); ensureRequiredFields(aiMcpMount, tenantId); } @Override public void validateBeforeUpdate(AiMcpMount aiMcpMount, Long tenantId) { ensureTenantId(tenantId); fillDefaults(aiMcpMount); if (aiMcpMount.getId() == null) { throw new CoolException("MCP 挂载 ID 不能为空"); } AiMcpMount current = requireMount(aiMcpMount.getId(), tenantId); aiMcpMount.setTenantId(current.getTenantId()); ensureRequiredFields(aiMcpMount, tenantId); } @Override public List previewTools(Long mountId, Long userId, Long tenantId) { AiMcpMount mount = requireMount(mountId, tenantId); long startedAt = System.currentTimeMillis(); try (McpMountRuntimeFactory.McpMountRuntime runtime = mcpMountRuntimeFactory.create(List.of(mount), userId)) { List tools = buildToolPreviewDtos(runtime.getToolCallbacks()); if (!runtime.getErrors().isEmpty()) { String message = String.join(";", runtime.getErrors()); updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY, message, System.currentTimeMillis() - startedAt); throw new CoolException(message); } updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_HEALTHY, "工具解析成功,共 " + tools.size() + " 个工具", System.currentTimeMillis() - startedAt); return tools; } catch (CoolException e) { throw e; } catch (Exception e) { updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY, "工具解析失败: " + e.getMessage(), System.currentTimeMillis() - startedAt); throw new CoolException("获取工具列表失败: " + e.getMessage()); } } @Override public AiMcpConnectivityTestDto testConnectivity(Long mountId, Long userId, Long tenantId) { AiMcpMount mount = requireMount(mountId, tenantId); long startedAt = System.currentTimeMillis(); try (McpMountRuntimeFactory.McpMountRuntime runtime = mcpMountRuntimeFactory.create(List.of(mount), userId)) { long elapsedMs = System.currentTimeMillis() - startedAt; if (!runtime.getErrors().isEmpty()) { String message = String.join(";", runtime.getErrors()); updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY, message, elapsedMs); AiMcpMount latest = requireMount(mount.getId(), tenantId); return buildConnectivityDto(latest, message, elapsedMs, runtime.getToolCallbacks().length); } String message = "连通性测试成功,解析出 " + runtime.getToolCallbacks().length + " 个工具"; updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_HEALTHY, message, elapsedMs); AiMcpMount latest = requireMount(mount.getId(), tenantId); return buildConnectivityDto(latest, message, elapsedMs, runtime.getToolCallbacks().length); } catch (CoolException e) { throw e; } catch (Exception e) { long elapsedMs = System.currentTimeMillis() - startedAt; String message = "连通性测试失败: " + e.getMessage(); updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY, message, elapsedMs); AiMcpMount latest = requireMount(mount.getId(), tenantId); return buildConnectivityDto(latest, message, elapsedMs, 0); } } @Override public AiMcpToolTestDto testTool(Long mountId, Long userId, Long tenantId, AiMcpToolTestRequest request) { if (userId == null) { throw new CoolException("当前登录用户不存在"); } if (tenantId == null) { throw new CoolException("当前租户不存在"); } if (request == null) { throw new CoolException("工具测试参数不能为空"); } if (!StringUtils.hasText(request.getToolName())) { throw new CoolException("工具名称不能为空"); } if (!StringUtils.hasText(request.getInputJson())) { throw new CoolException("工具输入 JSON 不能为空"); } try { objectMapper.readTree(request.getInputJson()); } catch (Exception e) { throw new CoolException("工具输入 JSON 格式错误: " + e.getMessage()); } AiMcpMount mount = requireMount(mountId, tenantId); long startedAt = System.currentTimeMillis(); try (McpMountRuntimeFactory.McpMountRuntime runtime = mcpMountRuntimeFactory.create(List.of(mount), userId)) { ToolCallback callback = Arrays.stream(runtime.getToolCallbacks()) .filter(item -> item != null && item.getToolDefinition() != null) .filter(item -> request.getToolName().equals(item.getToolDefinition().name())) .findFirst() .orElseThrow(() -> new CoolException("未找到要测试的工具: " + request.getToolName())); String output = callback.call( request.getInputJson(), new ToolContext(Map.of("userId", userId, "tenantId", tenantId, "mountId", mountId)) ); updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_HEALTHY, "工具测试成功: " + request.getToolName(), System.currentTimeMillis() - startedAt); return AiMcpToolTestDto.builder() .toolName(request.getToolName()) .inputJson(request.getInputJson()) .output(output) .build(); } catch (CoolException e) { updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY, "工具测试失败: " + e.getMessage(), System.currentTimeMillis() - startedAt); throw e; } catch (Exception e) { updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY, "工具测试失败: " + e.getMessage(), System.currentTimeMillis() - startedAt); throw new CoolException("工具测试失败: " + e.getMessage()); } } private void fillDefaults(AiMcpMount aiMcpMount) { if (!StringUtils.hasText(aiMcpMount.getTransportType())) { aiMcpMount.setTransportType(AiDefaults.MCP_TRANSPORT_SSE_HTTP); } if (aiMcpMount.getRequestTimeoutMs() == null) { aiMcpMount.setRequestTimeoutMs(AiDefaults.DEFAULT_TIMEOUT_MS); } if (aiMcpMount.getSort() == null) { aiMcpMount.setSort(0); } if (aiMcpMount.getStatus() == null) { aiMcpMount.setStatus(StatusType.ENABLE.val); } if (!StringUtils.hasText(aiMcpMount.getHealthStatus())) { aiMcpMount.setHealthStatus(AiDefaults.MCP_HEALTH_NOT_TESTED); } } private void ensureRequiredFields(AiMcpMount aiMcpMount, Long tenantId) { if (!StringUtils.hasText(aiMcpMount.getName())) { throw new CoolException("MCP 挂载名称不能为空"); } if (AiDefaults.MCP_TRANSPORT_BUILTIN.equals(aiMcpMount.getTransportType())) { builtinMcpToolRegistry.validateBuiltinCode(aiMcpMount.getBuiltinCode()); ensureBuiltinConflictFree(aiMcpMount, tenantId); return; } if (AiDefaults.MCP_TRANSPORT_SSE_HTTP.equals(aiMcpMount.getTransportType())) { if (!StringUtils.hasText(aiMcpMount.getServerUrl())) { throw new CoolException("远程 MCP 服务地址不能为空"); } return; } if (AiDefaults.MCP_TRANSPORT_STDIO.equals(aiMcpMount.getTransportType())) { if (!StringUtils.hasText(aiMcpMount.getCommand())) { throw new CoolException("STDIO MCP 命令不能为空"); } return; } throw new CoolException("不支持的 MCP 传输类型: " + aiMcpMount.getTransportType()); } private AiMcpMount requireMount(Long mountId, Long tenantId) { ensureTenantId(tenantId); if (mountId == null) { throw new CoolException("MCP 挂载 ID 不能为空"); } AiMcpMount mount = this.getOne(new LambdaQueryWrapper() .eq(AiMcpMount::getId, mountId) .eq(AiMcpMount::getTenantId, tenantId) .eq(AiMcpMount::getDeleted, 0) .last("limit 1")); if (mount == null) { throw new CoolException("MCP 挂载不存在"); } return mount; } private void ensureBuiltinConflictFree(AiMcpMount aiMcpMount, Long tenantId) { if (aiMcpMount.getStatus() == null || aiMcpMount.getStatus() != StatusType.ENABLE.val) { return; } List conflictCodes = resolveConflictCodes(aiMcpMount.getBuiltinCode()); if (conflictCodes.isEmpty()) { return; } LambdaQueryWrapper queryWrapper = new LambdaQueryWrapper() .eq(AiMcpMount::getTenantId, tenantId) .eq(AiMcpMount::getTransportType, AiDefaults.MCP_TRANSPORT_BUILTIN) .eq(AiMcpMount::getStatus, StatusType.ENABLE.val) .eq(AiMcpMount::getDeleted, 0) .in(AiMcpMount::getBuiltinCode, conflictCodes); if (aiMcpMount.getId() != null) { queryWrapper.ne(AiMcpMount::getId, aiMcpMount.getId()); } List conflictMounts = this.list(queryWrapper); if (conflictMounts.isEmpty()) { return; } String conflictNames = String.join("、", conflictMounts.stream().map(AiMcpMount::getName).toList()); throw new CoolException("当前内置 MCP 与已启用挂载冲突,请关闭后再启用: " + conflictNames); } private List resolveConflictCodes(String builtinCode) { List codes = new ArrayList<>(); if (AiDefaults.MCP_BUILTIN_RSF_WMS.equals(builtinCode)) { codes.add(AiDefaults.MCP_BUILTIN_RSF_WMS_STOCK); codes.add(AiDefaults.MCP_BUILTIN_RSF_WMS_TASK); codes.add(AiDefaults.MCP_BUILTIN_RSF_WMS_BASE); return codes; } if (AiDefaults.MCP_BUILTIN_RSF_WMS_STOCK.equals(builtinCode) || AiDefaults.MCP_BUILTIN_RSF_WMS_TASK.equals(builtinCode) || AiDefaults.MCP_BUILTIN_RSF_WMS_BASE.equals(builtinCode)) { codes.add(AiDefaults.MCP_BUILTIN_RSF_WMS); } return codes; } private void ensureTenantId(Long tenantId) { if (tenantId == null) { throw new CoolException("当前租户不存在"); } } private List buildToolPreviewDtos(ToolCallback[] callbacks) { List tools = new ArrayList<>(); for (ToolCallback callback : callbacks) { if (callback == null || callback.getToolDefinition() == null) { continue; } tools.add(AiMcpToolPreviewDto.builder() .name(callback.getToolDefinition().name()) .description(callback.getToolDefinition().description()) .inputSchema(callback.getToolDefinition().inputSchema()) .returnDirect(callback.getToolMetadata() == null ? null : callback.getToolMetadata().returnDirect()) .build()); } return tools; } private void updateHealthStatus(Long mountId, String healthStatus, String message, Long initElapsedMs) { this.update(new LambdaUpdateWrapper() .eq(AiMcpMount::getId, mountId) .set(AiMcpMount::getHealthStatus, healthStatus) .set(AiMcpMount::getLastTestTime, new Date()) .set(AiMcpMount::getLastTestMessage, message) .set(AiMcpMount::getLastInitElapsedMs, initElapsedMs)); } private AiMcpConnectivityTestDto buildConnectivityDto(AiMcpMount mount, String message, Long initElapsedMs, Integer toolCount) { return AiMcpConnectivityTestDto.builder() .mountId(mount.getId()) .mountName(mount.getName()) .healthStatus(mount.getHealthStatus()) .message(message) .initElapsedMs(initElapsedMs) .toolCount(toolCount) .testedAt(mount.getLastTestTime$()) .build(); } }