| | |
| | | package com.vincent.rsf.server.ai.service.impl; |
| | | |
| | | import com.fasterxml.jackson.databind.JsonNode; |
| | | import com.fasterxml.jackson.databind.ObjectMapper; |
| | | import com.vincent.rsf.framework.exception.CoolException; |
| | | import com.vincent.rsf.server.ai.config.AiDefaults; |
| | | import com.vincent.rsf.server.ai.dto.AiParamValidateResultDto; |
| | | import com.vincent.rsf.server.ai.entity.AiParam; |
| | | import io.micrometer.observation.ObservationRegistry; |
| | | import lombok.extern.slf4j.Slf4j; |
| | | import lombok.RequiredArgsConstructor; |
| | | import org.springframework.ai.chat.messages.UserMessage; |
| | | import org.springframework.ai.chat.model.ChatResponse; |
| | | import org.springframework.ai.chat.prompt.Prompt; |
| | | import org.springframework.ai.model.tool.DefaultToolCallingManager; |
| | | import org.springframework.ai.model.tool.ToolCallingManager; |
| | | import org.springframework.ai.openai.OpenAiChatModel; |
| | | import org.springframework.ai.openai.OpenAiChatOptions; |
| | | import org.springframework.ai.openai.api.OpenAiApi; |
| | | import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor; |
| | | import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver; |
| | | import org.springframework.ai.util.json.schema.SchemaType; |
| | | import org.springframework.context.support.GenericApplicationContext; |
| | | import org.springframework.http.client.SimpleClientHttpRequestFactory; |
| | | import org.springframework.stereotype.Component; |
| | | import org.springframework.http.MediaType; |
| | | import org.springframework.http.ResponseEntity; |
| | | import org.springframework.http.client.SimpleClientHttpRequestFactory; |
| | | import org.springframework.util.StringUtils; |
| | | import org.springframework.web.client.RestClient; |
| | | import org.springframework.web.reactive.function.client.WebClient; |
| | | import org.springframework.web.client.RestClientResponseException; |
| | | |
| | | import java.text.SimpleDateFormat; |
| | | import java.util.ArrayList; |
| | | import java.util.Date; |
| | | import java.util.LinkedHashMap; |
| | | import java.util.List; |
| | | import java.util.Map; |
| | | |
| | | @Component |
| | | @Slf4j |
| | | @RequiredArgsConstructor |
| | | public class AiParamValidationSupport { |
| | | |
| | | private final GenericApplicationContext applicationContext; |
| | | private final ObservationRegistry observationRegistry; |
| | | private static final String PROBE_MESSAGE = "请回复 OK"; |
| | | private static final int MAX_DETAIL_TEXT_LENGTH = 8000; |
| | | |
| | | private final ObjectMapper objectMapper; |
| | | |
| | | /** |
| | | * 对一份 AI 参数草稿做真实连通性校验。 |
| | |
| | | */ |
| | | public AiParamValidateResultDto validate(AiParam aiParam) { |
| | | long startedAt = System.currentTimeMillis(); |
| | | ValidationRequestContext context = buildValidationRequestContext(aiParam); |
| | | try { |
| | | OpenAiChatModel chatModel = createChatModel(aiParam); |
| | | ChatResponse response = chatModel.call(new Prompt(List.of(new UserMessage("请回复 OK")))); |
| | | if (response == null || response.getResult() == null || response.getResult().getOutput() == null |
| | | || !StringUtils.hasText(response.getResult().getOutput().getText())) { |
| | | ValidationResponseSnapshot responseSnapshot = executeValidation(context); |
| | | String assistantText = extractAssistantText(responseSnapshot.responseBody()); |
| | | if (!StringUtils.hasText(assistantText)) { |
| | | throw new CoolException("模型已连接,但未返回有效响应"); |
| | | } |
| | | long elapsedMs = System.currentTimeMillis() - startedAt; |
| | | String detail = buildValidationDetail( |
| | | context.requestUrl(), |
| | | context.requestBody(), |
| | | responseSnapshot.statusCode(), |
| | | responseSnapshot.responseBody(), |
| | | assistantText, |
| | | null |
| | | ); |
| | | log.info("AI 参数草稿校验成功, model={}, requestUrl={}, responseStatus={}", |
| | | aiParam.getModel(), context.requestUrl(), responseSnapshot.statusCode()); |
| | | return AiParamValidateResultDto.builder() |
| | | .status(AiDefaults.PARAM_VALIDATE_VALID) |
| | | .message("模型连通成功") |
| | | .detail(detail) |
| | | .model(aiParam.getModel()) |
| | | .elapsedMs(elapsedMs) |
| | | .validatedAt(formatDate(new Date())) |
| | | .requestUrl(context.requestUrl()) |
| | | .requestBody(context.requestBody()) |
| | | .responseStatus(responseSnapshot.statusCode()) |
| | | .responseBody(abbreviate(responseSnapshot.responseBody())) |
| | | .build(); |
| | | } catch (Exception e) { |
| | | long elapsedMs = System.currentTimeMillis() - startedAt; |
| | | String message = e instanceof CoolException ? e.getMessage() : "模型验证失败: " + e.getMessage(); |
| | | ValidationErrorSnapshot errorSnapshot = extractErrorSnapshot(e); |
| | | String message = e instanceof CoolException |
| | | ? e.getMessage() |
| | | : "模型验证失败: " + firstNonBlank( |
| | | extractResponseErrorMessage(errorSnapshot.responseBody()), |
| | | errorSnapshot.errorMessage(), |
| | | e.getMessage()); |
| | | String detail = buildValidationDetail( |
| | | context.requestUrl(), |
| | | context.requestBody(), |
| | | errorSnapshot.statusCode(), |
| | | errorSnapshot.responseBody(), |
| | | null, |
| | | e |
| | | ); |
| | | log.warn("AI 参数草稿校验失败, model={}, requestUrl={}, responseStatus={}, error={}", |
| | | aiParam.getModel(), context.requestUrl(), errorSnapshot.statusCode(), errorSnapshot.errorMessage(), e); |
| | | return AiParamValidateResultDto.builder() |
| | | .status(AiDefaults.PARAM_VALIDATE_INVALID) |
| | | .message(message) |
| | | .detail(detail) |
| | | .model(aiParam.getModel()) |
| | | .elapsedMs(elapsedMs) |
| | | .validatedAt(formatDate(new Date())) |
| | | .requestUrl(context.requestUrl()) |
| | | .requestBody(context.requestBody()) |
| | | .responseStatus(errorSnapshot.statusCode()) |
| | | .responseBody(abbreviate(errorSnapshot.responseBody())) |
| | | .build(); |
| | | } |
| | | } |
| | | |
| | | private OpenAiChatModel createChatModel(AiParam aiParam) { |
| | | /** |
| | | * 构造仅用于校验的轻量聊天模型。 |
| | | * 这里沿用正式链路的 Observation 和 ToolCalling 依赖, |
| | | * 保证校验结论与真实运行环境尽量一致。 |
| | | */ |
| | | OpenAiApi openAiApi = buildOpenAiApi(aiParam); |
| | | ToolCallingManager toolCallingManager = DefaultToolCallingManager.builder() |
| | | .observationRegistry(observationRegistry) |
| | | .toolCallbackResolver(new SpringBeanToolCallbackResolver(applicationContext, SchemaType.OPEN_API_SCHEMA)) |
| | | .toolExecutionExceptionProcessor(new DefaultToolExecutionExceptionProcessor(false)) |
| | | .build(); |
| | | return new OpenAiChatModel( |
| | | openAiApi, |
| | | OpenAiChatOptions.builder() |
| | | .model(aiParam.getModel()) |
| | | .temperature(aiParam.getTemperature()) |
| | | .topP(aiParam.getTopP()) |
| | | .maxTokens(aiParam.getMaxTokens()) |
| | | .streamUsage(true) |
| | | .build(), |
| | | toolCallingManager, |
| | | org.springframework.retry.support.RetryTemplate.builder().maxAttempts(1).build(), |
| | | observationRegistry |
| | | private ValidationRequestContext buildValidationRequestContext(AiParam aiParam) { |
| | | AiOpenAiApiSupport.EndpointConfig endpointConfig = AiOpenAiApiSupport.resolveEndpointConfig(aiParam.getBaseUrl()); |
| | | String requestUrl = endpointConfig.baseUrl() + endpointConfig.completionsPath(); |
| | | Map<String, Object> requestBody = new LinkedHashMap<>(); |
| | | requestBody.put("model", aiParam.getModel()); |
| | | requestBody.put("messages", List.of(Map.of("role", "user", "content", PROBE_MESSAGE))); |
| | | requestBody.put("stream", false); |
| | | if (aiParam.getTemperature() != null) { |
| | | requestBody.put("temperature", aiParam.getTemperature()); |
| | | } |
| | | if (aiParam.getTopP() != null) { |
| | | requestBody.put("top_p", aiParam.getTopP()); |
| | | } |
| | | if (aiParam.getMaxTokens() != null) { |
| | | requestBody.put("max_tokens", aiParam.getMaxTokens()); |
| | | } |
| | | return new ValidationRequestContext( |
| | | buildRestClient(aiParam), |
| | | requestUrl, |
| | | requestBody, |
| | | toPrettyJson(requestBody) |
| | | ); |
| | | } |
| | | |
| | | private OpenAiApi buildOpenAiApi(AiParam aiParam) { |
| | | /** |
| | | * 根据表单里的 Base URL、API Key 和超时参数构造 OpenAI 兼容客户端。 |
| | | * 该方法被显式拆出来,是为了让“网络连接参数”和“模型选项”职责分离。 |
| | | */ |
| | | private RestClient buildRestClient(AiParam aiParam) { |
| | | int timeoutMs = aiParam.getTimeoutMs() == null ? AiDefaults.DEFAULT_TIMEOUT_MS : aiParam.getTimeoutMs(); |
| | | SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory(); |
| | | requestFactory.setConnectTimeout(timeoutMs); |
| | | requestFactory.setReadTimeout(timeoutMs); |
| | | return OpenAiApi.builder() |
| | | .baseUrl(aiParam.getBaseUrl()) |
| | | .apiKey(aiParam.getApiKey()) |
| | | .restClientBuilder(RestClient.builder().requestFactory(requestFactory)) |
| | | .webClientBuilder(WebClient.builder()) |
| | | return RestClient.builder() |
| | | .requestFactory(requestFactory) |
| | | .defaultHeader("Authorization", "Bearer " + aiParam.getApiKey()) |
| | | .build(); |
| | | } |
| | | |
| | | private ValidationResponseSnapshot executeValidation(ValidationRequestContext context) { |
| | | ResponseEntity<String> responseEntity = context.restClient().post() |
| | | .uri(context.requestUrl()) |
| | | .contentType(MediaType.APPLICATION_JSON) |
| | | .accept(MediaType.APPLICATION_JSON) |
| | | .body(context.requestBodyMap()) |
| | | .retrieve() |
| | | .toEntity(String.class); |
| | | return new ValidationResponseSnapshot( |
| | | responseEntity.getStatusCode().value(), |
| | | responseEntity.getBody() |
| | | ); |
| | | } |
| | | |
| | | private String extractAssistantText(String responseBody) { |
| | | if (!StringUtils.hasText(responseBody)) { |
| | | return ""; |
| | | } |
| | | try { |
| | | JsonNode root = objectMapper.readTree(responseBody); |
| | | List<JsonNode> candidates = List.of( |
| | | root.at("/choices/0/message/content"), |
| | | root.at("/choices/0/text"), |
| | | root.at("/output_text"), |
| | | root.at("/output/0/content/0/text") |
| | | ); |
| | | for (JsonNode candidate : candidates) { |
| | | String text = flattenText(candidate); |
| | | if (StringUtils.hasText(text)) { |
| | | return text; |
| | | } |
| | | } |
| | | } catch (Exception e) { |
| | | log.debug("解析 AI 参数草稿校验响应失败: {}", e.getMessage()); |
| | | } |
| | | return ""; |
| | | } |
| | | |
| | | private String flattenText(JsonNode node) { |
| | | if (node == null || node.isMissingNode() || node.isNull()) { |
| | | return ""; |
| | | } |
| | | if (node.isTextual()) { |
| | | return node.asText(); |
| | | } |
| | | if (node.isArray()) { |
| | | List<String> parts = new ArrayList<>(); |
| | | for (JsonNode item : node) { |
| | | String text = flattenText(item); |
| | | if (StringUtils.hasText(text)) { |
| | | parts.add(text); |
| | | } |
| | | } |
| | | return String.join("\n", parts).trim(); |
| | | } |
| | | if (node.isObject()) { |
| | | String text = flattenText(node.get("text")); |
| | | if (StringUtils.hasText(text)) { |
| | | return text; |
| | | } |
| | | text = flattenText(node.get("content")); |
| | | if (StringUtils.hasText(text)) { |
| | | return text; |
| | | } |
| | | } |
| | | return ""; |
| | | } |
| | | |
| | | private ValidationErrorSnapshot extractErrorSnapshot(Throwable throwable) { |
| | | RestClientResponseException responseException = findResponseException(throwable); |
| | | if (responseException != null) { |
| | | return new ValidationErrorSnapshot( |
| | | responseException.getStatusCode().value(), |
| | | responseException.getResponseBodyAsString(), |
| | | rootMessage(throwable) |
| | | ); |
| | | } |
| | | return new ValidationErrorSnapshot(null, null, rootMessage(throwable)); |
| | | } |
| | | |
| | | private RestClientResponseException findResponseException(Throwable throwable) { |
| | | Throwable current = throwable; |
| | | while (current != null) { |
| | | if (current instanceof RestClientResponseException responseException) { |
| | | return responseException; |
| | | } |
| | | current = current.getCause(); |
| | | } |
| | | return null; |
| | | } |
| | | |
| | | private String extractResponseErrorMessage(String responseBody) { |
| | | if (!StringUtils.hasText(responseBody)) { |
| | | return ""; |
| | | } |
| | | try { |
| | | JsonNode root = objectMapper.readTree(responseBody); |
| | | List<JsonNode> candidates = List.of( |
| | | root.at("/error/message"), |
| | | root.at("/message"), |
| | | root.at("/detail") |
| | | ); |
| | | for (JsonNode candidate : candidates) { |
| | | String text = flattenText(candidate); |
| | | if (StringUtils.hasText(text)) { |
| | | return text; |
| | | } |
| | | } |
| | | } catch (Exception e) { |
| | | log.debug("解析 AI 参数草稿校验错误响应失败: {}", e.getMessage()); |
| | | } |
| | | return ""; |
| | | } |
| | | |
| | | private String buildValidationDetail(String requestUrl, |
| | | String requestBody, |
| | | Integer responseStatus, |
| | | String responseBody, |
| | | String assistantText, |
| | | Throwable throwable) { |
| | | List<String> sections = new ArrayList<>(); |
| | | sections.add("请求 URL:\n" + defaultText(requestUrl)); |
| | | sections.add("请求体:\n" + defaultText(requestBody)); |
| | | sections.add("HTTP 状态:\n" + (responseStatus == null ? "--" : responseStatus)); |
| | | if (StringUtils.hasText(assistantText)) { |
| | | sections.add("解析结果:\n" + abbreviate(assistantText)); |
| | | } |
| | | sections.add("响应结果:\n" + defaultText(abbreviate(responseBody))); |
| | | if (throwable != null) { |
| | | sections.add("异常类型:\n" + throwable.getClass().getName()); |
| | | sections.add("异常信息:\n" + defaultText(rootMessage(throwable))); |
| | | } |
| | | return String.join("\n\n", sections); |
| | | } |
| | | |
| | | private String toPrettyJson(Object value) { |
| | | try { |
| | | return objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(value); |
| | | } catch (Exception e) { |
| | | return String.valueOf(value); |
| | | } |
| | | } |
| | | |
| | | private String rootMessage(Throwable throwable) { |
| | | Throwable current = throwable; |
| | | while (current != null && current.getCause() != null) { |
| | | current = current.getCause(); |
| | | } |
| | | return current == null ? "" : current.getMessage(); |
| | | } |
| | | |
| | | private String defaultText(String value) { |
| | | return StringUtils.hasText(value) ? value : "--"; |
| | | } |
| | | |
| | | private String firstNonBlank(String... values) { |
| | | for (String value : values) { |
| | | if (StringUtils.hasText(value)) { |
| | | return value; |
| | | } |
| | | } |
| | | return ""; |
| | | } |
| | | |
| | | private String abbreviate(String value) { |
| | | if (!StringUtils.hasText(value)) { |
| | | return value; |
| | | } |
| | | if (value.length() <= MAX_DETAIL_TEXT_LENGTH) { |
| | | return value; |
| | | } |
| | | return value.substring(0, MAX_DETAIL_TEXT_LENGTH) + "\n...(已截断)"; |
| | | } |
| | | |
| | | private String formatDate(Date date) { |
| | | /** 统一输出给前端的校验时间格式。 */ |
| | | return new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(date); |
| | | } |
| | | |
| | | private record ValidationRequestContext(RestClient restClient, |
| | | String requestUrl, |
| | | Map<String, Object> requestBodyMap, |
| | | String requestBody) { |
| | | } |
| | | |
| | | private record ValidationResponseSnapshot(int statusCode, String responseBody) { |
| | | } |
| | | |
| | | private record ValidationErrorSnapshot(Integer statusCode, String responseBody, String errorMessage) { |
| | | } |
| | | } |