package com.vincent.rsf.server.ai.service.impl;
|
|
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
|
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.vincent.rsf.framework.exception.CoolException;
|
import com.vincent.rsf.server.ai.config.AiDefaults;
|
import com.vincent.rsf.server.ai.dto.AiMcpConnectivityTestDto;
|
import com.vincent.rsf.server.ai.dto.AiMcpToolPreviewDto;
|
import com.vincent.rsf.server.ai.dto.AiMcpToolTestDto;
|
import com.vincent.rsf.server.ai.dto.AiMcpToolTestRequest;
|
import com.vincent.rsf.server.ai.entity.AiMcpMount;
|
import com.vincent.rsf.server.ai.mapper.AiMcpMountMapper;
|
import com.vincent.rsf.server.ai.service.AiMcpMountService;
|
import com.vincent.rsf.server.ai.service.BuiltinMcpToolRegistry;
|
import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory;
|
import com.vincent.rsf.server.system.enums.StatusType;
|
import lombok.RequiredArgsConstructor;
|
import org.springframework.ai.chat.model.ToolContext;
|
import org.springframework.ai.tool.ToolCallback;
|
import org.springframework.stereotype.Service;
|
import org.springframework.util.StringUtils;
|
|
import java.util.ArrayList;
|
import java.util.Arrays;
|
import java.util.Date;
|
import java.util.List;
|
import java.util.Map;
|
|
@Service("aiMcpMountService")
|
@RequiredArgsConstructor
|
public class AiMcpMountServiceImpl extends ServiceImpl<AiMcpMountMapper, AiMcpMount> implements AiMcpMountService {
|
|
private final BuiltinMcpToolRegistry builtinMcpToolRegistry;
|
private final McpMountRuntimeFactory mcpMountRuntimeFactory;
|
private final ObjectMapper objectMapper;
|
|
@Override
|
public List<AiMcpMount> listActiveMounts(Long tenantId) {
|
ensureTenantId(tenantId);
|
return this.list(new LambdaQueryWrapper<AiMcpMount>()
|
.eq(AiMcpMount::getTenantId, tenantId)
|
.eq(AiMcpMount::getStatus, StatusType.ENABLE.val)
|
.eq(AiMcpMount::getDeleted, 0)
|
.orderByAsc(AiMcpMount::getSort)
|
.orderByAsc(AiMcpMount::getId));
|
}
|
|
@Override
|
public void validateBeforeSave(AiMcpMount aiMcpMount, Long tenantId) {
|
ensureTenantId(tenantId);
|
aiMcpMount.setTenantId(tenantId);
|
fillDefaults(aiMcpMount);
|
ensureRequiredFields(aiMcpMount, tenantId);
|
}
|
|
@Override
|
public void validateBeforeUpdate(AiMcpMount aiMcpMount, Long tenantId) {
|
ensureTenantId(tenantId);
|
fillDefaults(aiMcpMount);
|
if (aiMcpMount.getId() == null) {
|
throw new CoolException("MCP 挂载 ID 不能为空");
|
}
|
AiMcpMount current = requireMount(aiMcpMount.getId(), tenantId);
|
aiMcpMount.setTenantId(current.getTenantId());
|
ensureRequiredFields(aiMcpMount, tenantId);
|
}
|
|
@Override
|
public List<AiMcpToolPreviewDto> previewTools(Long mountId, Long userId, Long tenantId) {
|
AiMcpMount mount = requireMount(mountId, tenantId);
|
long startedAt = System.currentTimeMillis();
|
try (McpMountRuntimeFactory.McpMountRuntime runtime = mcpMountRuntimeFactory.create(List.of(mount), userId)) {
|
List<AiMcpToolPreviewDto> tools = buildToolPreviewDtos(runtime.getToolCallbacks(),
|
AiDefaults.MCP_TRANSPORT_BUILTIN.equals(mount.getTransportType())
|
? builtinMcpToolRegistry.listBuiltinToolCatalog(mount.getBuiltinCode())
|
: List.of());
|
if (!runtime.getErrors().isEmpty()) {
|
String message = String.join(";", runtime.getErrors());
|
updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY, message, System.currentTimeMillis() - startedAt);
|
throw new CoolException(message);
|
}
|
updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_HEALTHY,
|
"工具解析成功,共 " + tools.size() + " 个工具", System.currentTimeMillis() - startedAt);
|
return tools;
|
} catch (CoolException e) {
|
throw e;
|
} catch (Exception e) {
|
updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY,
|
"工具解析失败: " + e.getMessage(), System.currentTimeMillis() - startedAt);
|
throw new CoolException("获取工具列表失败: " + e.getMessage());
|
}
|
}
|
|
@Override
|
public AiMcpConnectivityTestDto testConnectivity(Long mountId, Long userId, Long tenantId) {
|
AiMcpMount mount = requireMount(mountId, tenantId);
|
long startedAt = System.currentTimeMillis();
|
try (McpMountRuntimeFactory.McpMountRuntime runtime = mcpMountRuntimeFactory.create(List.of(mount), userId)) {
|
long elapsedMs = System.currentTimeMillis() - startedAt;
|
if (!runtime.getErrors().isEmpty()) {
|
String message = String.join(";", runtime.getErrors());
|
updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY, message, elapsedMs);
|
AiMcpMount latest = requireMount(mount.getId(), tenantId);
|
return buildConnectivityDto(latest, message, elapsedMs, runtime.getToolCallbacks().length);
|
}
|
String message = "连通性测试成功,解析出 " + runtime.getToolCallbacks().length + " 个工具";
|
updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_HEALTHY, message, elapsedMs);
|
AiMcpMount latest = requireMount(mount.getId(), tenantId);
|
return buildConnectivityDto(latest, message, elapsedMs, runtime.getToolCallbacks().length);
|
} catch (CoolException e) {
|
throw e;
|
} catch (Exception e) {
|
long elapsedMs = System.currentTimeMillis() - startedAt;
|
String message = "连通性测试失败: " + e.getMessage();
|
updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY, message, elapsedMs);
|
AiMcpMount latest = requireMount(mount.getId(), tenantId);
|
return buildConnectivityDto(latest, message, elapsedMs, 0);
|
}
|
}
|
|
@Override
|
public AiMcpToolTestDto testTool(Long mountId, Long userId, Long tenantId, AiMcpToolTestRequest request) {
|
if (userId == null) {
|
throw new CoolException("当前登录用户不存在");
|
}
|
if (tenantId == null) {
|
throw new CoolException("当前租户不存在");
|
}
|
if (request == null) {
|
throw new CoolException("工具测试参数不能为空");
|
}
|
if (!StringUtils.hasText(request.getToolName())) {
|
throw new CoolException("工具名称不能为空");
|
}
|
if (!StringUtils.hasText(request.getInputJson())) {
|
throw new CoolException("工具输入 JSON 不能为空");
|
}
|
try {
|
objectMapper.readTree(request.getInputJson());
|
} catch (Exception e) {
|
throw new CoolException("工具输入 JSON 格式错误: " + e.getMessage());
|
}
|
AiMcpMount mount = requireMount(mountId, tenantId);
|
long startedAt = System.currentTimeMillis();
|
try (McpMountRuntimeFactory.McpMountRuntime runtime = mcpMountRuntimeFactory.create(List.of(mount), userId)) {
|
ToolCallback callback = Arrays.stream(runtime.getToolCallbacks())
|
.filter(item -> item != null && item.getToolDefinition() != null)
|
.filter(item -> request.getToolName().equals(item.getToolDefinition().name()))
|
.findFirst()
|
.orElseThrow(() -> new CoolException("未找到要测试的工具: " + request.getToolName()));
|
String output = callback.call(
|
request.getInputJson(),
|
new ToolContext(Map.of("userId", userId, "tenantId", tenantId, "mountId", mountId))
|
);
|
updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_HEALTHY,
|
"工具测试成功: " + request.getToolName(), System.currentTimeMillis() - startedAt);
|
return AiMcpToolTestDto.builder()
|
.toolName(request.getToolName())
|
.inputJson(request.getInputJson())
|
.output(output)
|
.build();
|
} catch (CoolException e) {
|
updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY,
|
"工具测试失败: " + e.getMessage(), System.currentTimeMillis() - startedAt);
|
throw e;
|
} catch (Exception e) {
|
updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY,
|
"工具测试失败: " + e.getMessage(), System.currentTimeMillis() - startedAt);
|
throw new CoolException("工具测试失败: " + e.getMessage());
|
}
|
}
|
|
private void fillDefaults(AiMcpMount aiMcpMount) {
|
if (!StringUtils.hasText(aiMcpMount.getTransportType())) {
|
aiMcpMount.setTransportType(AiDefaults.MCP_TRANSPORT_SSE_HTTP);
|
}
|
if (aiMcpMount.getRequestTimeoutMs() == null) {
|
aiMcpMount.setRequestTimeoutMs(AiDefaults.DEFAULT_TIMEOUT_MS);
|
}
|
if (aiMcpMount.getSort() == null) {
|
aiMcpMount.setSort(0);
|
}
|
if (aiMcpMount.getStatus() == null) {
|
aiMcpMount.setStatus(StatusType.ENABLE.val);
|
}
|
if (!StringUtils.hasText(aiMcpMount.getHealthStatus())) {
|
aiMcpMount.setHealthStatus(AiDefaults.MCP_HEALTH_NOT_TESTED);
|
}
|
}
|
|
private void ensureRequiredFields(AiMcpMount aiMcpMount, Long tenantId) {
|
if (!StringUtils.hasText(aiMcpMount.getName())) {
|
throw new CoolException("MCP 挂载名称不能为空");
|
}
|
if (AiDefaults.MCP_TRANSPORT_BUILTIN.equals(aiMcpMount.getTransportType())) {
|
builtinMcpToolRegistry.validateBuiltinCode(aiMcpMount.getBuiltinCode());
|
ensureBuiltinConflictFree(aiMcpMount, tenantId);
|
return;
|
}
|
if (AiDefaults.MCP_TRANSPORT_SSE_HTTP.equals(aiMcpMount.getTransportType())) {
|
if (!StringUtils.hasText(aiMcpMount.getServerUrl())) {
|
throw new CoolException("远程 MCP 服务地址不能为空");
|
}
|
return;
|
}
|
if (AiDefaults.MCP_TRANSPORT_STDIO.equals(aiMcpMount.getTransportType())) {
|
if (!StringUtils.hasText(aiMcpMount.getCommand())) {
|
throw new CoolException("STDIO MCP 命令不能为空");
|
}
|
return;
|
}
|
throw new CoolException("不支持的 MCP 传输类型: " + aiMcpMount.getTransportType());
|
}
|
|
private AiMcpMount requireMount(Long mountId, Long tenantId) {
|
ensureTenantId(tenantId);
|
if (mountId == null) {
|
throw new CoolException("MCP 挂载 ID 不能为空");
|
}
|
AiMcpMount mount = this.getOne(new LambdaQueryWrapper<AiMcpMount>()
|
.eq(AiMcpMount::getId, mountId)
|
.eq(AiMcpMount::getTenantId, tenantId)
|
.eq(AiMcpMount::getDeleted, 0)
|
.last("limit 1"));
|
if (mount == null) {
|
throw new CoolException("MCP 挂载不存在");
|
}
|
return mount;
|
}
|
|
private void ensureBuiltinConflictFree(AiMcpMount aiMcpMount, Long tenantId) {
|
if (aiMcpMount.getStatus() == null || aiMcpMount.getStatus() != StatusType.ENABLE.val) {
|
return;
|
}
|
List<String> conflictCodes = resolveConflictCodes(aiMcpMount.getBuiltinCode());
|
if (conflictCodes.isEmpty()) {
|
return;
|
}
|
LambdaQueryWrapper<AiMcpMount> queryWrapper = new LambdaQueryWrapper<AiMcpMount>()
|
.eq(AiMcpMount::getTenantId, tenantId)
|
.eq(AiMcpMount::getTransportType, AiDefaults.MCP_TRANSPORT_BUILTIN)
|
.eq(AiMcpMount::getStatus, StatusType.ENABLE.val)
|
.eq(AiMcpMount::getDeleted, 0)
|
.in(AiMcpMount::getBuiltinCode, conflictCodes);
|
if (aiMcpMount.getId() != null) {
|
queryWrapper.ne(AiMcpMount::getId, aiMcpMount.getId());
|
}
|
List<AiMcpMount> conflictMounts = this.list(queryWrapper);
|
if (conflictMounts.isEmpty()) {
|
return;
|
}
|
String conflictNames = String.join("、", conflictMounts.stream().map(AiMcpMount::getName).toList());
|
throw new CoolException("当前内置 MCP 与已启用挂载冲突,请关闭后再启用: " + conflictNames);
|
}
|
|
private List<String> resolveConflictCodes(String builtinCode) {
|
List<String> codes = new ArrayList<>();
|
if (AiDefaults.MCP_BUILTIN_RSF_WMS.equals(builtinCode)) {
|
codes.add(AiDefaults.MCP_BUILTIN_RSF_WMS_STOCK);
|
codes.add(AiDefaults.MCP_BUILTIN_RSF_WMS_TASK);
|
codes.add(AiDefaults.MCP_BUILTIN_RSF_WMS_BASE);
|
return codes;
|
}
|
if (AiDefaults.MCP_BUILTIN_RSF_WMS_STOCK.equals(builtinCode)
|
|| AiDefaults.MCP_BUILTIN_RSF_WMS_TASK.equals(builtinCode)
|
|| AiDefaults.MCP_BUILTIN_RSF_WMS_BASE.equals(builtinCode)) {
|
codes.add(AiDefaults.MCP_BUILTIN_RSF_WMS);
|
}
|
return codes;
|
}
|
|
private void ensureTenantId(Long tenantId) {
|
if (tenantId == null) {
|
throw new CoolException("当前租户不存在");
|
}
|
}
|
|
private List<AiMcpToolPreviewDto> buildToolPreviewDtos(ToolCallback[] callbacks, List<AiMcpToolPreviewDto> governedCatalog) {
|
List<AiMcpToolPreviewDto> tools = new ArrayList<>();
|
Map<String, AiMcpToolPreviewDto> catalogMap = new java.util.LinkedHashMap<>();
|
for (AiMcpToolPreviewDto item : governedCatalog) {
|
if (item == null || !StringUtils.hasText(item.getName())) {
|
continue;
|
}
|
catalogMap.put(item.getName(), item);
|
}
|
for (ToolCallback callback : callbacks) {
|
if (callback == null || callback.getToolDefinition() == null) {
|
continue;
|
}
|
AiMcpToolPreviewDto governedItem = catalogMap.get(callback.getToolDefinition().name());
|
tools.add(AiMcpToolPreviewDto.builder()
|
.name(callback.getToolDefinition().name())
|
.description(callback.getToolDefinition().description())
|
.inputSchema(callback.getToolDefinition().inputSchema())
|
.returnDirect(callback.getToolMetadata() == null ? null : callback.getToolMetadata().returnDirect())
|
.toolGroup(governedItem == null ? null : governedItem.getToolGroup())
|
.toolPurpose(governedItem == null ? null : governedItem.getToolPurpose())
|
.queryBoundary(governedItem == null ? null : governedItem.getQueryBoundary())
|
.exampleQuestions(governedItem == null ? List.of() : governedItem.getExampleQuestions())
|
.build());
|
}
|
return tools;
|
}
|
|
private void updateHealthStatus(Long mountId, String healthStatus, String message, Long initElapsedMs) {
|
this.update(new LambdaUpdateWrapper<AiMcpMount>()
|
.eq(AiMcpMount::getId, mountId)
|
.set(AiMcpMount::getHealthStatus, healthStatus)
|
.set(AiMcpMount::getLastTestTime, new Date())
|
.set(AiMcpMount::getLastTestMessage, message)
|
.set(AiMcpMount::getLastInitElapsedMs, initElapsedMs));
|
}
|
|
private AiMcpConnectivityTestDto buildConnectivityDto(AiMcpMount mount, String message, Long initElapsedMs, Integer toolCount) {
|
return AiMcpConnectivityTestDto.builder()
|
.mountId(mount.getId())
|
.mountName(mount.getName())
|
.healthStatus(mount.getHealthStatus())
|
.message(message)
|
.initElapsedMs(initElapsedMs)
|
.toolCount(toolCount)
|
.testedAt(mount.getLastTestTime$())
|
.build();
|
}
|
}
|