package com.vincent.rsf.server.ai.service.impl;
|
|
import com.vincent.rsf.server.ai.config.AiDefaults;
|
import com.vincent.rsf.server.ai.entity.AiParam;
|
import org.springframework.ai.openai.api.OpenAiApi;
|
import org.springframework.http.client.SimpleClientHttpRequestFactory;
|
import org.springframework.util.StringUtils;
|
import org.springframework.web.client.RestClient;
|
import org.springframework.web.reactive.function.client.WebClient;
|
|
import java.util.Locale;
|
|
final class AiOpenAiApiSupport {
|
|
private static final String DEFAULT_COMPLETIONS_PATH = "/v1/chat/completions";
|
private static final String DEFAULT_EMBEDDINGS_PATH = "/v1/embeddings";
|
private static final String V1_SEGMENT = "/v1";
|
private static final String COMPLETIONS_SEGMENT = "/chat/completions";
|
private static final String EMBEDDINGS_SEGMENT = "/embeddings";
|
|
private AiOpenAiApiSupport() {
|
}
|
|
static OpenAiApi buildOpenAiApi(AiParam aiParam) {
|
int timeoutMs = aiParam.getTimeoutMs() == null ? AiDefaults.DEFAULT_TIMEOUT_MS : aiParam.getTimeoutMs();
|
SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory();
|
requestFactory.setConnectTimeout(timeoutMs);
|
requestFactory.setReadTimeout(timeoutMs);
|
|
EndpointConfig endpointConfig = resolveEndpointConfig(aiParam.getBaseUrl());
|
return OpenAiApi.builder()
|
.baseUrl(endpointConfig.baseUrl())
|
.completionsPath(endpointConfig.completionsPath())
|
.embeddingsPath(endpointConfig.embeddingsPath())
|
.apiKey(aiParam.getApiKey())
|
.restClientBuilder(RestClient.builder().requestFactory(requestFactory))
|
.webClientBuilder(WebClient.builder())
|
.build();
|
}
|
|
static EndpointConfig resolveEndpointConfig(String rawBaseUrl) {
|
String normalizedBaseUrl = trimTrailingSlash(rawBaseUrl);
|
String lowerCaseBaseUrl = normalizedBaseUrl.toLowerCase(Locale.ROOT);
|
|
if (lowerCaseBaseUrl.endsWith(DEFAULT_COMPLETIONS_PATH)) {
|
String baseUrl = trimTrailingSlash(normalizedBaseUrl.substring(0,
|
normalizedBaseUrl.length() - DEFAULT_COMPLETIONS_PATH.length()));
|
return new EndpointConfig(baseUrl, DEFAULT_COMPLETIONS_PATH, DEFAULT_EMBEDDINGS_PATH);
|
}
|
if (lowerCaseBaseUrl.endsWith(DEFAULT_EMBEDDINGS_PATH)) {
|
String baseUrl = trimTrailingSlash(normalizedBaseUrl.substring(0,
|
normalizedBaseUrl.length() - DEFAULT_EMBEDDINGS_PATH.length()));
|
return new EndpointConfig(baseUrl, DEFAULT_COMPLETIONS_PATH, DEFAULT_EMBEDDINGS_PATH);
|
}
|
if (lowerCaseBaseUrl.endsWith(COMPLETIONS_SEGMENT)) {
|
String baseUrl = trimTrailingSlash(normalizedBaseUrl.substring(0,
|
normalizedBaseUrl.length() - COMPLETIONS_SEGMENT.length()));
|
return new EndpointConfig(baseUrl, COMPLETIONS_SEGMENT, EMBEDDINGS_SEGMENT);
|
}
|
if (lowerCaseBaseUrl.endsWith(EMBEDDINGS_SEGMENT)) {
|
String baseUrl = trimTrailingSlash(normalizedBaseUrl.substring(0,
|
normalizedBaseUrl.length() - EMBEDDINGS_SEGMENT.length()));
|
return new EndpointConfig(baseUrl, COMPLETIONS_SEGMENT, EMBEDDINGS_SEGMENT);
|
}
|
if (lowerCaseBaseUrl.endsWith(V1_SEGMENT)) {
|
String baseUrl = trimTrailingSlash(normalizedBaseUrl.substring(0,
|
normalizedBaseUrl.length() - V1_SEGMENT.length()));
|
return new EndpointConfig(baseUrl, DEFAULT_COMPLETIONS_PATH, DEFAULT_EMBEDDINGS_PATH);
|
}
|
return new EndpointConfig(normalizedBaseUrl, DEFAULT_COMPLETIONS_PATH, DEFAULT_EMBEDDINGS_PATH);
|
}
|
|
private static String trimTrailingSlash(String baseUrl) {
|
String normalized = baseUrl == null ? "" : baseUrl.trim();
|
if (!StringUtils.hasText(normalized)) {
|
return normalized;
|
}
|
while (normalized.endsWith("/")) {
|
normalized = normalized.substring(0, normalized.length() - 1);
|
}
|
return normalized;
|
}
|
|
record EndpointConfig(String baseUrl, String completionsPath, String embeddingsPath) {
|
}
|
}
|