zhou zhou
1 天以前 1dcfa3702505f0c431757312b5304531029f90f6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
package com.vincent.rsf.server.ai.service.impl;
 
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.vincent.rsf.framework.exception.CoolException;
import com.vincent.rsf.server.ai.config.AiDefaults;
import com.vincent.rsf.server.ai.dto.AiParamValidateResultDto;
import com.vincent.rsf.server.ai.entity.AiParam;
import lombok.extern.slf4j.Slf4j;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Component;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.util.StringUtils;
import org.springframework.web.client.RestClient;
import org.springframework.web.client.RestClientResponseException;
 
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
 
@Component
@Slf4j
@RequiredArgsConstructor
public class AiParamValidationSupport {
 
    private static final String PROBE_MESSAGE = "请回复 OK";
    private static final int MAX_DETAIL_TEXT_LENGTH = 8000;
 
    private final ObjectMapper objectMapper;
 
    /**
     * 对一份 AI 参数草稿做真实连通性校验。
     * 校验方式不是简单判断字段非空,而是直接构造聊天模型并发起一次最小探测调用,
     * 用返回结果和耗时生成前端可展示的校验报告。
     */
    public AiParamValidateResultDto validate(AiParam aiParam) {
        long startedAt = System.currentTimeMillis();
        ValidationRequestContext context = buildValidationRequestContext(aiParam);
        try {
            ValidationResponseSnapshot responseSnapshot = executeValidation(context);
            String assistantText = extractAssistantText(responseSnapshot.responseBody());
            if (!StringUtils.hasText(assistantText)) {
                throw new CoolException("模型已连接,但未返回有效响应");
            }
            long elapsedMs = System.currentTimeMillis() - startedAt;
            String detail = buildValidationDetail(
                    context.requestUrl(),
                    context.requestBody(),
                    responseSnapshot.statusCode(),
                    responseSnapshot.responseBody(),
                    assistantText,
                    null
            );
            log.info("AI 参数草稿校验成功, model={}, requestUrl={}, responseStatus={}",
                    aiParam.getModel(), context.requestUrl(), responseSnapshot.statusCode());
            return AiParamValidateResultDto.builder()
                    .status(AiDefaults.PARAM_VALIDATE_VALID)
                    .message("模型连通成功")
                    .detail(detail)
                    .model(aiParam.getModel())
                    .elapsedMs(elapsedMs)
                    .validatedAt(formatDate(new Date()))
                    .requestUrl(context.requestUrl())
                    .requestBody(context.requestBody())
                    .responseStatus(responseSnapshot.statusCode())
                    .responseBody(abbreviate(responseSnapshot.responseBody()))
                    .build();
        } catch (Exception e) {
            long elapsedMs = System.currentTimeMillis() - startedAt;
            ValidationErrorSnapshot errorSnapshot = extractErrorSnapshot(e);
            String message = e instanceof CoolException
                    ? e.getMessage()
                    : "模型验证失败: " + firstNonBlank(
                    extractResponseErrorMessage(errorSnapshot.responseBody()),
                    errorSnapshot.errorMessage(),
                    e.getMessage());
            String detail = buildValidationDetail(
                    context.requestUrl(),
                    context.requestBody(),
                    errorSnapshot.statusCode(),
                    errorSnapshot.responseBody(),
                    null,
                    e
            );
            log.warn("AI 参数草稿校验失败, model={}, requestUrl={}, responseStatus={}, error={}",
                    aiParam.getModel(), context.requestUrl(), errorSnapshot.statusCode(), errorSnapshot.errorMessage(), e);
            return AiParamValidateResultDto.builder()
                    .status(AiDefaults.PARAM_VALIDATE_INVALID)
                    .message(message)
                    .detail(detail)
                    .model(aiParam.getModel())
                    .elapsedMs(elapsedMs)
                    .validatedAt(formatDate(new Date()))
                    .requestUrl(context.requestUrl())
                    .requestBody(context.requestBody())
                    .responseStatus(errorSnapshot.statusCode())
                    .responseBody(abbreviate(errorSnapshot.responseBody()))
                    .build();
        }
    }
 
    private ValidationRequestContext buildValidationRequestContext(AiParam aiParam) {
        AiOpenAiApiSupport.EndpointConfig endpointConfig = AiOpenAiApiSupport.resolveEndpointConfig(aiParam.getBaseUrl());
        String requestUrl = endpointConfig.baseUrl() + endpointConfig.completionsPath();
        Map<String, Object> requestBody = new LinkedHashMap<>();
        requestBody.put("model", aiParam.getModel());
        requestBody.put("messages", List.of(Map.of("role", "user", "content", PROBE_MESSAGE)));
        requestBody.put("stream", false);
        if (aiParam.getTemperature() != null) {
            requestBody.put("temperature", aiParam.getTemperature());
        }
        if (aiParam.getTopP() != null) {
            requestBody.put("top_p", aiParam.getTopP());
        }
        if (aiParam.getMaxTokens() != null) {
            requestBody.put("max_tokens", aiParam.getMaxTokens());
        }
        return new ValidationRequestContext(
                buildRestClient(aiParam),
                requestUrl,
                requestBody,
                toPrettyJson(requestBody)
        );
    }
 
    private RestClient buildRestClient(AiParam aiParam) {
        int timeoutMs = aiParam.getTimeoutMs() == null ? AiDefaults.DEFAULT_TIMEOUT_MS : aiParam.getTimeoutMs();
        SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory();
        requestFactory.setConnectTimeout(timeoutMs);
        requestFactory.setReadTimeout(timeoutMs);
        return RestClient.builder()
                .requestFactory(requestFactory)
                .defaultHeader("Authorization", "Bearer " + aiParam.getApiKey())
                .build();
    }
 
