zhou zhou
2 天以前 1dcfa3702505f0c431757312b5304531029f90f6
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiParamValidationSupport.java
@@ -1,36 +1,37 @@
package com.vincent.rsf.server.ai.service.impl;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.vincent.rsf.framework.exception.CoolException;
import com.vincent.rsf.server.ai.config.AiDefaults;
import com.vincent.rsf.server.ai.dto.AiParamValidateResultDto;
import com.vincent.rsf.server.ai.entity.AiParam;
import io.micrometer.observation.ObservationRegistry;
import lombok.extern.slf4j.Slf4j;
import lombok.RequiredArgsConstructor;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.tool.DefaultToolCallingManager;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor;
import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver;
import org.springframework.ai.util.json.schema.SchemaType;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.stereotype.Component;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.util.StringUtils;
import org.springframework.web.client.RestClient;
import org.springframework.web.client.RestClientResponseException;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
@Component
@Slf4j
@RequiredArgsConstructor
public class AiParamValidationSupport {
    private final GenericApplicationContext applicationContext;
    private final ObservationRegistry observationRegistry;
    private static final String PROBE_MESSAGE = "请回复 OK";
    private static final int MAX_DETAIL_TEXT_LENGTH = 8000;
    private final ObjectMapper objectMapper;
    /**
     * 对一份 AI 参数草稿做真实连通性校验。
@@ -39,68 +40,293 @@
     */
    public AiParamValidateResultDto validate(AiParam aiParam) {
        long startedAt = System.currentTimeMillis();
        ValidationRequestContext context = buildValidationRequestContext(aiParam);
        try {
            OpenAiChatModel chatModel = createChatModel(aiParam);
            ChatResponse response = chatModel.call(new Prompt(List.of(new UserMessage("请回复 OK"))));
            if (response == null || response.getResult() == null || response.getResult().getOutput() == null
                    || !StringUtils.hasText(response.getResult().getOutput().getText())) {
            ValidationResponseSnapshot responseSnapshot = executeValidation(context);
            String assistantText = extractAssistantText(responseSnapshot.responseBody());
            if (!StringUtils.hasText(assistantText)) {
                throw new CoolException("模型已连接,但未返回有效响应");
            }
            long elapsedMs = System.currentTimeMillis() - startedAt;
            String detail = buildValidationDetail(
                    context.requestUrl(),
                    context.requestBody(),
                    responseSnapshot.statusCode(),
                    responseSnapshot.responseBody(),
                    assistantText,
                    null
            );
            log.info("AI 参数草稿校验成功, model={}, requestUrl={}, responseStatus={}",
                    aiParam.getModel(), context.requestUrl(), responseSnapshot.statusCode());
            return AiParamValidateResultDto.builder()
                    .status(AiDefaults.PARAM_VALIDATE_VALID)
                    .message("模型连通成功")
                    .detail(detail)
                    .model(aiParam.getModel())
                    .elapsedMs(elapsedMs)
                    .validatedAt(formatDate(new Date()))
                    .requestUrl(context.requestUrl())
                    .requestBody(context.requestBody())
                    .responseStatus(responseSnapshot.statusCode())
                    .responseBody(abbreviate(responseSnapshot.responseBody()))
                    .build();
        } catch (Exception e) {
            long elapsedMs = System.currentTimeMillis() - startedAt;
            String message = e instanceof CoolException ? e.getMessage() : "模型验证失败: " + e.getMessage();
            ValidationErrorSnapshot errorSnapshot = extractErrorSnapshot(e);
            String message = e instanceof CoolException
                    ? e.getMessage()
                    : "模型验证失败: " + firstNonBlank(
                    extractResponseErrorMessage(errorSnapshot.responseBody()),
                    errorSnapshot.errorMessage(),
                    e.getMessage());
            String detail = buildValidationDetail(
                    context.requestUrl(),
                    context.requestBody(),
                    errorSnapshot.statusCode(),
                    errorSnapshot.responseBody(),
                    null,
                    e
            );
            log.warn("AI 参数草稿校验失败, model={}, requestUrl={}, responseStatus={}, error={}",
                    aiParam.getModel(), context.requestUrl(), errorSnapshot.statusCode(), errorSnapshot.errorMessage(), e);
            return AiParamValidateResultDto.builder()
                    .status(AiDefaults.PARAM_VALIDATE_INVALID)
                    .message(message)
                    .detail(detail)
                    .model(aiParam.getModel())
                    .elapsedMs(elapsedMs)
                    .validatedAt(formatDate(new Date()))
                    .requestUrl(context.requestUrl())
                    .requestBody(context.requestBody())
                    .responseStatus(errorSnapshot.statusCode())
                    .responseBody(abbreviate(errorSnapshot.responseBody()))
                    .build();
        }
    }
    private OpenAiChatModel createChatModel(AiParam aiParam) {
        /**
         * 构造仅用于校验的轻量聊天模型。
         * 这里沿用正式链路的 Observation 和 ToolCalling 依赖,
         * 保证校验结论与真实运行环境尽量一致。
         */
        OpenAiApi openAiApi = buildOpenAiApi(aiParam);
        ToolCallingManager toolCallingManager = DefaultToolCallingManager.builder()
                .observationRegistry(observationRegistry)
                .toolCallbackResolver(new SpringBeanToolCallbackResolver(applicationContext, SchemaType.OPEN_API_SCHEMA))
                .toolExecutionExceptionProcessor(new DefaultToolExecutionExceptionProcessor(false))
                .build();
        return new OpenAiChatModel(
                openAiApi,
                OpenAiChatOptions.builder()
                        .model(aiParam.getModel())
                        .temperature(aiParam.getTemperature())
                        .topP(aiParam.getTopP())
                        .maxTokens(aiParam.getMaxTokens())
                        .streamUsage(true)
                        .build(),
                toolCallingManager,
                org.springframework.retry.support.RetryTemplate.builder().maxAttempts(1).build(),
                observationRegistry
    private ValidationRequestContext buildValidationRequestContext(AiParam aiParam) {
        AiOpenAiApiSupport.EndpointConfig endpointConfig = AiOpenAiApiSupport.resolveEndpointConfig(aiParam.getBaseUrl());
        String requestUrl = endpointConfig.baseUrl() + endpointConfig.completionsPath();
        Map<String, Object> requestBody = new LinkedHashMap<>();
        requestBody.put("model", aiParam.getModel());
        requestBody.put("messages", List.of(Map.of("role", "user", "content", PROBE_MESSAGE)));
        requestBody.put("stream", false);
        if (aiParam.getTemperature() != null) {
            requestBody.put("temperature", aiParam.getTemperature());
        }
        if (aiParam.getTopP() != null) {
            requestBody.put("top_p", aiParam.getTopP());
        }
        if (aiParam.getMaxTokens() != null) {
            requestBody.put("max_tokens", aiParam.getMaxTokens());
        }
        return new ValidationRequestContext(
                buildRestClient(aiParam),
                requestUrl,
                requestBody,
                toPrettyJson(requestBody)
        );
    }
    private OpenAiApi buildOpenAiApi(AiParam aiParam) {
        /** 统一兼容根地址、/v1 前缀和完整 completions endpoint 三种常见填法。 */
        return AiOpenAiApiSupport.buildOpenAiApi(aiParam);
    private RestClient buildRestClient(AiParam aiParam) {
        int timeoutMs = aiParam.getTimeoutMs() == null ? AiDefaults.DEFAULT_TIMEOUT_MS : aiParam.getTimeoutMs();
        SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory();
        requestFactory.setConnectTimeout(timeoutMs);
        requestFactory.setReadTimeout(timeoutMs);
        return RestClient.builder()
                .requestFactory(requestFactory)
                .defaultHeader("Authorization", "Bearer " + aiParam.getApiKey())
                .build();
    }
    private ValidationResponseSnapshot executeValidation(ValidationRequestContext context) {
        ResponseEntity<String> responseEntity = context.restClient().post()
                .uri(context.requestUrl())
                .contentType(MediaType.APPLICATION_JSON)
                .accept(MediaType.APPLICATION_JSON)
                .body(context.requestBodyMap())
                .retrieve()
                .toEntity(String.class);
        return new ValidationResponseSnapshot(
                responseEntity.getStatusCode().value(),
                responseEntity.getBody()
        );
    }
    private String extractAssistantText(String responseBody) {
        if (!StringUtils.hasText(responseBody)) {
            return "";
        }
        try {
            JsonNode root = objectMapper.readTree(responseBody);
            List<JsonNode> candidates = List.of(
                    root.at("/choices/0/message/content"),
                    root.at("/choices/0/text"),
                    root.at("/output_text"),
                    root.at("/output/0/content/0/text")
            );
            for (JsonNode candidate : candidates) {
                String text = flattenText(candidate);
                if (StringUtils.hasText(text)) {
                    return text;
                }
            }
        } catch (Exception e) {
            log.debug("解析 AI 参数草稿校验响应失败: {}", e.getMessage());
        }
        return "";
    }
    private String flattenText(JsonNode node) {
        if (node == null || node.isMissingNode() || node.isNull()) {
            return "";
        }
        if (node.isTextual()) {
            return node.asText();
        }
        if (node.isArray()) {
            List<String> parts = new ArrayList<>();
            for (JsonNode item : node) {
                String text = flattenText(item);
                if (StringUtils.hasText(text)) {
                    parts.add(text);
                }
            }
            return String.join("\n", parts).trim();
        }
        if (node.isObject()) {
            String text = flattenText(node.get("text"));
            if (StringUtils.hasText(text)) {
                return text;
            }
            text = flattenText(node.get("content"));
            if (StringUtils.hasText(text)) {
                return text;
            }
        }
        return "";
    }
    private ValidationErrorSnapshot extractErrorSnapshot(Throwable throwable) {
        RestClientResponseException responseException = findResponseException(throwable);
        if (responseException != null) {
            return new ValidationErrorSnapshot(
                    responseException.getStatusCode().value(),
                    responseException.getResponseBodyAsString(),
                    rootMessage(throwable)
            );
        }
        return new ValidationErrorSnapshot(null, null, rootMessage(throwable));
    }
    private RestClientResponseException findResponseException(Throwable throwable) {
        Throwable current = throwable;
        while (current != null) {
            if (current instanceof RestClientResponseException responseException) {
                return responseException;
            }
            current = current.getCause();
        }
        return null;
    }
    private String extractResponseErrorMessage(String responseBody) {
        if (!StringUtils.hasText(responseBody)) {
            return "";
        }
        try {
            JsonNode root = objectMapper.readTree(responseBody);
            List<JsonNode> candidates = List.of(
                    root.at("/error/message"),
                    root.at("/message"),
                    root.at("/detail")
            );
            for (JsonNode candidate : candidates) {
                String text = flattenText(candidate);
                if (StringUtils.hasText(text)) {
                    return text;
                }
            }
        } catch (Exception e) {
            log.debug("解析 AI 参数草稿校验错误响应失败: {}", e.getMessage());
        }
        return "";
    }
    private String buildValidationDetail(String requestUrl,
                                         String requestBody,
                                         Integer responseStatus,
                                         String responseBody,
                                         String assistantText,
                                         Throwable throwable) {
        List<String> sections = new ArrayList<>();
        sections.add("请求 URL:\n" + defaultText(requestUrl));
        sections.add("请求体:\n" + defaultText(requestBody));
        sections.add("HTTP 状态:\n" + (responseStatus == null ? "--" : responseStatus));
        if (StringUtils.hasText(assistantText)) {
            sections.add("解析结果:\n" + abbreviate(assistantText));
        }
        sections.add("响应结果:\n" + defaultText(abbreviate(responseBody)));
        if (throwable != null) {
            sections.add("异常类型:\n" + throwable.getClass().getName());
            sections.add("异常信息:\n" + defaultText(rootMessage(throwable)));
        }
        return String.join("\n\n", sections);
    }
    private String toPrettyJson(Object value) {
        try {
            return objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(value);
        } catch (Exception e) {
            return String.valueOf(value);
        }
    }
    private String rootMessage(Throwable throwable) {
        Throwable current = throwable;
        while (current != null && current.getCause() != null) {
            current = current.getCause();
        }
        return current == null ? "" : current.getMessage();
    }
    private String defaultText(String value) {
        return StringUtils.hasText(value) ? value : "--";
    }
    private String firstNonBlank(String... values) {
        for (String value : values) {
            if (StringUtils.hasText(value)) {
                return value;
            }
        }
        return "";
    }
    private String abbreviate(String value) {
        if (!StringUtils.hasText(value)) {
            return value;
        }
        if (value.length() <= MAX_DETAIL_TEXT_LENGTH) {
            return value;
        }
        return value.substring(0, MAX_DETAIL_TEXT_LENGTH) + "\n...(已截断)";
    }
    private String formatDate(Date date) {
        /** 统一输出给前端的校验时间格式。 */
        return new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(date);
    }
    private record ValidationRequestContext(RestClient restClient,
                                            String requestUrl,
                                            Map<String, Object> requestBodyMap,
                                            String requestBody) {
    }
    private record ValidationResponseSnapshot(int statusCode, String responseBody) {
    }
    private record ValidationErrorSnapshot(Integer statusCode, String responseBody, String errorMessage) {
    }
}