package com.vincent.rsf.server.ai.tool; import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import com.vincent.rsf.framework.exception.CoolException; import com.vincent.rsf.server.manager.entity.Task; import com.vincent.rsf.server.manager.service.TaskService; import lombok.RequiredArgsConstructor; import org.springframework.ai.tool.annotation.Tool; import org.springframework.ai.tool.annotation.ToolParam; import org.springframework.stereotype.Component; import org.springframework.util.StringUtils; import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @Component @RequiredArgsConstructor public class RsfWmsTaskTools { private final TaskService taskService; @Tool(name = "rsf_query_task_list", description = "按任务号、状态、任务类型、源站点、目标站点等条件查询任务列表。") public List> queryTaskList( @ToolParam(description = "任务号,可模糊查询") String taskCode, @ToolParam(description = "任务状态,可选") Integer taskStatus, @ToolParam(description = "任务类型,可选") Integer taskType, @ToolParam(description = "源站点,可选") String orgSite, @ToolParam(description = "目标站点,可选") String targSite, @ToolParam(description = "返回条数,默认 10,最大 50") Integer limit) { LambdaQueryWrapper queryWrapper = new LambdaQueryWrapper<>(); int finalLimit = normalizeLimit(limit, 10, 50); if (StringUtils.hasText(taskCode)) { queryWrapper.like(Task::getTaskCode, taskCode); } if (taskStatus != null) { queryWrapper.eq(Task::getTaskStatus, taskStatus); } if (taskType != null) { queryWrapper.eq(Task::getTaskType, taskType); } if (StringUtils.hasText(orgSite)) { queryWrapper.eq(Task::getOrgSite, orgSite); } if (StringUtils.hasText(targSite)) { queryWrapper.eq(Task::getTargSite, targSite); } queryWrapper.orderByDesc(Task::getCreateTime).last("LIMIT " + finalLimit); List tasks = taskService.list(queryWrapper); List> result = new ArrayList<>(); for (Task task : tasks) { result.add(buildTaskSummary(task)); } return result; } @Tool(name = "rsf_query_task_detail", description = "根据任务 ID 或任务号查询任务详情。") public Map queryTaskDetail( @ToolParam(description = "任务 ID") Long taskId, @ToolParam(description = "任务号") String taskCode) { if (taskId == null && !StringUtils.hasText(taskCode)) { throw new CoolException("任务 ID 和任务号至少需要提供一个"); } Task task; if (taskId != null) { task = taskService.getById(taskId); } else { task = taskService.getOne(new LambdaQueryWrapper().eq(Task::getTaskCode, taskCode)); } if (task == null) { throw new CoolException("未查询到任务"); } Map result = buildTaskSummary(task); result.put("resource", task.getResource()); result.put("exceStatus", task.getExceStatus()); result.put("orgLoc", task.getOrgLoc()); result.put("targLoc", task.getTargLoc()); result.put("orgSite", task.getOrgSite()); result.put("orgSiteLabel", task.getOrgSite$()); result.put("targSite", task.getTargSite()); result.put("targSiteLabel", task.getTargSite$()); result.put("barcode", task.getBarcode()); result.put("robotCode", task.getRobotCode()); result.put("memo", task.getMemo()); result.put("expCode", task.getExpCode()); result.put("expDesc", task.getExpDesc()); result.put("startTime", task.getStartTime$()); result.put("endTime", task.getEndTime$()); result.put("createTime", task.getCreateTime$()); result.put("updateTime", task.getUpdateTime$()); return result; } private Map buildTaskSummary(Task task) { Map item = new LinkedHashMap<>(); item.put("id", task.getId()); item.put("taskCode", task.getTaskCode()); item.put("taskStatus", task.getTaskStatus()); item.put("taskStatusLabel", task.getTaskStatus$()); item.put("taskType", task.getTaskType()); item.put("taskTypeLabel", task.getTaskType$()); item.put("orgSite", task.getOrgSite()); item.put("orgSiteLabel", task.getOrgSite$()); item.put("targSite", task.getTargSite()); item.put("targSiteLabel", task.getTargSite$()); item.put("status", task.getStatus()); item.put("statusLabel", task.getStatus$()); item.put("createTime", task.getCreateTime$()); item.put("updateTime", task.getUpdateTime$()); return item; } private int normalizeLimit(Integer limit, int defaultValue, int maxValue) { if (limit == null) { return defaultValue; } if (limit < 1 || limit > maxValue) { throw new CoolException("limit 必须在 1 到 " + maxValue + " 之间"); } return limit; } }