    private ValidationResponseSnapshot executeValidation(ValidationRequestContext context) {
        ResponseEntity<String> responseEntity = context.restClient().post()
                .uri(context.requestUrl())
                .contentType(MediaType.APPLICATION_JSON)
                .accept(MediaType.APPLICATION_JSON)
                .body(context.requestBodyMap())
                .retrieve()
                .toEntity(String.class);
        return new ValidationResponseSnapshot(
                responseEntity.getStatusCode().value(),
                responseEntity.getBody()
        );
    }
 
    private String extractAssistantText(String responseBody) {
        if (!StringUtils.hasText(responseBody)) {
            return "";
        }
        try {
            JsonNode root = objectMapper.readTree(responseBody);
            List<JsonNode> candidates = List.of(
                    root.at("/choices/0/message/content"),
                    root.at("/choices/0/text"),
                    root.at("/output_text"),
                    root.at("/output/0/content/0/text")
            );
            for (JsonNode candidate : candidates) {
                String text = flattenText(candidate);
                if (StringUtils.hasText(text)) {
                    return text;
                }
            }
        } catch (Exception e) {
            log.debug("解析 AI 参数草稿校验响应失败: {}", e.getMessage());
        }
        return "";
    }
 
    private String flattenText(JsonNode node) {
        if (node == null || node.isMissingNode() || node.isNull()) {
            return "";
        }
        if (node.isTextual()) {
            return node.asText();
        }
        if (node.isArray()) {
            List<String> parts = new ArrayList<>();
            for (JsonNode item : node) {
                String text = flattenText(item);
                if (StringUtils.hasText(text)) {
                    parts.add(text);
                }
            }
            return String.join("\n", parts).trim();
        }
        if (node.isObject()) {
            String text = flattenText(node.get("text"));
            if (StringUtils.hasText(text)) {
                return text;
            }
            text = flattenText(node.get("content"));
            if (StringUtils.hasText(text)) {
                return text;
            }
        }
        return "";
    }
 
    private ValidationErrorSnapshot extractErrorSnapshot(Throwable throwable) {
        RestClientResponseException responseException = findResponseException(throwable);
        if (responseException != null) {
            return new ValidationErrorSnapshot(
                    responseException.getStatusCode().value(),
                    responseException.getResponseBodyAsString(),
                    rootMessage(throwable)
            );
        }
        return new ValidationErrorSnapshot(null, null, rootMessage(throwable));
    }
 
    private RestClientResponseException findResponseException(Throwable throwable) {
        Throwable current = throwable;
        while (current != null) {
            if (current instanceof RestClientResponseException responseException) {
                return responseException;
            }
            current = current.getCause();
        }
        return null;
    }
 
    private String extractResponseErrorMessage(String responseBody) {
        if (!StringUtils.hasText(responseBody)) {
            return "";
        }
        try {
            JsonNode root = objectMapper.readTree(responseBody);
            List<JsonNode> candidates = List.of(
                    root.at("/error/message"),
                    root.at("/message"),
                    root.at("/detail")
            );
            for (JsonNode candidate : candidates) {
                String text = flattenText(candidate);
                if (StringUtils.hasText(text)) {
                    return text;
                }
            }
        } catch (Exception e) {
            log.debug("解析 AI 参数草稿校验错误响应失败: {}", e.getMessage());
        }
        return "";
    }
 
    private String buildValidationDetail(String requestUrl,
                                         String requestBody,
                                         Integer responseStatus,
                                         String responseBody,
                                         String assistantText,
                                         Throwable throwable) {
        List<String> sections = new ArrayList<>();
        sections.add("请求 URL:\n" + defaultText(requestUrl));
        sections.add("请求体:\n" + defaultText(requestBody));
        sections.add("HTTP 状态:\n" + (responseStatus == null ? "--" : responseStatus));
        if (StringUtils.hasText(assistantText)) {
            sections.add("解析结果:\n" + abbreviate(assistantText));
        }
        sections.add("响应结果:\n" + defaultText(abbreviate(responseBody)));
        if (throwable != null) {
            sections.add("异常类型:\n" + throwable.getClass().getName());
            sections.add("异常信息:\n" + defaultText(rootMessage(throwable)));
        }
        return String.join("\n\n", sections);
    }
 
    private String toPrettyJson(Object value) {
        try {
            return objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(value);
        } catch (Exception e) {
            return String.valueOf(value);
        }
    }
 
    private String rootMessage(Throwable throwable) {
        Throwable current = throwable;
        while (current != null && current.getCause() != null) {
            current = current.getCause();
        }
        return current == null ? "" : current.getMessage();
    }
 
    private String defaultText(String value) {
        return StringUtils.hasText(value) ? value : "--";
    }
 
    private String firstNonBlank(String... values) {
        for (String value : values) {
            if (StringUtils.hasText(value)) {
                return value;
            }
        }
        return "";
    }
 
    private String abbreviate(String value) {
        if (!StringUtils.hasText(value)) {
            return value;
        }
        if (value.length() <= MAX_DETAIL_TEXT_LENGTH) {
            return value;
        }
        return value.substring(0, MAX_DETAIL_TEXT_LENGTH) + "\n...(已截断)";
    }
 
    private String formatDate(Date date) {
        /** 统一输出给前端的校验时间格式。 */
        return new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(date);
    }
 
    private record ValidationRequestContext(RestClient restClient,
                                            String requestUrl,
                                            Map<String, Object> requestBodyMap,
                                            String requestBody) {
    }
 
    private record ValidationResponseSnapshot(int statusCode, String responseBody) {
    }
 
    private record ValidationErrorSnapshot(Integer statusCode, String responseBody, String errorMessage) {
    }
}