package com.vincent.rsf.server.ai.tool;
|
|
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
import com.vincent.rsf.framework.exception.CoolException;
|
import com.vincent.rsf.server.manager.entity.Task;
|
import com.vincent.rsf.server.manager.service.TaskService;
|
import lombok.RequiredArgsConstructor;
|
import org.springframework.ai.tool.annotation.Tool;
|
import org.springframework.ai.tool.annotation.ToolParam;
|
import org.springframework.stereotype.Component;
|
import org.springframework.util.StringUtils;
|
|
import java.util.ArrayList;
|
import java.util.LinkedHashMap;
|
import java.util.List;
|
import java.util.Map;
|
|
@Component
|
@RequiredArgsConstructor
|
public class RsfWmsTaskTools {
|
|
private final TaskService taskService;
|
|
/**
|
* 查询任务列表。
|
* 方法要求至少带一个过滤条件,避免模型把任务表当作可直接遍历的数据源。
|
*/
|
@Tool(name = "rsf_query_task_list", description = "只读查询工具。按任务号、状态、任务类型、源站点、目标站点等条件查询任务列表。")
|
public List<Map<String, Object>> queryTaskList(
|
@ToolParam(description = "任务号,可模糊查询") String taskCode,
|
@ToolParam(description = "任务状态,可选") Integer taskStatus,
|
@ToolParam(description = "任务类型,可选") Integer taskType,
|
@ToolParam(description = "源站点,可选") String orgSite,
|
@ToolParam(description = "目标站点,可选") String targSite,
|
@ToolParam(description = "返回条数,默认 10,最大 50") Integer limit) {
|
String normalizedTaskCode = BuiltinToolGovernanceSupport.sanitizeQueryText(taskCode, "任务号", 64);
|
String normalizedOrgSite = BuiltinToolGovernanceSupport.sanitizeQueryText(orgSite, "源站点", 64);
|
String normalizedTargSite = BuiltinToolGovernanceSupport.sanitizeQueryText(targSite, "目标站点", 64);
|
BuiltinToolGovernanceSupport.requireAnyFilter("任务列表查询至少需要提供一个过滤条件",
|
normalizedTaskCode, normalizedOrgSite, normalizedTargSite,
|
taskStatus == null ? null : String.valueOf(taskStatus),
|
taskType == null ? null : String.valueOf(taskType));
|
LambdaQueryWrapper<Task> queryWrapper = new LambdaQueryWrapper<>();
|
int finalLimit = BuiltinToolGovernanceSupport.normalizeLimit(limit, 10, 50);
|
if (StringUtils.hasText(normalizedTaskCode)) {
|
queryWrapper.like(Task::getTaskCode, normalizedTaskCode);
|
}
|
if (taskStatus != null) {
|
queryWrapper.eq(Task::getTaskStatus, taskStatus);
|
}
|
if (taskType != null) {
|
queryWrapper.eq(Task::getTaskType, taskType);
|
}
|
if (StringUtils.hasText(normalizedOrgSite)) {
|
queryWrapper.eq(Task::getOrgSite, normalizedOrgSite);
|
}
|
if (StringUtils.hasText(normalizedTargSite)) {
|
queryWrapper.eq(Task::getTargSite, normalizedTargSite);
|
}
|
queryWrapper.orderByDesc(Task::getCreateTime).last("LIMIT " + finalLimit);
|
List<Task> tasks = taskService.list(queryWrapper);
|
List<Map<String, Object>> result = new ArrayList<>();
|
for (Task task : tasks) {
|
result.add(buildTaskSummary(task));
|
}
|
return result;
|
}
|
|
/**
|
* 查询单个任务详情。
|
* 与列表查询不同,这里允许返回更丰富的字段,但仍然要求调用方通过任务 ID 或任务号做精确定位。
|
*/
|
@Tool(name = "rsf_query_task_detail", description = "只读查询工具。根据任务 ID 或任务号查询任务详情。")
|
public Map<String, Object> queryTaskDetail(
|
@ToolParam(description = "任务 ID") Long taskId,
|
@ToolParam(description = "任务号") String taskCode) {
|
String normalizedTaskCode = BuiltinToolGovernanceSupport.sanitizeQueryText(taskCode, "任务号", 64);
|
if (taskId == null && !StringUtils.hasText(normalizedTaskCode)) {
|
throw new CoolException("任务 ID 和任务号至少需要提供一个");
|
}
|
Task task;
|
if (taskId != null) {
|
task = taskService.getById(taskId);
|
} else {
|
task = taskService.getOne(new LambdaQueryWrapper<Task>().eq(Task::getTaskCode, normalizedTaskCode));
|
}
|
if (task == null) {
|
throw new CoolException("未查询到任务");
|
}
|
Map<String, Object> result = buildTaskSummary(task);
|
result.put("resource", task.getResource());
|
result.put("exceStatus", task.getExceStatus());
|
result.put("orgLoc", task.getOrgLoc());
|
result.put("targLoc", task.getTargLoc());
|
result.put("orgSite", task.getOrgSite());
|
result.put("orgSiteLabel", task.getOrgSite$());
|
result.put("targSite", task.getTargSite());
|
result.put("targSiteLabel", task.getTargSite$());
|
result.put("barcode", task.getBarcode());
|
result.put("robotCode", task.getRobotCode());
|
result.put("memo", task.getMemo());
|
result.put("expCode", task.getExpCode());
|
result.put("expDesc", task.getExpDesc());
|
result.put("startTime", task.getStartTime$());
|
result.put("endTime", task.getEndTime$());
|
result.put("createTime", task.getCreateTime$());
|
result.put("updateTime", task.getUpdateTime$());
|
return result;
|
}
|
|
private Map<String, Object> buildTaskSummary(Task task) {
|
/** 把任务实体收敛为适合模型阅读和前端展示的摘要结构。 */
|
Map<String, Object> item = new LinkedHashMap<>();
|
item.put("id", task.getId());
|
item.put("taskCode", task.getTaskCode());
|
item.put("taskStatus", task.getTaskStatus());
|
item.put("taskStatusLabel", task.getTaskStatus$());
|
item.put("taskType", task.getTaskType());
|
item.put("taskTypeLabel", task.getTaskType$());
|
item.put("orgSite", task.getOrgSite());
|
item.put("orgSiteLabel", task.getOrgSite$());
|
item.put("targSite", task.getTargSite());
|
item.put("targSiteLabel", task.getTargSite$());
|
item.put("status", task.getStatus());
|
item.put("statusLabel", task.getStatus$());
|
item.put("createTime", task.getCreateTime$());
|
item.put("updateTime", task.getUpdateTime$());
|
return item;
|
}
|
|
}
|