package com.vincent.rsf.server.ai.service.impl;
|
|
import com.fasterxml.jackson.databind.JsonNode;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.vincent.rsf.framework.exception.CoolException;
|
import com.vincent.rsf.server.ai.config.AiDefaults;
|
import com.vincent.rsf.server.ai.dto.AiParamValidateResultDto;
|
import com.vincent.rsf.server.ai.entity.AiParam;
|
import lombok.extern.slf4j.Slf4j;
|
import lombok.RequiredArgsConstructor;
|
import org.springframework.stereotype.Component;
|
import org.springframework.http.MediaType;
|
import org.springframework.http.ResponseEntity;
|
import org.springframework.http.client.SimpleClientHttpRequestFactory;
|
import org.springframework.util.StringUtils;
|
import org.springframework.web.client.RestClient;
|
import org.springframework.web.client.RestClientResponseException;
|
|
import java.text.SimpleDateFormat;
|
import java.util.ArrayList;
|
import java.util.Date;
|
import java.util.LinkedHashMap;
|
import java.util.List;
|
import java.util.Map;
|
|
@Component
|
@Slf4j
|
@RequiredArgsConstructor
|
public class AiParamValidationSupport {
|
|
private static final String PROBE_MESSAGE = "请回复 OK";
|
private static final int MAX_DETAIL_TEXT_LENGTH = 8000;
|
|
private final ObjectMapper objectMapper;
|
|
/**
|
* 对一份 AI 参数草稿做真实连通性校验。
|
* 校验方式不是简单判断字段非空,而是直接构造聊天模型并发起一次最小探测调用,
|
* 用返回结果和耗时生成前端可展示的校验报告。
|
*/
|
public AiParamValidateResultDto validate(AiParam aiParam) {
|
long startedAt = System.currentTimeMillis();
|
ValidationRequestContext context = buildValidationRequestContext(aiParam);
|
try {
|
ValidationResponseSnapshot responseSnapshot = executeValidation(context);
|
String assistantText = extractAssistantText(responseSnapshot.responseBody());
|
if (!StringUtils.hasText(assistantText)) {
|
throw new CoolException("模型已连接,但未返回有效响应");
|
}
|
long elapsedMs = System.currentTimeMillis() - startedAt;
|
String detail = buildValidationDetail(
|
context.requestUrl(),
|
context.requestBody(),
|
responseSnapshot.statusCode(),
|
responseSnapshot.responseBody(),
|
assistantText,
|
null
|
);
|
log.info("AI 参数草稿校验成功, model={}, requestUrl={}, responseStatus={}",
|
aiParam.getModel(), context.requestUrl(), responseSnapshot.statusCode());
|
return AiParamValidateResultDto.builder()
|
.status(AiDefaults.PARAM_VALIDATE_VALID)
|
.message("模型连通成功")
|
.detail(detail)
|
.model(aiParam.getModel())
|
.elapsedMs(elapsedMs)
|
.validatedAt(formatDate(new Date()))
|
.requestUrl(context.requestUrl())
|
.requestBody(context.requestBody())
|
.responseStatus(responseSnapshot.statusCode())
|
.responseBody(abbreviate(responseSnapshot.responseBody()))
|
.build();
|
} catch (Exception e) {
|
long elapsedMs = System.currentTimeMillis() - startedAt;
|
ValidationErrorSnapshot errorSnapshot = extractErrorSnapshot(e);
|
String message = e instanceof CoolException
|
? e.getMessage()
|
: "模型验证失败: " + firstNonBlank(
|
extractResponseErrorMessage(errorSnapshot.responseBody()),
|
errorSnapshot.errorMessage(),
|
e.getMessage());
|
String detail = buildValidationDetail(
|
context.requestUrl(),
|
context.requestBody(),
|
errorSnapshot.statusCode(),
|
errorSnapshot.responseBody(),
|
null,
|
e
|
);
|
log.warn("AI 参数草稿校验失败, model={}, requestUrl={}, responseStatus={}, error={}",
|
aiParam.getModel(), context.requestUrl(), errorSnapshot.statusCode(), errorSnapshot.errorMessage(), e);
|
return AiParamValidateResultDto.builder()
|
.status(AiDefaults.PARAM_VALIDATE_INVALID)
|
.message(message)
|
.detail(detail)
|
.model(aiParam.getModel())
|
.elapsedMs(elapsedMs)
|
.validatedAt(formatDate(new Date()))
|
.requestUrl(context.requestUrl())
|
.requestBody(context.requestBody())
|
.responseStatus(errorSnapshot.statusCode())
|
.responseBody(abbreviate(errorSnapshot.responseBody()))
|
.build();
|
}
|
}
|
|
private ValidationRequestContext buildValidationRequestContext(AiParam aiParam) {
|
AiOpenAiApiSupport.EndpointConfig endpointConfig = AiOpenAiApiSupport.resolveEndpointConfig(aiParam.getBaseUrl());
|
String requestUrl = endpointConfig.baseUrl() + endpointConfig.completionsPath();
|
Map<String, Object> requestBody = new LinkedHashMap<>();
|
requestBody.put("model", aiParam.getModel());
|
requestBody.put("messages", List.of(Map.of("role", "user", "content", PROBE_MESSAGE)));
|
requestBody.put("stream", false);
|
if (aiParam.getTemperature() != null) {
|
requestBody.put("temperature", aiParam.getTemperature());
|
}
|
if (aiParam.getTopP() != null) {
|
requestBody.put("top_p", aiParam.getTopP());
|
}
|
if (aiParam.getMaxTokens() != null) {
|
requestBody.put("max_tokens", aiParam.getMaxTokens());
|
}
|
return new ValidationRequestContext(
|
buildRestClient(aiParam),
|
requestUrl,
|
requestBody,
|
toPrettyJson(requestBody)
|
);
|
}
|
|
private RestClient buildRestClient(AiParam aiParam) {
|
int timeoutMs = aiParam.getTimeoutMs() == null ? AiDefaults.DEFAULT_TIMEOUT_MS : aiParam.getTimeoutMs();
|
SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory();
|
requestFactory.setConnectTimeout(timeoutMs);
|
requestFactory.setReadTimeout(timeoutMs);
|
return RestClient.builder()
|
.requestFactory(requestFactory)
|
.defaultHeader("Authorization", "Bearer " + aiParam.getApiKey())
|
.build();
|
}
|
|
private ValidationResponseSnapshot executeValidation(ValidationRequestContext context) {
|
ResponseEntity<String> responseEntity = context.restClient().post()
|
.uri(context.requestUrl())
|
.contentType(MediaType.APPLICATION_JSON)
|
.accept(MediaType.APPLICATION_JSON)
|
.body(context.requestBodyMap())
|
.retrieve()
|
.toEntity(String.class);
|
return new ValidationResponseSnapshot(
|
responseEntity.getStatusCode().value(),
|
responseEntity.getBody()
|
);
|
}
|
|
private String extractAssistantText(String responseBody) {
|
if (!StringUtils.hasText(responseBody)) {
|
return "";
|
}
|
try {
|
JsonNode root = objectMapper.readTree(responseBody);
|
List<JsonNode> candidates = List.of(
|
root.at("/choices/0/message/content"),
|
root.at("/choices/0/text"),
|
root.at("/output_text"),
|
root.at("/output/0/content/0/text")
|
);
|
for (JsonNode candidate : candidates) {
|
String text = flattenText(candidate);
|
if (StringUtils.hasText(text)) {
|
return text;
|
}
|
}
|
} catch (Exception e) {
|
log.debug("解析 AI 参数草稿校验响应失败: {}", e.getMessage());
|
}
|
return "";
|
}
|
|
private String flattenText(JsonNode node) {
|
if (node == null || node.isMissingNode() || node.isNull()) {
|
return "";
|
}
|
if (node.isTextual()) {
|
return node.asText();
|
}
|
if (node.isArray()) {
|
List<String> parts = new ArrayList<>();
|
for (JsonNode item : node) {
|
String text = flattenText(item);
|
if (StringUtils.hasText(text)) {
|
parts.add(text);
|
}
|
}
|
return String.join("\n", parts).trim();
|
}
|
if (node.isObject()) {
|
String text = flattenText(node.get("text"));
|
if (StringUtils.hasText(text)) {
|
return text;
|
}
|
text = flattenText(node.get("content"));
|
if (StringUtils.hasText(text)) {
|
return text;
|
}
|
}
|
return "";
|
}
|
|
private ValidationErrorSnapshot extractErrorSnapshot(Throwable throwable) {
|
RestClientResponseException responseException = findResponseException(throwable);
|
if (responseException != null) {
|
return new ValidationErrorSnapshot(
|
responseException.getStatusCode().value(),
|
responseException.getResponseBodyAsString(),
|
rootMessage(throwable)
|
);
|
}
|
return new ValidationErrorSnapshot(null, null, rootMessage(throwable));
|
}
|
|
private RestClientResponseException findResponseException(Throwable throwable) {
|
Throwable current = throwable;
|
while (current != null) {
|
if (current instanceof RestClientResponseException responseException) {
|
return responseException;
|
}
|
current = current.getCause();
|
}
|
return null;
|
}
|
|
private String extractResponseErrorMessage(String responseBody) {
|
if (!StringUtils.hasText(responseBody)) {
|
return "";
|
}
|
try {
|
JsonNode root = objectMapper.readTree(responseBody);
|
List<JsonNode> candidates = List.of(
|
root.at("/error/message"),
|
root.at("/message"),
|
root.at("/detail")
|
);
|
for (JsonNode candidate : candidates) {
|
String text = flattenText(candidate);
|
if (StringUtils.hasText(text)) {
|
return text;
|
}
|
}
|
} catch (Exception e) {
|
log.debug("解析 AI 参数草稿校验错误响应失败: {}", e.getMessage());
|
}
|
return "";
|
}
|
|
private String buildValidationDetail(String requestUrl,
|
String requestBody,
|
Integer responseStatus,
|
String responseBody,
|
String assistantText,
|
Throwable throwable) {
|
List<String> sections = new ArrayList<>();
|
sections.add("请求 URL:\n" + defaultText(requestUrl));
|
sections.add("请求体:\n" + defaultText(requestBody));
|
sections.add("HTTP 状态:\n" + (responseStatus == null ? "--" : responseStatus));
|
if (StringUtils.hasText(assistantText)) {
|
sections.add("解析结果:\n" + abbreviate(assistantText));
|
}
|
sections.add("响应结果:\n" + defaultText(abbreviate(responseBody)));
|
if (throwable != null) {
|
sections.add("异常类型:\n" + throwable.getClass().getName());
|
sections.add("异常信息:\n" + defaultText(rootMessage(throwable)));
|
}
|
return String.join("\n\n", sections);
|
}
|
|
private String toPrettyJson(Object value) {
|
try {
|
return objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(value);
|
} catch (Exception e) {
|
return String.valueOf(value);
|
}
|
}
|
|
private String rootMessage(Throwable throwable) {
|
Throwable current = throwable;
|
while (current != null && current.getCause() != null) {
|
current = current.getCause();
|
}
|
return current == null ? "" : current.getMessage();
|
}
|
|
private String defaultText(String value) {
|
return StringUtils.hasText(value) ? value : "--";
|
}
|
|
private String firstNonBlank(String... values) {
|
for (String value : values) {
|
if (StringUtils.hasText(value)) {
|
return value;
|
}
|
}
|
return "";
|
}
|
|
private String abbreviate(String value) {
|
if (!StringUtils.hasText(value)) {
|
return value;
|
}
|
if (value.length() <= MAX_DETAIL_TEXT_LENGTH) {
|
return value;
|
}
|
return value.substring(0, MAX_DETAIL_TEXT_LENGTH) + "\n...(已截断)";
|
}
|
|
private String formatDate(Date date) {
|
/** 统一输出给前端的校验时间格式。 */
|
return new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(date);
|
}
|
|
private record ValidationRequestContext(RestClient restClient,
|
String requestUrl,
|
Map<String, Object> requestBodyMap,
|
String requestBody) {
|
}
|
|
private record ValidationResponseSnapshot(int statusCode, String responseBody) {
|
}
|
|
private record ValidationErrorSnapshot(Integer statusCode, String responseBody, String errorMessage) {
|
}
|
}
|