package com.vincent.rsf.server.ai.tool;
|
|
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
import com.vincent.rsf.framework.exception.CoolException;
|
import com.vincent.rsf.server.manager.entity.BasStation;
|
import com.vincent.rsf.server.manager.entity.Warehouse;
|
import com.vincent.rsf.server.manager.service.BasStationService;
|
import com.vincent.rsf.server.manager.service.WarehouseService;
|
import com.vincent.rsf.server.system.entity.DictData;
|
import com.vincent.rsf.server.system.service.DictDataService;
|
import lombok.RequiredArgsConstructor;
|
import org.springframework.ai.tool.annotation.Tool;
|
import org.springframework.ai.tool.annotation.ToolParam;
|
import org.springframework.stereotype.Component;
|
import org.springframework.util.StringUtils;
|
|
import java.util.ArrayList;
|
import java.util.LinkedHashMap;
|
import java.util.List;
|
import java.util.Map;
|
|
@Component
|
@RequiredArgsConstructor
|
public class RsfWmsBaseTools {
|
|
private final WarehouseService warehouseService;
|
private final BasStationService basStationService;
|
private final DictDataService dictDataService;
|
|
/**
|
* 查询仓库基础信息。
|
* 该工具面向“按编码/名称定位仓库”的问答场景,不负责提供全量仓库主数据导出能力。
|
*/
|
@Tool(name = "rsf_query_warehouses", description = "只读查询工具。按仓库编码或名称查询仓库基础信息。")
|
public List<Map<String, Object>> queryWarehouses(
|
@ToolParam(description = "仓库编码,可选") String code,
|
@ToolParam(description = "仓库名称,可选") String name,
|
@ToolParam(description = "返回条数,默认 10,最大 50") Integer limit) {
|
String normalizedCode = BuiltinToolGovernanceSupport.sanitizeQueryText(code, "仓库编码", 64);
|
String normalizedName = BuiltinToolGovernanceSupport.sanitizeQueryText(name, "仓库名称", 100);
|
BuiltinToolGovernanceSupport.requireAnyFilter("仓库查询至少需要提供仓库编码或名称", normalizedCode, normalizedName);
|
LambdaQueryWrapper<Warehouse> queryWrapper = new LambdaQueryWrapper<>();
|
int finalLimit = BuiltinToolGovernanceSupport.normalizeLimit(limit, 10, 50);
|
if (StringUtils.hasText(normalizedCode)) {
|
queryWrapper.like(Warehouse::getCode, normalizedCode);
|
}
|
if (StringUtils.hasText(normalizedName)) {
|
queryWrapper.like(Warehouse::getName, normalizedName);
|
}
|
queryWrapper.orderByAsc(Warehouse::getCode).last("LIMIT " + finalLimit);
|
List<Warehouse> warehouses = warehouseService.list(queryWrapper);
|
List<Map<String, Object>> result = new ArrayList<>();
|
for (Warehouse warehouse : warehouses) {
|
Map<String, Object> item = new LinkedHashMap<>();
|
item.put("id", warehouse.getId());
|
item.put("code", warehouse.getCode());
|
item.put("name", warehouse.getName());
|
item.put("factory", warehouse.getFactory());
|
item.put("address", warehouse.getAddress());
|
item.put("status", warehouse.getStatus());
|
item.put("statusLabel", warehouse.getStatus$());
|
item.put("memo", warehouse.getMemo());
|
result.add(item);
|
}
|
return result;
|
}
|
|
/**
|
* 查询基础站点信息。
|
* 查询条件允许按站点名称、编号或使用状态组合过滤,返回值只保留 AI 对话需要的字段。
|
*/
|
@Tool(name = "rsf_query_bas_stations", description = "只读查询工具。按站点编号、站点名称或使用状态查询基础站点。")
|
public List<Map<String, Object>> queryBasStations(
|
@ToolParam(description = "站点名称,可选") String stationName,
|
@ToolParam(description = "站点编号,可选") String stationId,
|
@ToolParam(description = "使用状态,可选") String useStatus,
|
@ToolParam(description = "返回条数,默认 10,最大 50") Integer limit) {
|
String normalizedStationName = BuiltinToolGovernanceSupport.sanitizeQueryText(stationName, "站点名称", 100);
|
String normalizedStationId = BuiltinToolGovernanceSupport.sanitizeQueryText(stationId, "站点编号", 64);
|
String normalizedUseStatus = BuiltinToolGovernanceSupport.sanitizeQueryText(useStatus, "使用状态", 32);
|
BuiltinToolGovernanceSupport.requireAnyFilter("基础站点查询至少需要提供站点名称、站点编号或使用状态",
|
normalizedStationName, normalizedStationId, normalizedUseStatus);
|
LambdaQueryWrapper<BasStation> queryWrapper = new LambdaQueryWrapper<>();
|
int finalLimit = BuiltinToolGovernanceSupport.normalizeLimit(limit, 10, 50);
|
if (StringUtils.hasText(normalizedStationName)) {
|
queryWrapper.like(BasStation::getStationName, normalizedStationName);
|
}
|
if (StringUtils.hasText(normalizedStationId)) {
|
queryWrapper.like(BasStation::getStationId, normalizedStationId);
|
}
|
if (StringUtils.hasText(normalizedUseStatus)) {
|
queryWrapper.eq(BasStation::getUseStatus, normalizedUseStatus);
|
}
|
queryWrapper.orderByAsc(BasStation::getStationName).last("LIMIT " + finalLimit);
|
List<BasStation> stations = basStationService.list(queryWrapper);
|
List<Map<String, Object>> result = new ArrayList<>();
|
for (BasStation station : stations) {
|
Map<String, Object> item = new LinkedHashMap<>();
|
item.put("id", station.getId());
|
item.put("stationName", station.getStationName());
|
item.put("stationId", station.getStationId());
|
item.put("type", station.getType());
|
item.put("typeLabel", station.getType$());
|
item.put("useStatus", station.getUseStatus());
|
item.put("useStatusLabel", station.getUseStatus$());
|
item.put("area", station.getArea());
|
item.put("areaLabel", station.getArea$());
|
item.put("isWcs", station.getIsWcs());
|
item.put("inAble", station.getInAble());
|
item.put("outAble", station.getOutAble());
|
item.put("status", station.getStatus());
|
result.add(item);
|
}
|
return result;
|
}
|
|
/**
|
* 查询字典数据。
|
* 字典类型编码是强制条件,用来确保模型不会越过业务边界直接遍历整张字典表。
|
*/
|
@Tool(name = "rsf_query_dict_data", description = "只读查询工具。根据字典类型编码查询字典数据,可按值或标签进一步过滤。")
|
public List<Map<String, Object>> queryDictData(
|
@ToolParam(required = true, description = "字典类型编码") String dictTypeCode,
|
@ToolParam(description = "字典值,可选") String value,
|
@ToolParam(description = "字典标签,可选") String label,
|
@ToolParam(description = "返回条数,默认 20,最大 100") Integer limit) {
|
String normalizedDictTypeCode = BuiltinToolGovernanceSupport.sanitizeQueryText(dictTypeCode, "字典类型编码", 64);
|
String normalizedValue = BuiltinToolGovernanceSupport.sanitizeQueryText(value, "字典值", 64);
|
String normalizedLabel = BuiltinToolGovernanceSupport.sanitizeQueryText(label, "字典标签", 100);
|
if (!StringUtils.hasText(normalizedDictTypeCode)) {
|
throw new CoolException("字典类型编码不能为空");
|
}
|
int finalLimit = BuiltinToolGovernanceSupport.normalizeLimit(limit, 20, 100);
|
LambdaQueryWrapper<DictData> queryWrapper = new LambdaQueryWrapper<DictData>()
|
.eq(DictData::getDictTypeCode, normalizedDictTypeCode);
|
if (StringUtils.hasText(normalizedValue)) {
|
queryWrapper.like(DictData::getValue, normalizedValue);
|
}
|
if (StringUtils.hasText(normalizedLabel)) {
|
queryWrapper.like(DictData::getLabel, normalizedLabel);
|
}
|
queryWrapper.orderByAsc(DictData::getSort).last("LIMIT " + finalLimit);
|
List<DictData> dictDataList = dictDataService.list(queryWrapper);
|
List<Map<String, Object>> result = new ArrayList<>();
|
for (DictData dictData : dictDataList) {
|
Map<String, Object> item = new LinkedHashMap<>();
|
item.put("id", dictData.getId());
|
item.put("dictTypeCode", dictData.getDictTypeCode());
|
item.put("value", dictData.getValue());
|
item.put("label", dictData.getLabel());
|
item.put("sort", dictData.getSort());
|
item.put("color", dictData.getColor());
|
item.put("group", dictData.getGroup());
|
item.put("status", dictData.getStatus());
|
item.put("statusLabel", dictData.getStatus$());
|
result.add(item);
|
}
|
return result;
|
}
|
|
}
|