package com.vincent.rsf.server.ai.service.impl; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.vincent.rsf.framework.exception.CoolException; import com.vincent.rsf.server.ai.config.AiDefaults; import com.vincent.rsf.server.ai.dto.AiParamValidateResultDto; import com.vincent.rsf.server.ai.entity.AiParam; import lombok.extern.slf4j.Slf4j; import lombok.RequiredArgsConstructor; import org.springframework.stereotype.Component; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.http.client.SimpleClientHttpRequestFactory; import org.springframework.util.StringUtils; import org.springframework.web.client.RestClient; import org.springframework.web.client.RestClientResponseException; import java.text.SimpleDateFormat; import java.util.ArrayList; import java.util.Date; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @Component @Slf4j @RequiredArgsConstructor public class AiParamValidationSupport { private static final String PROBE_MESSAGE = "请回复 OK"; private static final int MAX_DETAIL_TEXT_LENGTH = 8000; private final ObjectMapper objectMapper; /** * 对一份 AI 参数草稿做真实连通性校验。 * 校验方式不是简单判断字段非空,而是直接构造聊天模型并发起一次最小探测调用, * 用返回结果和耗时生成前端可展示的校验报告。 */ public AiParamValidateResultDto validate(AiParam aiParam) { long startedAt = System.currentTimeMillis(); ValidationRequestContext context = buildValidationRequestContext(aiParam); try { ValidationResponseSnapshot responseSnapshot = executeValidation(context); String assistantText = extractAssistantText(responseSnapshot.responseBody()); if (!StringUtils.hasText(assistantText)) { throw new CoolException("模型已连接,但未返回有效响应"); } long elapsedMs = System.currentTimeMillis() - startedAt; String detail = buildValidationDetail( context.requestUrl(), context.requestBody(), responseSnapshot.statusCode(), responseSnapshot.responseBody(), assistantText, null ); log.info("AI 参数草稿校验成功, model={}, requestUrl={}, responseStatus={}", aiParam.getModel(), context.requestUrl(), responseSnapshot.statusCode()); return AiParamValidateResultDto.builder() .status(AiDefaults.PARAM_VALIDATE_VALID) .message("模型连通成功") .detail(detail) .model(aiParam.getModel()) .elapsedMs(elapsedMs) .validatedAt(formatDate(new Date())) .requestUrl(context.requestUrl()) .requestBody(context.requestBody()) .responseStatus(responseSnapshot.statusCode()) .responseBody(abbreviate(responseSnapshot.responseBody())) .build(); } catch (Exception e) { long elapsedMs = System.currentTimeMillis() - startedAt; ValidationErrorSnapshot errorSnapshot = extractErrorSnapshot(e); String message = e instanceof CoolException ? e.getMessage() : "模型验证失败: " + firstNonBlank( extractResponseErrorMessage(errorSnapshot.responseBody()), errorSnapshot.errorMessage(), e.getMessage()); String detail = buildValidationDetail( context.requestUrl(), context.requestBody(), errorSnapshot.statusCode(), errorSnapshot.responseBody(), null, e ); log.warn("AI 参数草稿校验失败, model={}, requestUrl={}, responseStatus={}, error={}", aiParam.getModel(), context.requestUrl(), errorSnapshot.statusCode(), errorSnapshot.errorMessage(), e); return AiParamValidateResultDto.builder() .status(AiDefaults.PARAM_VALIDATE_INVALID) .message(message) .detail(detail) .model(aiParam.getModel()) .elapsedMs(elapsedMs) .validatedAt(formatDate(new Date())) .requestUrl(context.requestUrl()) .requestBody(context.requestBody()) .responseStatus(errorSnapshot.statusCode()) .responseBody(abbreviate(errorSnapshot.responseBody())) .build(); } } private ValidationRequestContext buildValidationRequestContext(AiParam aiParam) { AiOpenAiApiSupport.EndpointConfig endpointConfig = AiOpenAiApiSupport.resolveEndpointConfig(aiParam.getBaseUrl()); String requestUrl = endpointConfig.baseUrl() + endpointConfig.completionsPath(); Map requestBody = new LinkedHashMap<>(); requestBody.put("model", aiParam.getModel()); requestBody.put("messages", List.of(Map.of("role", "user", "content", PROBE_MESSAGE))); requestBody.put("stream", false); if (aiParam.getTemperature() != null) { requestBody.put("temperature", aiParam.getTemperature()); } if (aiParam.getTopP() != null) { requestBody.put("top_p", aiParam.getTopP()); } if (aiParam.getMaxTokens() != null) { requestBody.put("max_tokens", aiParam.getMaxTokens()); } return new ValidationRequestContext( buildRestClient(aiParam), requestUrl, requestBody, toPrettyJson(requestBody) ); } private RestClient buildRestClient(AiParam aiParam) { int timeoutMs = aiParam.getTimeoutMs() == null ? AiDefaults.DEFAULT_TIMEOUT_MS : aiParam.getTimeoutMs(); SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory(); requestFactory.setConnectTimeout(timeoutMs); requestFactory.setReadTimeout(timeoutMs); return RestClient.builder() .requestFactory(requestFactory) .defaultHeader("Authorization", "Bearer " + aiParam.getApiKey()) .build(); } private ValidationResponseSnapshot executeValidation(ValidationRequestContext context) { ResponseEntity responseEntity = context.restClient().post() .uri(context.requestUrl()) .contentType(MediaType.APPLICATION_JSON) .accept(MediaType.APPLICATION_JSON) .body(context.requestBodyMap()) .retrieve() .toEntity(String.class); return new ValidationResponseSnapshot( responseEntity.getStatusCode().value(), responseEntity.getBody() ); } private String extractAssistantText(String responseBody) { if (!StringUtils.hasText(responseBody)) { return ""; } try { JsonNode root = objectMapper.readTree(responseBody); List candidates = List.of( root.at("/choices/0/message/content"), root.at("/choices/0/text"), root.at("/output_text"), root.at("/output/0/content/0/text") ); for (JsonNode candidate : candidates) { String text = flattenText(candidate); if (StringUtils.hasText(text)) { return text; } } } catch (Exception e) { log.debug("解析 AI 参数草稿校验响应失败: {}", e.getMessage()); } return ""; } private String flattenText(JsonNode node) { if (node == null || node.isMissingNode() || node.isNull()) { return ""; } if (node.isTextual()) { return node.asText(); } if (node.isArray()) { List parts = new ArrayList<>(); for (JsonNode item : node) { String text = flattenText(item); if (StringUtils.hasText(text)) { parts.add(text); } } return String.join("\n", parts).trim(); } if (node.isObject()) { String text = flattenText(node.get("text")); if (StringUtils.hasText(text)) { return text; } text = flattenText(node.get("content")); if (StringUtils.hasText(text)) { return text; } } return ""; } private ValidationErrorSnapshot extractErrorSnapshot(Throwable throwable) { RestClientResponseException responseException = findResponseException(throwable); if (responseException != null) { return new ValidationErrorSnapshot( responseException.getStatusCode().value(), responseException.getResponseBodyAsString(), rootMessage(throwable) ); } return new ValidationErrorSnapshot(null, null, rootMessage(throwable)); } private RestClientResponseException findResponseException(Throwable throwable) { Throwable current = throwable; while (current != null) { if (current instanceof RestClientResponseException responseException) { return responseException; } current = current.getCause(); } return null; } private String extractResponseErrorMessage(String responseBody) { if (!StringUtils.hasText(responseBody)) { return ""; } try { JsonNode root = objectMapper.readTree(responseBody); List candidates = List.of( root.at("/error/message"), root.at("/message"), root.at("/detail") ); for (JsonNode candidate : candidates) { String text = flattenText(candidate); if (StringUtils.hasText(text)) { return text; } } } catch (Exception e) { log.debug("解析 AI 参数草稿校验错误响应失败: {}", e.getMessage()); } return ""; } private String buildValidationDetail(String requestUrl, String requestBody, Integer responseStatus, String responseBody, String assistantText, Throwable throwable) { List sections = new ArrayList<>(); sections.add("请求 URL:\n" + defaultText(requestUrl)); sections.add("请求体:\n" + defaultText(requestBody)); sections.add("HTTP 状态:\n" + (responseStatus == null ? "--" : responseStatus)); if (StringUtils.hasText(assistantText)) { sections.add("解析结果:\n" + abbreviate(assistantText)); } sections.add("响应结果:\n" + defaultText(abbreviate(responseBody))); if (throwable != null) { sections.add("异常类型:\n" + throwable.getClass().getName()); sections.add("异常信息:\n" + defaultText(rootMessage(throwable))); } return String.join("\n\n", sections); } private String toPrettyJson(Object value) { try { return objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(value); } catch (Exception e) { return String.valueOf(value); } } private String rootMessage(Throwable throwable) { Throwable current = throwable; while (current != null && current.getCause() != null) { current = current.getCause(); } return current == null ? "" : current.getMessage(); } private String defaultText(String value) { return StringUtils.hasText(value) ? value : "--"; } private String firstNonBlank(String... values) { for (String value : values) { if (StringUtils.hasText(value)) { return value; } } return ""; } private String abbreviate(String value) { if (!StringUtils.hasText(value)) { return value; } if (value.length() <= MAX_DETAIL_TEXT_LENGTH) { return value; } return value.substring(0, MAX_DETAIL_TEXT_LENGTH) + "\n...(已截断)"; } private String formatDate(Date date) { /** 统一输出给前端的校验时间格式。 */ return new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(date); } private record ValidationRequestContext(RestClient restClient, String requestUrl, Map requestBodyMap, String requestBody) { } private record ValidationResponseSnapshot(int statusCode, String responseBody) { } private record ValidationErrorSnapshot(Integer statusCode, String responseBody, String errorMessage) { } }