package com.vincent.rsf.server.ai.tool; import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import com.vincent.rsf.framework.exception.CoolException; import com.vincent.rsf.server.manager.entity.BasStation; import com.vincent.rsf.server.manager.entity.Warehouse; import com.vincent.rsf.server.manager.service.BasStationService; import com.vincent.rsf.server.manager.service.WarehouseService; import com.vincent.rsf.server.system.entity.DictData; import com.vincent.rsf.server.system.service.DictDataService; import lombok.RequiredArgsConstructor; import org.springframework.ai.tool.annotation.Tool; import org.springframework.ai.tool.annotation.ToolParam; import org.springframework.stereotype.Component; import org.springframework.util.StringUtils; import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @Component @RequiredArgsConstructor public class RsfWmsBaseTools { private final WarehouseService warehouseService; private final BasStationService basStationService; private final DictDataService dictDataService; @Tool(name = "rsf_query_warehouses", description = "按仓库编码或名称查询仓库基础信息。") public List> queryWarehouses( @ToolParam(description = "仓库编码,可选") String code, @ToolParam(description = "仓库名称,可选") String name, @ToolParam(description = "返回条数,默认 10,最大 50") Integer limit) { LambdaQueryWrapper queryWrapper = new LambdaQueryWrapper<>(); int finalLimit = normalizeLimit(limit, 10, 50); if (StringUtils.hasText(code)) { queryWrapper.like(Warehouse::getCode, code); } if (StringUtils.hasText(name)) { queryWrapper.like(Warehouse::getName, name); } queryWrapper.orderByAsc(Warehouse::getCode).last("LIMIT " + finalLimit); List warehouses = warehouseService.list(queryWrapper); List> result = new ArrayList<>(); for (Warehouse warehouse : warehouses) { Map item = new LinkedHashMap<>(); item.put("id", warehouse.getId()); item.put("code", warehouse.getCode()); item.put("name", warehouse.getName()); item.put("factory", warehouse.getFactory()); item.put("address", warehouse.getAddress()); item.put("status", warehouse.getStatus()); item.put("statusLabel", warehouse.getStatus$()); item.put("memo", warehouse.getMemo()); result.add(item); } return result; } @Tool(name = "rsf_query_bas_stations", description = "按站点编号、站点名称或使用状态查询基础站点。") public List> queryBasStations( @ToolParam(description = "站点编号,可选") String stationName, @ToolParam(description = "站点名称,可选") String stationId, @ToolParam(description = "使用状态,可选") String useStatus, @ToolParam(description = "返回条数,默认 10,最大 50") Integer limit) { LambdaQueryWrapper queryWrapper = new LambdaQueryWrapper<>(); int finalLimit = normalizeLimit(limit, 10, 50); if (StringUtils.hasText(stationName)) { queryWrapper.like(BasStation::getStationName, stationName); } if (StringUtils.hasText(stationId)) { queryWrapper.like(BasStation::getStationId, stationId); } if (StringUtils.hasText(useStatus)) { queryWrapper.eq(BasStation::getUseStatus, useStatus); } queryWrapper.orderByAsc(BasStation::getStationName).last("LIMIT " + finalLimit); List stations = basStationService.list(queryWrapper); List> result = new ArrayList<>(); for (BasStation station : stations) { Map item = new LinkedHashMap<>(); item.put("id", station.getId()); item.put("stationName", station.getStationName()); item.put("stationId", station.getStationId()); item.put("type", station.getType()); item.put("typeLabel", station.getType$()); item.put("useStatus", station.getUseStatus()); item.put("useStatusLabel", station.getUseStatus$()); item.put("area", station.getArea()); item.put("areaLabel", station.getArea$()); item.put("isWcs", station.getIsWcs()); item.put("inAble", station.getInAble()); item.put("outAble", station.getOutAble()); item.put("status", station.getStatus()); result.add(item); } return result; } @Tool(name = "rsf_query_dict_data", description = "根据字典类型编码查询字典数据,可按值或标签进一步过滤。") public List> queryDictData( @ToolParam(required = true, description = "字典类型编码") String dictTypeCode, @ToolParam(description = "字典值,可选") String value, @ToolParam(description = "字典标签,可选") String label, @ToolParam(description = "返回条数,默认 20,最大 100") Integer limit) { if (!StringUtils.hasText(dictTypeCode)) { throw new CoolException("字典类型编码不能为空"); } int finalLimit = normalizeLimit(limit, 20, 100); LambdaQueryWrapper queryWrapper = new LambdaQueryWrapper() .eq(DictData::getDictTypeCode, dictTypeCode); if (StringUtils.hasText(value)) { queryWrapper.like(DictData::getValue, value); } if (StringUtils.hasText(label)) { queryWrapper.like(DictData::getLabel, label); } queryWrapper.orderByAsc(DictData::getSort).last("LIMIT " + finalLimit); List dictDataList = dictDataService.list(queryWrapper); List> result = new ArrayList<>(); for (DictData dictData : dictDataList) { Map item = new LinkedHashMap<>(); item.put("id", dictData.getId()); item.put("dictTypeCode", dictData.getDictTypeCode()); item.put("value", dictData.getValue()); item.put("label", dictData.getLabel()); item.put("sort", dictData.getSort()); item.put("color", dictData.getColor()); item.put("group", dictData.getGroup()); item.put("status", dictData.getStatus()); item.put("statusLabel", dictData.getStatus$()); result.add(item); } return result; } private int normalizeLimit(Integer limit, int defaultValue, int maxValue) { if (limit == null) { return defaultValue; } if (limit < 1 || limit > maxValue) { throw new CoolException("limit 必须在 1 到 " + maxValue + " 之间"); } return limit; } }