package com.vincent.rsf.server.ai.service.impl; import com.vincent.rsf.framework.exception.CoolException; import com.vincent.rsf.server.ai.config.AiDefaults; import com.vincent.rsf.server.ai.dto.AiParamValidateResultDto; import com.vincent.rsf.server.ai.entity.AiParam; import io.micrometer.observation.ObservationRegistry; import lombok.RequiredArgsConstructor; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.tool.DefaultToolCallingManager; import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor; import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver; import org.springframework.ai.util.json.schema.SchemaType; import org.springframework.context.support.GenericApplicationContext; import org.springframework.http.client.SimpleClientHttpRequestFactory; import org.springframework.stereotype.Component; import org.springframework.util.StringUtils; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; import java.text.SimpleDateFormat; import java.util.Date; import java.util.List; @Component @RequiredArgsConstructor public class AiParamValidationSupport { private final GenericApplicationContext applicationContext; private final ObservationRegistry observationRegistry; /** * 对一份 AI 参数草稿做真实连通性校验。 * 校验方式不是简单判断字段非空,而是直接构造聊天模型并发起一次最小探测调用, * 用返回结果和耗时生成前端可展示的校验报告。 */ public AiParamValidateResultDto validate(AiParam aiParam) { long startedAt = System.currentTimeMillis(); try { OpenAiChatModel chatModel = createChatModel(aiParam); ChatResponse response = chatModel.call(new Prompt(List.of(new UserMessage("请回复 OK")))); if (response == null || response.getResult() == null || response.getResult().getOutput() == null || !StringUtils.hasText(response.getResult().getOutput().getText())) { throw new CoolException("模型已连接,但未返回有效响应"); } long elapsedMs = System.currentTimeMillis() - startedAt; return AiParamValidateResultDto.builder() .status(AiDefaults.PARAM_VALIDATE_VALID) .message("模型连通成功") .model(aiParam.getModel()) .elapsedMs(elapsedMs) .validatedAt(formatDate(new Date())) .build(); } catch (Exception e) { long elapsedMs = System.currentTimeMillis() - startedAt; String message = e instanceof CoolException ? e.getMessage() : "模型验证失败: " + e.getMessage(); return AiParamValidateResultDto.builder() .status(AiDefaults.PARAM_VALIDATE_INVALID) .message(message) .model(aiParam.getModel()) .elapsedMs(elapsedMs) .validatedAt(formatDate(new Date())) .build(); } } private OpenAiChatModel createChatModel(AiParam aiParam) { /** * 构造仅用于校验的轻量聊天模型。 * 这里沿用正式链路的 Observation 和 ToolCalling 依赖, * 保证校验结论与真实运行环境尽量一致。 */ OpenAiApi openAiApi = buildOpenAiApi(aiParam); ToolCallingManager toolCallingManager = DefaultToolCallingManager.builder() .observationRegistry(observationRegistry) .toolCallbackResolver(new SpringBeanToolCallbackResolver(applicationContext, SchemaType.OPEN_API_SCHEMA)) .toolExecutionExceptionProcessor(new DefaultToolExecutionExceptionProcessor(false)) .build(); return new OpenAiChatModel( openAiApi, OpenAiChatOptions.builder() .model(aiParam.getModel()) .temperature(aiParam.getTemperature()) .topP(aiParam.getTopP()) .maxTokens(aiParam.getMaxTokens()) .streamUsage(true) .build(), toolCallingManager, org.springframework.retry.support.RetryTemplate.builder().maxAttempts(1).build(), observationRegistry ); } private OpenAiApi buildOpenAiApi(AiParam aiParam) { /** * 根据表单里的 Base URL、API Key 和超时参数构造 OpenAI 兼容客户端。 * 该方法被显式拆出来,是为了让“网络连接参数”和“模型选项”职责分离。 */ int timeoutMs = aiParam.getTimeoutMs() == null ? AiDefaults.DEFAULT_TIMEOUT_MS : aiParam.getTimeoutMs(); SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory(); requestFactory.setConnectTimeout(timeoutMs); requestFactory.setReadTimeout(timeoutMs); return OpenAiApi.builder() .baseUrl(aiParam.getBaseUrl()) .apiKey(aiParam.getApiKey()) .restClientBuilder(RestClient.builder().requestFactory(requestFactory)) .webClientBuilder(WebClient.builder()) .build(); } private String formatDate(Date date) { /** 统一输出给前端的校验时间格式。 */ return new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(date); } }