#ai
zhou zhou
12 小时以前 1668b4ce8fb82ddfd54b44b86e78e3080b99a1cc
#ai
1个文件已添加
3个文件已修改
236 ■■■■ 已修改文件
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java 15 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiOpenAiApiSupport.java 86 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiParamValidationSupport.java 19 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
rsf-server/src/main/java/com/vincent/rsf/server/ai/tool/RsfWmsTaskTools.java 116 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java
@@ -51,12 +51,9 @@
import org.springframework.ai.util.json.schema.SchemaType;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.http.MediaType;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
import org.springframework.web.client.RestClient;
import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import reactor.core.publisher.Flux;
@@ -479,17 +476,7 @@
    }
    private OpenAiApi buildOpenAiApi(AiParam aiParam) {
        int timeoutMs = aiParam.getTimeoutMs() == null ? AiDefaults.DEFAULT_TIMEOUT_MS : aiParam.getTimeoutMs();
        SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory();
        requestFactory.setConnectTimeout(timeoutMs);
        requestFactory.setReadTimeout(timeoutMs);
        return OpenAiApi.builder()
                .baseUrl(aiParam.getBaseUrl())
                .apiKey(aiParam.getApiKey())
                .restClientBuilder(RestClient.builder().requestFactory(requestFactory))
                .webClientBuilder(WebClient.builder())
                .build();
        return AiOpenAiApiSupport.buildOpenAiApi(aiParam);
    }
    private OpenAiChatOptions buildChatOptions(AiParam aiParam, ToolCallback[] toolCallbacks, Long userId, Long tenantId,
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiOpenAiApiSupport.java
New file
@@ -0,0 +1,86 @@
package com.vincent.rsf.server.ai.service.impl;
import com.vincent.rsf.server.ai.config.AiDefaults;
import com.vincent.rsf.server.ai.entity.AiParam;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.util.StringUtils;
import org.springframework.web.client.RestClient;
import org.springframework.web.reactive.function.client.WebClient;
import java.util.Locale;
final class AiOpenAiApiSupport {
    private static final String DEFAULT_COMPLETIONS_PATH = "/v1/chat/completions";
    private static final String DEFAULT_EMBEDDINGS_PATH = "/v1/embeddings";
    private static final String V1_SEGMENT = "/v1";
    private static final String COMPLETIONS_SEGMENT = "/chat/completions";
    private static final String EMBEDDINGS_SEGMENT = "/embeddings";
    private AiOpenAiApiSupport() {
    }
    static OpenAiApi buildOpenAiApi(AiParam aiParam) {
        int timeoutMs = aiParam.getTimeoutMs() == null ? AiDefaults.DEFAULT_TIMEOUT_MS : aiParam.getTimeoutMs();
        SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory();
        requestFactory.setConnectTimeout(timeoutMs);
        requestFactory.setReadTimeout(timeoutMs);
        EndpointConfig endpointConfig = resolveEndpointConfig(aiParam.getBaseUrl());
        return OpenAiApi.builder()
                .baseUrl(endpointConfig.baseUrl())
                .completionsPath(endpointConfig.completionsPath())
                .embeddingsPath(endpointConfig.embeddingsPath())
                .apiKey(aiParam.getApiKey())
                .restClientBuilder(RestClient.builder().requestFactory(requestFactory))
                .webClientBuilder(WebClient.builder())
                .build();
    }
    static EndpointConfig resolveEndpointConfig(String rawBaseUrl) {
        String normalizedBaseUrl = trimTrailingSlash(rawBaseUrl);
        String lowerCaseBaseUrl = normalizedBaseUrl.toLowerCase(Locale.ROOT);
        if (lowerCaseBaseUrl.endsWith(DEFAULT_COMPLETIONS_PATH)) {
            String baseUrl = trimTrailingSlash(normalizedBaseUrl.substring(0,
                    normalizedBaseUrl.length() - DEFAULT_COMPLETIONS_PATH.length()));
            return new EndpointConfig(baseUrl, DEFAULT_COMPLETIONS_PATH, DEFAULT_EMBEDDINGS_PATH);
        }
        if (lowerCaseBaseUrl.endsWith(DEFAULT_EMBEDDINGS_PATH)) {
            String baseUrl = trimTrailingSlash(normalizedBaseUrl.substring(0,
                    normalizedBaseUrl.length() - DEFAULT_EMBEDDINGS_PATH.length()));
            return new EndpointConfig(baseUrl, DEFAULT_COMPLETIONS_PATH, DEFAULT_EMBEDDINGS_PATH);
        }
        if (lowerCaseBaseUrl.endsWith(COMPLETIONS_SEGMENT)) {
            String baseUrl = trimTrailingSlash(normalizedBaseUrl.substring(0,
                    normalizedBaseUrl.length() - COMPLETIONS_SEGMENT.length()));
            return new EndpointConfig(baseUrl, COMPLETIONS_SEGMENT, EMBEDDINGS_SEGMENT);
        }
        if (lowerCaseBaseUrl.endsWith(EMBEDDINGS_SEGMENT)) {
            String baseUrl = trimTrailingSlash(normalizedBaseUrl.substring(0,
                    normalizedBaseUrl.length() - EMBEDDINGS_SEGMENT.length()));
            return new EndpointConfig(baseUrl, COMPLETIONS_SEGMENT, EMBEDDINGS_SEGMENT);
        }
        if (lowerCaseBaseUrl.endsWith(V1_SEGMENT)) {
            String baseUrl = trimTrailingSlash(normalizedBaseUrl.substring(0,
                    normalizedBaseUrl.length() - V1_SEGMENT.length()));
            return new EndpointConfig(baseUrl, DEFAULT_COMPLETIONS_PATH, DEFAULT_EMBEDDINGS_PATH);
        }
        return new EndpointConfig(normalizedBaseUrl, DEFAULT_COMPLETIONS_PATH, DEFAULT_EMBEDDINGS_PATH);
    }
    private static String trimTrailingSlash(String baseUrl) {
        String normalized = baseUrl == null ? "" : baseUrl.trim();
        if (!StringUtils.hasText(normalized)) {
            return normalized;
        }
        while (normalized.endsWith("/")) {
            normalized = normalized.substring(0, normalized.length() - 1);
        }
        return normalized;
    }
    record EndpointConfig(String baseUrl, String completionsPath, String embeddingsPath) {
    }
}
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiParamValidationSupport.java
@@ -18,11 +18,8 @@
import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver;
import org.springframework.ai.util.json.schema.SchemaType;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import org.springframework.web.client.RestClient;
import org.springframework.web.reactive.function.client.WebClient;
import java.text.SimpleDateFormat;
import java.util.Date;
@@ -98,20 +95,8 @@
    }
    private OpenAiApi buildOpenAiApi(AiParam aiParam) {
        /**
         * 根据表单里的 Base URL、API Key 和超时参数构造 OpenAI 兼容客户端。
         * 该方法被显式拆出来,是为了让“网络连接参数”和“模型选项”职责分离。
         */
        int timeoutMs = aiParam.getTimeoutMs() == null ? AiDefaults.DEFAULT_TIMEOUT_MS : aiParam.getTimeoutMs();
        SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory();
        requestFactory.setConnectTimeout(timeoutMs);
        requestFactory.setReadTimeout(timeoutMs);
        return OpenAiApi.builder()
                .baseUrl(aiParam.getBaseUrl())
                .apiKey(aiParam.getApiKey())
                .restClientBuilder(RestClient.builder().requestFactory(requestFactory))
                .webClientBuilder(WebClient.builder())
                .build();
        /** 统一兼容根地址、/v1 前缀和完整 completions endpoint 三种常见填法。 */
        return AiOpenAiApiSupport.buildOpenAiApi(aiParam);
    }
    private String formatDate(Date date) {
rsf-server/src/main/java/com/vincent/rsf/server/ai/tool/RsfWmsTaskTools.java
@@ -2,7 +2,10 @@
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.vincent.rsf.framework.exception.CoolException;
import com.vincent.rsf.server.common.utils.FieldsUtils;
import com.vincent.rsf.server.manager.entity.Task;
import com.vincent.rsf.server.manager.entity.TaskItem;
import com.vincent.rsf.server.manager.service.TaskItemService;
import com.vincent.rsf.server.manager.service.TaskService;
import lombok.RequiredArgsConstructor;
import org.springframework.ai.tool.annotation.Tool;
@@ -14,12 +17,14 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@Component
@RequiredArgsConstructor
public class RsfWmsTaskTools {
    private final TaskService taskService;
    private final TaskItemService taskItemService;
    /**
     * 查询任务列表。
@@ -70,42 +75,21 @@
     * 查询单个任务详情。
     * 与列表查询不同,这里允许返回更丰富的字段,但仍然要求调用方通过任务 ID 或任务号做精确定位。
     */
    @Tool(name = "rsf_query_task_detail", description = "只读查询工具。根据任务 ID 或任务号查询任务详情。")
    @Tool(name = "rsf_query_task_detail", description = "只读查询工具。查询任务列表有正常返回值可以根据任务ID查询任务详情。")
    public Map<String, Object> queryTaskDetail(
            @ToolParam(description = "任务 ID") Long taskId,
            @ToolParam(description = "任务号") String taskCode) {
        String normalizedTaskCode = BuiltinToolGovernanceSupport.sanitizeQueryText(taskCode, "任务号", 64);
        if (taskId == null && !StringUtils.hasText(normalizedTaskCode)) {
            throw new CoolException("任务 ID 和任务号至少需要提供一个");
            @ToolParam(description = "任务 ID") Long taskId
            ) {
        if (taskId == null) {
            throw new CoolException("任务 ID 需要提供");
        }
        Task task;
        if (taskId != null) {
            task = taskService.getById(taskId);
        } else {
            task = taskService.getOne(new LambdaQueryWrapper<Task>().eq(Task::getTaskCode, normalizedTaskCode));
        }
        if (task == null) {
        List<TaskItem> taskItems = new ArrayList<>();
        taskItems = taskItemService.list(new LambdaQueryWrapper<TaskItem>().eq(TaskItem::getTaskId, taskId));
        if (taskItems.isEmpty()) {
            throw new CoolException("未查询到任务");
        }
        Map<String, Object> result = buildTaskSummary(task);
        result.put("resource", task.getResource());
        result.put("exceStatus", task.getExceStatus());
        result.put("orgLoc", task.getOrgLoc());
        result.put("targLoc", task.getTargLoc());
        result.put("orgSite", task.getOrgSite());
        result.put("orgSiteLabel", task.getOrgSite$());
        result.put("targSite", task.getTargSite());
        result.put("targSiteLabel", task.getTargSite$());
        result.put("barcode", task.getBarcode());
        result.put("robotCode", task.getRobotCode());
        result.put("memo", task.getMemo());
        result.put("expCode", task.getExpCode());
        result.put("expDesc", task.getExpDesc());
        result.put("startTime", task.getStartTime$());
        result.put("endTime", task.getEndTime$());
        result.put("createTime", task.getCreateTime$());
        result.put("updateTime", task.getUpdateTime$());
        return result;
        return buildTaskItemDetail(taskItems);
    }
    private Map<String, Object> buildTaskSummary(Task task) {
@@ -128,4 +112,72 @@
        return item;
    }
    private Map<String, Object> buildTaskItemDetail(List<TaskItem> taskItems) {
        Map<String, Object> result = new LinkedHashMap<>();
        result.put("taskId", taskItems.get(0).getTaskId());
        result.put("itemCount", taskItems.size());
        double totalAnfme = 0D;
        double totalWorkQty = 0D;
        double totalQty = 0D;
        List<Map<String, Object>> items = new ArrayList<>();
        for (TaskItem taskItem : taskItems) {
            totalAnfme += taskItem.getAnfme() == null ? 0D : taskItem.getAnfme();
            totalWorkQty += taskItem.getWorkQty() == null ? 0D : taskItem.getWorkQty();
            totalQty += taskItem.getQty() == null ? 0D : taskItem.getQty();
            items.add(buildTaskItemRow(taskItem));
        }
        result.put("totalAnfme", totalAnfme);
        result.put("totalWorkQty", totalWorkQty);
        result.put("totalQty", totalQty);
        result.put("items", items);
        return result;
    }
    private Map<String, Object> buildTaskItemRow(TaskItem taskItem) {
        if (!Objects.isNull(taskItem.getFieldsIndex())) {
            taskItem.setExtendFields(FieldsUtils.getFields(taskItem.getFieldsIndex()));
        }
        Map<String, Object> item = new LinkedHashMap<>();
        item.put("id", taskItem.getId());
        item.put("taskId", taskItem.getTaskId());
        item.put("matnrId", taskItem.getMatnrId());
        item.put("matnrCode", taskItem.getMatnrCode());
        item.put("maktx", taskItem.getMaktx());
        item.put("trackCode", taskItem.getTrackCode());
        item.put("splrBatch", taskItem.getSplrBatch());
        item.put("batch", taskItem.getBatch());
        item.put("spec", taskItem.getSpec());
        item.put("model", taskItem.getModel());
        item.put("unit", taskItem.getUnit());
        item.put("anfme", taskItem.getAnfme());
        item.put("workQty", taskItem.getWorkQty());
        item.put("qty", taskItem.getQty());
        item.put("ableQty", taskItem.getAbleQty());
        item.put("source", taskItem.getSource());
        item.put("sourceId", taskItem.getSourceId());
        item.put("sourceCode", taskItem.getSourceCode());
        item.put("orderId", taskItem.getOrderId());
        item.put("orderItemId", taskItem.getOrderItemId());
        item.put("platItemId", taskItem.getPlatItemId());
        item.put("platOrderCode", taskItem.getPlatOrderCode());
        item.put("platWorkCode", taskItem.getPlatWorkCode());
        item.put("projectCode", taskItem.getProjectCode());
        item.put("orderType", taskItem.getOrderType());
        item.put("orderTypeLabel", taskItem.getOrderType$());
        item.put("wkType", taskItem.getWkType());
        item.put("wkTypeLabel", taskItem.getWkType$());
        item.put("isptResult", taskItem.getIsptResult());
        item.put("isptResultLabel", taskItem.getIsptResult$());
        item.put("fieldsIndex", taskItem.getFieldsIndex());
        item.put("extendFields", taskItem.getExtendFields());
        item.put("status", taskItem.getStatus());
        item.put("statusLabel", taskItem.getStatus$());
        item.put("memo", taskItem.getMemo());
        item.put("createTime", taskItem.getCreateTime$());
        item.put("updateTime", taskItem.getUpdateTime$());
        return item;
    }
}