package com.vincent.rsf.server.ai.tool; import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import com.vincent.rsf.framework.exception.CoolException; import com.vincent.rsf.server.manager.entity.BasStation; import com.vincent.rsf.server.manager.entity.Warehouse; import com.vincent.rsf.server.manager.service.BasStationService; import com.vincent.rsf.server.manager.service.WarehouseService; import com.vincent.rsf.server.system.entity.DictData; import com.vincent.rsf.server.system.service.DictDataService; import lombok.RequiredArgsConstructor; import org.springframework.ai.tool.annotation.Tool; import org.springframework.ai.tool.annotation.ToolParam; import org.springframework.stereotype.Component; import org.springframework.util.StringUtils; import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @Component @RequiredArgsConstructor public class RsfWmsBaseTools { private final WarehouseService warehouseService; private final BasStationService basStationService; private final DictDataService dictDataService; @Tool(name = "rsf_query_warehouses", description = "只读查询工具。按仓库编码或名称查询仓库基础信息。") public List> queryWarehouses( @ToolParam(description = "仓库编码,可选") String code, @ToolParam(description = "仓库名称,可选") String name, @ToolParam(description = "返回条数,默认 10,最大 50") Integer limit) { String normalizedCode = BuiltinToolGovernanceSupport.sanitizeQueryText(code, "仓库编码", 64); String normalizedName = BuiltinToolGovernanceSupport.sanitizeQueryText(name, "仓库名称", 100); BuiltinToolGovernanceSupport.requireAnyFilter("仓库查询至少需要提供仓库编码或名称", normalizedCode, normalizedName); LambdaQueryWrapper queryWrapper = new LambdaQueryWrapper<>(); int finalLimit = BuiltinToolGovernanceSupport.normalizeLimit(limit, 10, 50); if (StringUtils.hasText(normalizedCode)) { queryWrapper.like(Warehouse::getCode, normalizedCode); } if (StringUtils.hasText(normalizedName)) { queryWrapper.like(Warehouse::getName, normalizedName); } queryWrapper.orderByAsc(Warehouse::getCode).last("LIMIT " + finalLimit); List warehouses = warehouseService.list(queryWrapper); List> result = new ArrayList<>(); for (Warehouse warehouse : warehouses) { Map item = new LinkedHashMap<>(); item.put("id", warehouse.getId()); item.put("code", warehouse.getCode()); item.put("name", warehouse.getName()); item.put("factory", warehouse.getFactory()); item.put("address", warehouse.getAddress()); item.put("status", warehouse.getStatus()); item.put("statusLabel", warehouse.getStatus$()); item.put("memo", warehouse.getMemo()); result.add(item); } return result; } @Tool(name = "rsf_query_bas_stations", description = "只读查询工具。按站点编号、站点名称或使用状态查询基础站点。") public List> queryBasStations( @ToolParam(description = "站点名称,可选") String stationName, @ToolParam(description = "站点编号,可选") String stationId, @ToolParam(description = "使用状态,可选") String useStatus, @ToolParam(description = "返回条数,默认 10,最大 50") Integer limit) { String normalizedStationName = BuiltinToolGovernanceSupport.sanitizeQueryText(stationName, "站点名称", 100); String normalizedStationId = BuiltinToolGovernanceSupport.sanitizeQueryText(stationId, "站点编号", 64); String normalizedUseStatus = BuiltinToolGovernanceSupport.sanitizeQueryText(useStatus, "使用状态", 32); BuiltinToolGovernanceSupport.requireAnyFilter("基础站点查询至少需要提供站点名称、站点编号或使用状态", normalizedStationName, normalizedStationId, normalizedUseStatus); LambdaQueryWrapper queryWrapper = new LambdaQueryWrapper<>(); int finalLimit = BuiltinToolGovernanceSupport.normalizeLimit(limit, 10, 50); if (StringUtils.hasText(normalizedStationName)) { queryWrapper.like(BasStation::getStationName, normalizedStationName); } if (StringUtils.hasText(normalizedStationId)) { queryWrapper.like(BasStation::getStationId, normalizedStationId); } if (StringUtils.hasText(normalizedUseStatus)) { queryWrapper.eq(BasStation::getUseStatus, normalizedUseStatus); } queryWrapper.orderByAsc(BasStation::getStationName).last("LIMIT " + finalLimit); List stations = basStationService.list(queryWrapper); List> result = new ArrayList<>(); for (BasStation station : stations) { Map item = new LinkedHashMap<>(); item.put("id", station.getId()); item.put("stationName", station.getStationName()); item.put("stationId", station.getStationId()); item.put("type", station.getType()); item.put("typeLabel", station.getType$()); item.put("useStatus", station.getUseStatus()); item.put("useStatusLabel", station.getUseStatus$()); item.put("area", station.getArea()); item.put("areaLabel", station.getArea$()); item.put("isWcs", station.getIsWcs()); item.put("inAble", station.getInAble()); item.put("outAble", station.getOutAble()); item.put("status", station.getStatus()); result.add(item); } return result; } @Tool(name = "rsf_query_dict_data", description = "只读查询工具。根据字典类型编码查询字典数据,可按值或标签进一步过滤。") public List> queryDictData( @ToolParam(required = true, description = "字典类型编码") String dictTypeCode, @ToolParam(description = "字典值,可选") String value, @ToolParam(description = "字典标签,可选") String label, @ToolParam(description = "返回条数,默认 20,最大 100") Integer limit) { String normalizedDictTypeCode = BuiltinToolGovernanceSupport.sanitizeQueryText(dictTypeCode, "字典类型编码", 64); String normalizedValue = BuiltinToolGovernanceSupport.sanitizeQueryText(value, "字典值", 64); String normalizedLabel = BuiltinToolGovernanceSupport.sanitizeQueryText(label, "字典标签", 100); if (!StringUtils.hasText(normalizedDictTypeCode)) { throw new CoolException("字典类型编码不能为空"); } int finalLimit = BuiltinToolGovernanceSupport.normalizeLimit(limit, 20, 100); LambdaQueryWrapper queryWrapper = new LambdaQueryWrapper() .eq(DictData::getDictTypeCode, normalizedDictTypeCode); if (StringUtils.hasText(normalizedValue)) { queryWrapper.like(DictData::getValue, normalizedValue); } if (StringUtils.hasText(normalizedLabel)) { queryWrapper.like(DictData::getLabel, normalizedLabel); } queryWrapper.orderByAsc(DictData::getSort).last("LIMIT " + finalLimit); List dictDataList = dictDataService.list(queryWrapper); List> result = new ArrayList<>(); for (DictData dictData : dictDataList) { Map item = new LinkedHashMap<>(); item.put("id", dictData.getId()); item.put("dictTypeCode", dictData.getDictTypeCode()); item.put("value", dictData.getValue()); item.put("label", dictData.getLabel()); item.put("sort", dictData.getSort()); item.put("color", dictData.getColor()); item.put("group", dictData.getGroup()); item.put("status", dictData.getStatus()); item.put("statusLabel", dictData.getStatus$()); result.add(item); } return result; } }