package com.vincent.rsf.server.ai.service.impl; import com.vincent.rsf.framework.exception.CoolException; import com.vincent.rsf.server.ai.config.AiDefaults; import com.vincent.rsf.server.ai.dto.AiParamValidateResultDto; import com.vincent.rsf.server.ai.entity.AiParam; import io.micrometer.observation.ObservationRegistry; import lombok.RequiredArgsConstructor; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.tool.DefaultToolCallingManager; import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor; import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver; import org.springframework.ai.util.json.schema.SchemaType; import org.springframework.context.support.GenericApplicationContext; import org.springframework.http.client.SimpleClientHttpRequestFactory; import org.springframework.stereotype.Component; import org.springframework.util.StringUtils; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; import java.text.SimpleDateFormat; import java.util.Date; import java.util.List; @Component @RequiredArgsConstructor public class AiParamValidationSupport { private final GenericApplicationContext applicationContext; private final ObservationRegistry observationRegistry; public AiParamValidateResultDto validate(AiParam aiParam) { long startedAt = System.currentTimeMillis(); try { OpenAiChatModel chatModel = createChatModel(aiParam); ChatResponse response = chatModel.call(new Prompt(List.of(new UserMessage("请回复 OK")))); if (response == null || response.getResult() == null || response.getResult().getOutput() == null || !StringUtils.hasText(response.getResult().getOutput().getText())) { throw new CoolException("模型已连接,但未返回有效响应"); } long elapsedMs = System.currentTimeMillis() - startedAt; return AiParamValidateResultDto.builder() .status(AiDefaults.PARAM_VALIDATE_VALID) .message("模型连通成功") .model(aiParam.getModel()) .elapsedMs(elapsedMs) .validatedAt(formatDate(new Date())) .build(); } catch (Exception e) { long elapsedMs = System.currentTimeMillis() - startedAt; String message = e instanceof CoolException ? e.getMessage() : "模型验证失败: " + e.getMessage(); return AiParamValidateResultDto.builder() .status(AiDefaults.PARAM_VALIDATE_INVALID) .message(message) .model(aiParam.getModel()) .elapsedMs(elapsedMs) .validatedAt(formatDate(new Date())) .build(); } } private OpenAiChatModel createChatModel(AiParam aiParam) { OpenAiApi openAiApi = buildOpenAiApi(aiParam); ToolCallingManager toolCallingManager = DefaultToolCallingManager.builder() .observationRegistry(observationRegistry) .toolCallbackResolver(new SpringBeanToolCallbackResolver(applicationContext, SchemaType.OPEN_API_SCHEMA)) .toolExecutionExceptionProcessor(new DefaultToolExecutionExceptionProcessor(false)) .build(); return new OpenAiChatModel( openAiApi, OpenAiChatOptions.builder() .model(aiParam.getModel()) .temperature(aiParam.getTemperature()) .topP(aiParam.getTopP()) .maxTokens(aiParam.getMaxTokens()) .streamUsage(true) .build(), toolCallingManager, org.springframework.retry.support.RetryTemplate.builder().maxAttempts(1).build(), observationRegistry ); } private OpenAiApi buildOpenAiApi(AiParam aiParam) { int timeoutMs = aiParam.getTimeoutMs() == null ? AiDefaults.DEFAULT_TIMEOUT_MS : aiParam.getTimeoutMs(); SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory(); requestFactory.setConnectTimeout(timeoutMs); requestFactory.setReadTimeout(timeoutMs); return OpenAiApi.builder() .baseUrl(aiParam.getBaseUrl()) .apiKey(aiParam.getApiKey()) .restClientBuilder(RestClient.builder().requestFactory(requestFactory)) .webClientBuilder(WebClient.builder()) .build(); } private String formatDate(Date date) { return new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(date); } }