package com.vincent.rsf.server.ai.tool;
|
|
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
import com.vincent.rsf.framework.exception.CoolException;
|
import com.vincent.rsf.server.manager.entity.BasStation;
|
import com.vincent.rsf.server.manager.entity.Warehouse;
|
import com.vincent.rsf.server.manager.service.BasStationService;
|
import com.vincent.rsf.server.manager.service.WarehouseService;
|
import com.vincent.rsf.server.system.entity.DictData;
|
import com.vincent.rsf.server.system.service.DictDataService;
|
import lombok.RequiredArgsConstructor;
|
import org.springframework.ai.tool.annotation.Tool;
|
import org.springframework.ai.tool.annotation.ToolParam;
|
import org.springframework.stereotype.Component;
|
import org.springframework.util.StringUtils;
|
|
import java.util.ArrayList;
|
import java.util.LinkedHashMap;
|
import java.util.List;
|
import java.util.Map;
|
|
@Component
|
@RequiredArgsConstructor
|
public class RsfWmsBaseTools {
|
|
private final WarehouseService warehouseService;
|
private final BasStationService basStationService;
|
private final DictDataService dictDataService;
|
|
@Tool(name = "rsf_query_warehouses", description = "按仓库编码或名称查询仓库基础信息。")
|
public List<Map<String, Object>> queryWarehouses(
|
@ToolParam(description = "仓库编码,可选") String code,
|
@ToolParam(description = "仓库名称,可选") String name,
|
@ToolParam(description = "返回条数,默认 10,最大 50") Integer limit) {
|
LambdaQueryWrapper<Warehouse> queryWrapper = new LambdaQueryWrapper<>();
|
int finalLimit = normalizeLimit(limit, 10, 50);
|
if (StringUtils.hasText(code)) {
|
queryWrapper.like(Warehouse::getCode, code);
|
}
|
if (StringUtils.hasText(name)) {
|
queryWrapper.like(Warehouse::getName, name);
|
}
|
queryWrapper.orderByAsc(Warehouse::getCode).last("LIMIT " + finalLimit);
|
List<Warehouse> warehouses = warehouseService.list(queryWrapper);
|
List<Map<String, Object>> result = new ArrayList<>();
|
for (Warehouse warehouse : warehouses) {
|
Map<String, Object> item = new LinkedHashMap<>();
|
item.put("id", warehouse.getId());
|
item.put("code", warehouse.getCode());
|
item.put("name", warehouse.getName());
|
item.put("factory", warehouse.getFactory());
|
item.put("address", warehouse.getAddress());
|
item.put("status", warehouse.getStatus());
|
item.put("statusLabel", warehouse.getStatus$());
|
item.put("memo", warehouse.getMemo());
|
result.add(item);
|
}
|
return result;
|
}
|
|
@Tool(name = "rsf_query_bas_stations", description = "按站点编号、站点名称或使用状态查询基础站点。")
|
public List<Map<String, Object>> queryBasStations(
|
@ToolParam(description = "站点编号,可选") String stationName,
|
@ToolParam(description = "站点名称,可选") String stationId,
|
@ToolParam(description = "使用状态,可选") String useStatus,
|
@ToolParam(description = "返回条数,默认 10,最大 50") Integer limit) {
|
LambdaQueryWrapper<BasStation> queryWrapper = new LambdaQueryWrapper<>();
|
int finalLimit = normalizeLimit(limit, 10, 50);
|
if (StringUtils.hasText(stationName)) {
|
queryWrapper.like(BasStation::getStationName, stationName);
|
}
|
if (StringUtils.hasText(stationId)) {
|
queryWrapper.like(BasStation::getStationId, stationId);
|
}
|
if (StringUtils.hasText(useStatus)) {
|
queryWrapper.eq(BasStation::getUseStatus, useStatus);
|
}
|
queryWrapper.orderByAsc(BasStation::getStationName).last("LIMIT " + finalLimit);
|
List<BasStation> stations = basStationService.list(queryWrapper);
|
List<Map<String, Object>> result = new ArrayList<>();
|
for (BasStation station : stations) {
|
Map<String, Object> item = new LinkedHashMap<>();
|
item.put("id", station.getId());
|
item.put("stationName", station.getStationName());
|
item.put("stationId", station.getStationId());
|
item.put("type", station.getType());
|
item.put("typeLabel", station.getType$());
|
item.put("useStatus", station.getUseStatus());
|
item.put("useStatusLabel", station.getUseStatus$());
|
item.put("area", station.getArea());
|
item.put("areaLabel", station.getArea$());
|
item.put("isWcs", station.getIsWcs());
|
item.put("inAble", station.getInAble());
|
item.put("outAble", station.getOutAble());
|
item.put("status", station.getStatus());
|
result.add(item);
|
}
|
return result;
|
}
|
|
@Tool(name = "rsf_query_dict_data", description = "根据字典类型编码查询字典数据,可按值或标签进一步过滤。")
|
public List<Map<String, Object>> queryDictData(
|
@ToolParam(required = true, description = "字典类型编码") String dictTypeCode,
|
@ToolParam(description = "字典值,可选") String value,
|
@ToolParam(description = "字典标签,可选") String label,
|
@ToolParam(description = "返回条数,默认 20,最大 100") Integer limit) {
|
if (!StringUtils.hasText(dictTypeCode)) {
|
throw new CoolException("字典类型编码不能为空");
|
}
|
int finalLimit = normalizeLimit(limit, 20, 100);
|
LambdaQueryWrapper<DictData> queryWrapper = new LambdaQueryWrapper<DictData>()
|
.eq(DictData::getDictTypeCode, dictTypeCode);
|
if (StringUtils.hasText(value)) {
|
queryWrapper.like(DictData::getValue, value);
|
}
|
if (StringUtils.hasText(label)) {
|
queryWrapper.like(DictData::getLabel, label);
|
}
|
queryWrapper.orderByAsc(DictData::getSort).last("LIMIT " + finalLimit);
|
List<DictData> dictDataList = dictDataService.list(queryWrapper);
|
List<Map<String, Object>> result = new ArrayList<>();
|
for (DictData dictData : dictDataList) {
|
Map<String, Object> item = new LinkedHashMap<>();
|
item.put("id", dictData.getId());
|
item.put("dictTypeCode", dictData.getDictTypeCode());
|
item.put("value", dictData.getValue());
|
item.put("label", dictData.getLabel());
|
item.put("sort", dictData.getSort());
|
item.put("color", dictData.getColor());
|
item.put("group", dictData.getGroup());
|
item.put("status", dictData.getStatus());
|
item.put("statusLabel", dictData.getStatus$());
|
result.add(item);
|
}
|
return result;
|
}
|
|
private int normalizeLimit(Integer limit, int defaultValue, int maxValue) {
|
if (limit == null) {
|
return defaultValue;
|
}
|
if (limit < 1 || limit > maxValue) {
|
throw new CoolException("limit 必须在 1 到 " + maxValue + " 之间");
|
}
|
return limit;
|
}
|
}
|