package com.vincent.rsf.server.ai.service.impl;
|
|
import com.vincent.rsf.framework.exception.CoolException;
|
import com.vincent.rsf.server.ai.config.AiDefaults;
|
import com.vincent.rsf.server.ai.dto.AiParamValidateResultDto;
|
import com.vincent.rsf.server.ai.entity.AiParam;
|
import io.micrometer.observation.ObservationRegistry;
|
import lombok.RequiredArgsConstructor;
|
import org.springframework.ai.chat.messages.UserMessage;
|
import org.springframework.ai.chat.model.ChatResponse;
|
import org.springframework.ai.chat.prompt.Prompt;
|
import org.springframework.ai.model.tool.DefaultToolCallingManager;
|
import org.springframework.ai.model.tool.ToolCallingManager;
|
import org.springframework.ai.openai.OpenAiChatModel;
|
import org.springframework.ai.openai.OpenAiChatOptions;
|
import org.springframework.ai.openai.api.OpenAiApi;
|
import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor;
|
import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver;
|
import org.springframework.ai.util.json.schema.SchemaType;
|
import org.springframework.context.support.GenericApplicationContext;
|
import org.springframework.http.client.SimpleClientHttpRequestFactory;
|
import org.springframework.stereotype.Component;
|
import org.springframework.util.StringUtils;
|
import org.springframework.web.client.RestClient;
|
import org.springframework.web.reactive.function.client.WebClient;
|
|
import java.text.SimpleDateFormat;
|
import java.util.Date;
|
import java.util.List;
|
|
@Component
|
@RequiredArgsConstructor
|
public class AiParamValidationSupport {
|
|
private final GenericApplicationContext applicationContext;
|
private final ObservationRegistry observationRegistry;
|
|
public AiParamValidateResultDto validate(AiParam aiParam) {
|
long startedAt = System.currentTimeMillis();
|
try {
|
OpenAiChatModel chatModel = createChatModel(aiParam);
|
ChatResponse response = chatModel.call(new Prompt(List.of(new UserMessage("请回复 OK"))));
|
if (response == null || response.getResult() == null || response.getResult().getOutput() == null
|
|| !StringUtils.hasText(response.getResult().getOutput().getText())) {
|
throw new CoolException("模型已连接,但未返回有效响应");
|
}
|
long elapsedMs = System.currentTimeMillis() - startedAt;
|
return AiParamValidateResultDto.builder()
|
.status(AiDefaults.PARAM_VALIDATE_VALID)
|
.message("模型连通成功")
|
.model(aiParam.getModel())
|
.elapsedMs(elapsedMs)
|
.validatedAt(formatDate(new Date()))
|
.build();
|
} catch (Exception e) {
|
long elapsedMs = System.currentTimeMillis() - startedAt;
|
String message = e instanceof CoolException ? e.getMessage() : "模型验证失败: " + e.getMessage();
|
return AiParamValidateResultDto.builder()
|
.status(AiDefaults.PARAM_VALIDATE_INVALID)
|
.message(message)
|
.model(aiParam.getModel())
|
.elapsedMs(elapsedMs)
|
.validatedAt(formatDate(new Date()))
|
.build();
|
}
|
}
|
|
private OpenAiChatModel createChatModel(AiParam aiParam) {
|
OpenAiApi openAiApi = buildOpenAiApi(aiParam);
|
ToolCallingManager toolCallingManager = DefaultToolCallingManager.builder()
|
.observationRegistry(observationRegistry)
|
.toolCallbackResolver(new SpringBeanToolCallbackResolver(applicationContext, SchemaType.OPEN_API_SCHEMA))
|
.toolExecutionExceptionProcessor(new DefaultToolExecutionExceptionProcessor(false))
|
.build();
|
return new OpenAiChatModel(
|
openAiApi,
|
OpenAiChatOptions.builder()
|
.model(aiParam.getModel())
|
.temperature(aiParam.getTemperature())
|
.topP(aiParam.getTopP())
|
.maxTokens(aiParam.getMaxTokens())
|
.streamUsage(true)
|
.build(),
|
toolCallingManager,
|
org.springframework.retry.support.RetryTemplate.builder().maxAttempts(1).build(),
|
observationRegistry
|
);
|
}
|
|
private OpenAiApi buildOpenAiApi(AiParam aiParam) {
|
int timeoutMs = aiParam.getTimeoutMs() == null ? AiDefaults.DEFAULT_TIMEOUT_MS : aiParam.getTimeoutMs();
|
SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory();
|
requestFactory.setConnectTimeout(timeoutMs);
|
requestFactory.setReadTimeout(timeoutMs);
|
return OpenAiApi.builder()
|
.baseUrl(aiParam.getBaseUrl())
|
.apiKey(aiParam.getApiKey())
|
.restClientBuilder(RestClient.builder().requestFactory(requestFactory))
|
.webClientBuilder(WebClient.builder())
|
.build();
|
}
|
|
private String formatDate(Date date) {
|
return new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(date);
|
}
|
}
|