zhou zhou
14 小时以前 6477d7156272a6f1fe126c781958369bb10970c6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
package com.vincent.rsf.server.ai.service.impl;
 
import com.vincent.rsf.server.ai.config.AiDefaults;
import com.vincent.rsf.server.ai.entity.AiParam;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.util.StringUtils;
import org.springframework.web.client.RestClient;
import org.springframework.web.reactive.function.client.WebClient;
 
import java.util.Locale;
 
final class AiOpenAiApiSupport {
 
    private static final String DEFAULT_COMPLETIONS_PATH = "/v1/chat/completions";
    private static final String DEFAULT_EMBEDDINGS_PATH = "/v1/embeddings";
    private static final String V1_SEGMENT = "/v1";
    private static final String COMPLETIONS_SEGMENT = "/chat/completions";
    private static final String EMBEDDINGS_SEGMENT = "/embeddings";
 
    private AiOpenAiApiSupport() {
    }
 
    static OpenAiApi buildOpenAiApi(AiParam aiParam) {
        int timeoutMs = aiParam.getTimeoutMs() == null ? AiDefaults.DEFAULT_TIMEOUT_MS : aiParam.getTimeoutMs();
        SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory();
        requestFactory.setConnectTimeout(timeoutMs);
        requestFactory.setReadTimeout(timeoutMs);
 
        EndpointConfig endpointConfig = resolveEndpointConfig(aiParam.getBaseUrl());
        return OpenAiApi.builder()
                .baseUrl(endpointConfig.baseUrl())
                .completionsPath(endpointConfig.completionsPath())
                .embeddingsPath(endpointConfig.embeddingsPath())
                .apiKey(aiParam.getApiKey())
                .restClientBuilder(RestClient.builder().requestFactory(requestFactory))
                .webClientBuilder(WebClient.builder())
                .build();
    }
 
    static EndpointConfig resolveEndpointConfig(String rawBaseUrl) {
        String normalizedBaseUrl = trimTrailingSlash(rawBaseUrl);
        String lowerCaseBaseUrl = normalizedBaseUrl.toLowerCase(Locale.ROOT);
 
        if (lowerCaseBaseUrl.endsWith(DEFAULT_COMPLETIONS_PATH)) {
            String baseUrl = trimTrailingSlash(normalizedBaseUrl.substring(0,
                    normalizedBaseUrl.length() - DEFAULT_COMPLETIONS_PATH.length()));
            return new EndpointConfig(baseUrl, DEFAULT_COMPLETIONS_PATH, DEFAULT_EMBEDDINGS_PATH);
        }
        if (lowerCaseBaseUrl.endsWith(DEFAULT_EMBEDDINGS_PATH)) {
            String baseUrl = trimTrailingSlash(normalizedBaseUrl.substring(0,
                    normalizedBaseUrl.length() - DEFAULT_EMBEDDINGS_PATH.length()));
            return new EndpointConfig(baseUrl, DEFAULT_COMPLETIONS_PATH, DEFAULT_EMBEDDINGS_PATH);
        }
        if (lowerCaseBaseUrl.endsWith(COMPLETIONS_SEGMENT)) {
            String baseUrl = trimTrailingSlash(normalizedBaseUrl.substring(0,
                    normalizedBaseUrl.length() - COMPLETIONS_SEGMENT.length()));
            return new EndpointConfig(baseUrl, COMPLETIONS_SEGMENT, EMBEDDINGS_SEGMENT);
        }
        if (lowerCaseBaseUrl.endsWith(EMBEDDINGS_SEGMENT)) {
            String baseUrl = trimTrailingSlash(normalizedBaseUrl.substring(0,
                    normalizedBaseUrl.length() - EMBEDDINGS_SEGMENT.length()));
            return new EndpointConfig(baseUrl, COMPLETIONS_SEGMENT, EMBEDDINGS_SEGMENT);
        }
        if (lowerCaseBaseUrl.endsWith(V1_SEGMENT)) {
            String baseUrl = trimTrailingSlash(normalizedBaseUrl.substring(0,
                    normalizedBaseUrl.length() - V1_SEGMENT.length()));
            return new EndpointConfig(baseUrl, DEFAULT_COMPLETIONS_PATH, DEFAULT_EMBEDDINGS_PATH);
        }
        return new EndpointConfig(normalizedBaseUrl, DEFAULT_COMPLETIONS_PATH, DEFAULT_EMBEDDINGS_PATH);
    }
 
    private static String trimTrailingSlash(String baseUrl) {
        String normalized = baseUrl == null ? "" : baseUrl.trim();
        if (!StringUtils.hasText(normalized)) {
            return normalized;
        }
        while (normalized.endsWith("/")) {
            normalized = normalized.substring(0, normalized.length() - 1);
        }
        return normalized;
    }
 
    record EndpointConfig(String baseUrl, String completionsPath, String embeddingsPath) {
    }
}