zhou zhou
8 小时以前 82624affb0251b75b62b35567d3eb260c06efe78
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiMcpMountServiceImpl.java
@@ -1,123 +1,159 @@
package com.vincent.rsf.server.ai.service.impl;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.vincent.rsf.framework.exception.CoolException;
import com.vincent.rsf.server.ai.config.AiDefaults;
import com.vincent.rsf.server.ai.dto.AiMcpConnectivityTestDto;
import com.vincent.rsf.server.ai.dto.AiMcpToolPreviewDto;
import com.vincent.rsf.server.ai.dto.AiMcpToolTestDto;
import com.vincent.rsf.server.ai.dto.AiMcpToolTestRequest;
import com.vincent.rsf.server.ai.entity.AiMcpMount;
import com.vincent.rsf.server.ai.mapper.AiMcpMountMapper;
import com.vincent.rsf.server.ai.service.impl.mcp.AiMcpAdminService;
import com.vincent.rsf.server.ai.store.AiConfigCacheStore;
import com.vincent.rsf.server.ai.store.AiConversationCacheStore;
import com.vincent.rsf.server.ai.store.AiMcpCacheStore;
import com.vincent.rsf.server.ai.service.AiMcpMountService;
import com.vincent.rsf.server.ai.service.BuiltinMcpToolRegistry;
import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory;
import com.vincent.rsf.server.system.enums.StatusType;
import lombok.RequiredArgsConstructor;
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
@Service("aiMcpMountService")
@RequiredArgsConstructor
public class AiMcpMountServiceImpl extends ServiceImpl<AiMcpMountMapper, AiMcpMount> implements AiMcpMountService {
    private final BuiltinMcpToolRegistry builtinMcpToolRegistry;
    private final McpMountRuntimeFactory mcpMountRuntimeFactory;
    private final ObjectMapper objectMapper;
    private final AiMcpAdminService aiMcpAdminService;
    private final AiMcpCacheStore aiMcpCacheStore;
    private final AiConfigCacheStore aiConfigCacheStore;
    private final AiConversationCacheStore aiConversationCacheStore;
    /** 查询某个租户下当前启用的 MCP 挂载列表。 */
    @Override
    public List<AiMcpMount> listActiveMounts() {
    public List<AiMcpMount> listActiveMounts(Long tenantId) {
        ensureTenantId(tenantId);
        return this.list(new LambdaQueryWrapper<AiMcpMount>()
                .eq(AiMcpMount::getTenantId, tenantId)
                .eq(AiMcpMount::getStatus, StatusType.ENABLE.val)
                .eq(AiMcpMount::getDeleted, 0)
                .orderByAsc(AiMcpMount::getSort)
                .orderByAsc(AiMcpMount::getId));
    }
    /** 保存前校验 MCP 挂载草稿,并补全运行时默认值。 */
    @Override
    public void validateBeforeSave(AiMcpMount aiMcpMount) {
    public void validateBeforeSave(AiMcpMount aiMcpMount, Long tenantId) {
        ensureTenantId(tenantId);
        aiMcpMount.setTenantId(tenantId);
        fillDefaults(aiMcpMount);
        ensureRequiredFields(aiMcpMount);
        ensureRequiredFields(aiMcpMount, tenantId);
    }
    /** 更新前校验并锁定记录所属租户,防止跨租户修改。 */
    @Override
    public void validateBeforeUpdate(AiMcpMount aiMcpMount) {
    public void validateBeforeUpdate(AiMcpMount aiMcpMount, Long tenantId) {
        ensureTenantId(tenantId);
        fillDefaults(aiMcpMount);
        if (aiMcpMount.getId() == null) {
            throw new CoolException("MCP 挂载 ID 不能为空");
        }
        ensureRequiredFields(aiMcpMount);
        AiMcpMount current = aiMcpAdminService.requireMount(aiMcpMount.getId(), tenantId);
        aiMcpMount.setTenantId(current.getTenantId());
        ensureRequiredFields(aiMcpMount, tenantId);
    }
    /**
     * 预览当前挂载最终会暴露给模型的工具目录。
     * 对内置 MCP 会额外合并治理目录信息,对外部 MCP 则以实际解析结果为准。
     */
    @Override
    public List<AiMcpToolPreviewDto> previewTools(Long mountId, Long userId, Long tenantId) {
        AiMcpMount mount = requireMount(mountId);
        try (McpMountRuntimeFactory.McpMountRuntime runtime = mcpMountRuntimeFactory.create(List.of(mount), userId)) {
            List<AiMcpToolPreviewDto> tools = new ArrayList<>();
            for (ToolCallback callback : runtime.getToolCallbacks()) {
                if (callback == null || callback.getToolDefinition() == null) {
                    continue;
                }
                tools.add(AiMcpToolPreviewDto.builder()
                        .name(callback.getToolDefinition().name())
                        .description(callback.getToolDefinition().description())
                        .inputSchema(callback.getToolDefinition().inputSchema())
                        .returnDirect(callback.getToolMetadata() == null ? null : callback.getToolMetadata().returnDirect())
                        .build());
            }
            return tools;
        AiMcpMount mount = aiMcpAdminService.requireMount(mountId, tenantId);
        List<AiMcpToolPreviewDto> cached = aiMcpCacheStore.getToolPreview(tenantId, mountId);
        if (cached != null) {
            return cached;
        }
        List<AiMcpToolPreviewDto> tools = aiMcpAdminService.previewTools(mount, userId);
        aiMcpCacheStore.cacheToolPreview(tenantId, mountId, tools);
        return tools;
    }
    /** 对已保存的挂载做真实连通性测试,并把结果回写到运行态字段。 */
    @Override
    public AiMcpToolTestDto testTool(Long mountId, Long userId, Long tenantId, AiMcpToolTestRequest request) {
    public AiMcpConnectivityTestDto testConnectivity(Long mountId, Long userId, Long tenantId) {
        AiMcpMount mount = aiMcpAdminService.requireMount(mountId, tenantId);
        AiMcpConnectivityTestDto connectivity = aiMcpAdminService.testConnectivity(mount, userId, true);
        aiMcpCacheStore.cacheConnectivity(tenantId, mountId, connectivity);
        return connectivity;
    }
    /** 对表单里的草稿配置做临时连通性测试,不落库。 */
    @Override
    public AiMcpConnectivityTestDto testDraftConnectivity(AiMcpMount mount, Long userId, Long tenantId) {
        ensureTenantId(tenantId);
        if (userId == null) {
            throw new CoolException("当前登录用户不存在");
        }
        if (tenantId == null) {
            throw new CoolException("当前租户不存在");
        if (mount == null) {
            throw new CoolException("MCP 挂载参数不能为空");
        }
        if (request == null) {
            throw new CoolException("工具测试参数不能为空");
        mount.setTenantId(tenantId);
        fillDefaults(mount);
        ensureRequiredFields(mount, tenantId);
        return aiMcpAdminService.testConnectivity(mount, userId, false);
    }
    /**
     * 直接执行某一个工具的测试调用。
     * 该方法主要服务于管理端的“工具测试”面板,不参与正式对话链路。
     */
    @Override
    public AiMcpToolTestDto testTool(Long mountId, Long userId, Long tenantId, AiMcpToolTestRequest request) {
        AiMcpMount mount = aiMcpAdminService.requireMount(mountId, tenantId);
        return aiMcpAdminService.testTool(mount, userId, tenantId, request);
    }
    @Override
    public boolean save(AiMcpMount entity) {
        boolean saved = super.save(entity);
        if (saved && entity != null && entity.getTenantId() != null) {
            evictMountRelatedCaches(entity.getTenantId(), entity.getId());
        }
        if (!StringUtils.hasText(request.getToolName())) {
            throw new CoolException("工具名称不能为空");
        return saved;
    }
    @Override
    public boolean updateById(AiMcpMount entity) {
        boolean updated = super.updateById(entity);
        if (updated && entity != null && entity.getTenantId() != null) {
            evictMountRelatedCaches(entity.getTenantId(), entity.getId());
        }
        if (!StringUtils.hasText(request.getInputJson())) {
            throw new CoolException("工具输入 JSON 不能为空");
        return updated;
    }
    @Override
    public boolean removeByIds(java.util.Collection<?> list) {
        java.util.List<java.io.Serializable> ids = list == null ? java.util.List.of() : list.stream()
                .filter(java.util.Objects::nonNull)
                .map(item -> (java.io.Serializable) item)
                .toList();
        java.util.List<AiMcpMount> records = this.listByIds(ids);
        boolean removed = super.removeByIds(list);
        if (removed) {
            records.stream()
                    .filter(java.util.Objects::nonNull)
                    .forEach(item -> evictMountRelatedCaches(item.getTenantId(), item.getId()));
        }
        try {
            objectMapper.readTree(request.getInputJson());
        } catch (Exception e) {
            throw new CoolException("工具输入 JSON 格式错误: " + e.getMessage());
        }
        AiMcpMount mount = requireMount(mountId);
        try (McpMountRuntimeFactory.McpMountRuntime runtime = mcpMountRuntimeFactory.create(List.of(mount), userId)) {
            ToolCallback callback = Arrays.stream(runtime.getToolCallbacks())
                    .filter(item -> item != null && item.getToolDefinition() != null)
                    .filter(item -> request.getToolName().equals(item.getToolDefinition().name()))
                    .findFirst()
                    .orElseThrow(() -> new CoolException("未找到要测试的工具: " + request.getToolName()));
            String output = callback.call(
                    request.getInputJson(),
                    new ToolContext(Map.of("userId", userId, "tenantId", tenantId, "mountId", mountId))
            );
            return AiMcpToolTestDto.builder()
                    .toolName(request.getToolName())
                    .inputJson(request.getInputJson())
                    .output(output)
                    .build();
        }
        return removed;
    }
    private void fillDefaults(AiMcpMount aiMcpMount) {
        /** 为挂载草稿补齐统一默认值,保证后续运行时代码不需要重复判断空值。 */
        if (!StringUtils.hasText(aiMcpMount.getTransportType())) {
            aiMcpMount.setTransportType(AiDefaults.MCP_TRANSPORT_SSE_HTTP);
        }
@@ -130,15 +166,22 @@
        if (aiMcpMount.getStatus() == null) {
            aiMcpMount.setStatus(StatusType.ENABLE.val);
        }
        if (!StringUtils.hasText(aiMcpMount.getHealthStatus())) {
            aiMcpMount.setHealthStatus(AiDefaults.MCP_HEALTH_NOT_TESTED);
        }
    }
    private void ensureRequiredFields(AiMcpMount aiMcpMount) {
    private void ensureRequiredFields(AiMcpMount aiMcpMount, Long tenantId) {
        /**
         * 按 transportType 校验挂载必填项。
         * 这里把“字段合法性”和“跨记录冲突”一起收口,避免校验逻辑分散在 controller 层。
         */
        if (!StringUtils.hasText(aiMcpMount.getName())) {
            throw new CoolException("MCP 挂载名称不能为空");
        }
        if (AiDefaults.MCP_TRANSPORT_BUILTIN.equals(aiMcpMount.getTransportType())) {
            builtinMcpToolRegistry.validateBuiltinCode(aiMcpMount.getBuiltinCode());
            ensureBuiltinConflictFree(aiMcpMount);
            ensureBuiltinConflictFree(aiMcpMount, tenantId);
            return;
        }
        if (AiDefaults.MCP_TRANSPORT_SSE_HTTP.equals(aiMcpMount.getTransportType())) {
@@ -156,18 +199,8 @@
        throw new CoolException("不支持的 MCP 传输类型: " + aiMcpMount.getTransportType());
    }
    private AiMcpMount requireMount(Long mountId) {
        if (mountId == null) {
            throw new CoolException("MCP 挂载 ID 不能为空");
        }
        AiMcpMount mount = this.getById(mountId);
        if (mount == null || (mount.getDeleted() != null && mount.getDeleted() == 1)) {
            throw new CoolException("MCP 挂载不存在");
        }
        return mount;
    }
    private void ensureBuiltinConflictFree(AiMcpMount aiMcpMount) {
    private void ensureBuiltinConflictFree(AiMcpMount aiMcpMount, Long tenantId) {
        /** 校验同租户下是否存在与当前内置编码互斥的启用挂载。 */
        if (aiMcpMount.getStatus() == null || aiMcpMount.getStatus() != StatusType.ENABLE.val) {
            return;
        }
@@ -176,8 +209,10 @@
            return;
        }
        LambdaQueryWrapper<AiMcpMount> queryWrapper = new LambdaQueryWrapper<AiMcpMount>()
                .eq(AiMcpMount::getTenantId, tenantId)
                .eq(AiMcpMount::getTransportType, AiDefaults.MCP_TRANSPORT_BUILTIN)
                .eq(AiMcpMount::getStatus, StatusType.ENABLE.val)
                .eq(AiMcpMount::getDeleted, 0)
                .in(AiMcpMount::getBuiltinCode, conflictCodes);
        if (aiMcpMount.getId() != null) {
            queryWrapper.ne(AiMcpMount::getId, aiMcpMount.getId());
@@ -191,18 +226,21 @@
    }
    private List<String> resolveConflictCodes(String builtinCode) {
        List<String> codes = new ArrayList<>();
        if (AiDefaults.MCP_BUILTIN_RSF_WMS.equals(builtinCode)) {
            codes.add(AiDefaults.MCP_BUILTIN_RSF_WMS_STOCK);
            codes.add(AiDefaults.MCP_BUILTIN_RSF_WMS_TASK);
            codes.add(AiDefaults.MCP_BUILTIN_RSF_WMS_BASE);
            return codes;
            return List.of();
        }
        if (AiDefaults.MCP_BUILTIN_RSF_WMS_STOCK.equals(builtinCode)
                || AiDefaults.MCP_BUILTIN_RSF_WMS_TASK.equals(builtinCode)
                || AiDefaults.MCP_BUILTIN_RSF_WMS_BASE.equals(builtinCode)) {
            codes.add(AiDefaults.MCP_BUILTIN_RSF_WMS);
        throw new CoolException("不支持的内置 MCP 编码: " + builtinCode);
    }
    private void ensureTenantId(Long tenantId) {
        if (tenantId == null) {
            throw new CoolException("当前租户不存在");
        }
        return codes;
    }
    private void evictMountRelatedCaches(Long tenantId, Long mountId) {
        aiMcpCacheStore.evictMcpMountCaches(tenantId, mountId);
        aiConfigCacheStore.evictTenantConfigCaches(tenantId);
        aiConversationCacheStore.evictTenantRuntimeCaches(tenantId);
    }
}