package com.zy.ai.gateway;
|
|
import com.alibaba.fastjson.JSON;
|
import com.zy.ai.entity.LlmCallLog;
|
import com.zy.ai.entity.LlmRouteConfig;
|
import com.zy.ai.gateway.adapter.AiProviderException;
|
import com.zy.ai.gateway.adapter.AiProviderAdapter;
|
import com.zy.ai.gateway.adapter.AiProviderAdapterRegistry;
|
import com.zy.ai.gateway.model.AiRequest;
|
import com.zy.ai.gateway.model.AiResponse;
|
import com.zy.ai.service.AiTokenUsageService;
|
import com.zy.ai.service.LlmCallLogService;
|
import com.zy.ai.service.LlmRoutingService;
|
import lombok.RequiredArgsConstructor;
|
import lombok.extern.slf4j.Slf4j;
|
import org.springframework.beans.factory.annotation.Value;
|
import org.springframework.stereotype.Service;
|
|
import java.util.ArrayList;
|
import java.util.Date;
|
import java.util.HashMap;
|
import java.util.List;
|
import java.util.Locale;
|
import java.util.Optional;
|
import java.util.UUID;
|
|
@Slf4j
|
@Service
|
@RequiredArgsConstructor
|
public class AiGatewayService {
|
|
private static final int LOG_TEXT_LIMIT = 16000;
|
|
private final LlmRoutingService llmRoutingService;
|
private final AiProviderAdapterRegistry adapterRegistry;
|
private final LlmCallLogService llmCallLogService;
|
private final AiTokenUsageService aiTokenUsageService;
|
|
@Value("${llm.base-url:}")
|
private String fallbackBaseUrl;
|
|
@Value("${llm.api-key:}")
|
private String fallbackApiKey;
|
|
@Value("${llm.model:}")
|
private String fallbackModel;
|
|
@Value("${llm.thinking:false}")
|
private String fallbackThinking;
|
|
public AiResponse generate(AiRequest request) {
|
String traceId = nextTraceId();
|
Throwable lastError = null;
|
List<LlmRouteConfig> routes = resolveRoutes();
|
for (int routeIndex = 0; routeIndex < routes.size(); routeIndex++) {
|
LlmRouteConfig route = routes.get(routeIndex);
|
Optional<AiProviderAdapter> adapter = adapterRegistry.findAdapter(route, request);
|
if (adapter.isEmpty()) {
|
continue;
|
}
|
long start = System.currentTimeMillis();
|
try {
|
AiResponse response = adapter.get().generate(route, request);
|
if (isValidResponse(response)) {
|
markSuccess(route);
|
recordCall(traceId, request, response, route, routeIndex + 1, true, 200,
|
System.currentTimeMillis() - start, "none", null);
|
return response;
|
}
|
lastError = new IllegalStateException("LLM 响应为空");
|
markFailure(route, lastError, false);
|
recordCall(traceId, request, response, route, routeIndex + 1, false, null,
|
System.currentTimeMillis() - start, "error", lastError);
|
} catch (Throwable ex) {
|
lastError = ex;
|
boolean quota = isQuotaExhausted(ex);
|
boolean canSwitch = shouldSwitch(route, quota);
|
markFailure(route, ex, canSwitch);
|
recordCall(traceId, request, null, route, routeIndex + 1, false, statusCodeOf(ex),
|
System.currentTimeMillis() - start, quota ? "quota" : "error", ex);
|
if (!canSwitch) {
|
break;
|
}
|
log.warn("AI 网关切换到下一路由, route={}, reason={}", route.getName(), ex.getMessage());
|
}
|
}
|
throw new IllegalStateException(lastError == null ? "未找到可用 AI 路由" : lastError.getMessage(), lastError);
|
}
|
|
private List<LlmRouteConfig> resolveRoutes() {
|
List<LlmRouteConfig> routes = new ArrayList<>(llmRoutingService.listAvailableRoutes());
|
if (routes.isEmpty() && !isBlank(fallbackBaseUrl) && !isBlank(fallbackApiKey) && !isBlank(fallbackModel)) {
|
routes.add(fallbackRoute());
|
}
|
return routes;
|
}
|
|
private LlmRouteConfig fallbackRoute() {
|
LlmRouteConfig route = new LlmRouteConfig();
|
route.setName("fallback-yml");
|
route.setProviderType("OPENAI_COMPATIBLE");
|
route.setProtocolType("OPENAI_CHAT_COMPLETIONS");
|
route.setAuthType("BEARER");
|
route.setBaseUrl(fallbackBaseUrl);
|
route.setApiKey(fallbackApiKey);
|
route.setModel(fallbackModel);
|
route.setThinking(isFallbackThinkingEnabled() ? (short) 1 : (short) 0);
|
route.setSwitchOnError((short) 1);
|
route.setSwitchOnQuota((short) 1);
|
route.setCooldownSeconds(300);
|
return route;
|
}
|
|
private boolean isValidResponse(AiResponse response) {
|
if (response == null) {
|
return false;
|
}
|
if (!isBlank(response.getText())) {
|
return true;
|
}
|
return response.getMessage() != null
|
&& response.getMessage().getToolCalls() != null
|
&& !response.getMessage().getToolCalls().isEmpty();
|
}
|
|
private boolean shouldSwitch(LlmRouteConfig route, boolean quota) {
|
if (quota) {
|
return route.getSwitchOnQuota() == null || route.getSwitchOnQuota() == 1;
|
}
|
return route.getSwitchOnError() == null || route.getSwitchOnError() == 1;
|
}
|
|
private void markSuccess(LlmRouteConfig route) {
|
if (route.getId() != null) {
|
llmRoutingService.markSuccess(route.getId());
|
}
|
}
|
|
private void markFailure(LlmRouteConfig route, Throwable ex, boolean enterCooldown) {
|
if (route.getId() != null) {
|
llmRoutingService.markFailure(route.getId(), ex.getMessage(), enterCooldown, route.getCooldownSeconds());
|
}
|
}
|
|
private boolean isQuotaExhausted(Throwable ex) {
|
Integer status = statusCodeOf(ex);
|
if (status != null && status == 429) {
|
return true;
|
}
|
String text = responseBodyOf(ex);
|
text = text == null ? "" : text.toLowerCase(Locale.ROOT);
|
return text.contains("insufficient_quota")
|
|| text.contains("quota")
|
|| text.contains("余额")
|
|| text.contains("用量")
|
|| text.contains("超限")
|
|| text.contains("rate limit");
|
}
|
|
private Integer statusCodeOf(Throwable ex) {
|
if (ex instanceof AiProviderException) {
|
return ((AiProviderException) ex).getStatusCode();
|
}
|
return null;
|
}
|
|
private String responseBodyOf(Throwable ex) {
|
if (ex instanceof AiProviderException) {
|
return ((AiProviderException) ex).getResponseBody();
|
}
|
return ex == null ? null : ex.getMessage();
|
}
|
|
private void recordCall(String traceId,
|
AiRequest request,
|
AiResponse response,
|
LlmRouteConfig route,
|
int attemptNo,
|
boolean success,
|
Integer httpStatus,
|
long latencyMs,
|
String switchMode,
|
Throwable err) {
|
LlmCallLog item = new LlmCallLog();
|
item.setTraceId(cut(traceId, 64));
|
item.setScene(cut(request == null ? null : request.getScene(), 64));
|
item.setStream((short) (request != null && Boolean.TRUE.equals(request.getStream()) ? 1 : 0));
|
item.setAttemptNo(attemptNo);
|
if (route != null) {
|
item.setRouteId(route.getId());
|
item.setRouteName(cut(route.getName(), 128));
|
item.setBaseUrl(cut(route.getBaseUrl(), 255));
|
item.setModel(cut(route.getModel(), 128));
|
}
|
item.setSuccess((short) (success ? 1 : 0));
|
item.setHttpStatus(httpStatus);
|
item.setLatencyMs(latencyMs < 0 ? 0 : latencyMs);
|
item.setSwitchMode(cut(switchMode, 32));
|
item.setRequestContent(cut(JSON.toJSONString(request), LOG_TEXT_LIMIT));
|
item.setResponseContent(cut(responseText(response), LOG_TEXT_LIMIT));
|
item.setErrorType(err == null ? null : cut(err.getClass().getSimpleName(), 128));
|
item.setErrorMessage(err == null ? null : cut(err.getMessage(), 1024));
|
item.setExtra(cut(extraPayload(route, response), 512));
|
item.setCreateTime(new Date());
|
llmCallLogService.saveIgnoreError(item);
|
|
// 累加 token 到独立存储
|
if (success && response != null && response.getUsage() != null) {
|
aiTokenUsageService.incrementTokens(
|
response.getUsage().getInputTokens() == null ? 0 : response.getUsage().getInputTokens(),
|
response.getUsage().getOutputTokens() == null ? 0 : response.getUsage().getOutputTokens(),
|
response.getUsage().getTotalTokens() == null ? 0 : response.getUsage().getTotalTokens(),
|
1);
|
}
|
}
|
|
private String responseText(AiResponse response) {
|
if (response == null) {
|
return null;
|
}
|
if (!isBlank(response.getText())) {
|
return response.getText();
|
}
|
return JSON.toJSONString(response.getMessage());
|
}
|
|
private String extraPayload(LlmRouteConfig route, AiResponse response) {
|
HashMap<String, Object> payload = new HashMap<>();
|
if (route != null) {
|
payload.put("providerType", route.getProviderType());
|
payload.put("protocolType", route.getProtocolType());
|
}
|
if (response != null && response.getUsage() != null) {
|
payload.put("inputTokens", response.getUsage().getInputTokens());
|
payload.put("outputTokens", response.getUsage().getOutputTokens());
|
payload.put("totalTokens", response.getUsage().getTotalTokens());
|
}
|
return payload.isEmpty() ? null : JSON.toJSONString(payload);
|
}
|
|
private String cut(String text, int maxLen) {
|
if (text == null) {
|
return null;
|
}
|
String clean = text.replace("\r", " ");
|
return clean.length() > maxLen ? clean.substring(0, maxLen) : clean;
|
}
|
|
private boolean isFallbackThinkingEnabled() {
|
String value = fallbackThinking == null ? "" : fallbackThinking.trim();
|
return "true".equalsIgnoreCase(value) || "1".equals(value) || "enable".equalsIgnoreCase(value);
|
}
|
|
private String nextTraceId() {
|
return UUID.randomUUID().toString().replace("-", "");
|
}
|
|
private boolean isBlank(String value) {
|
return value == null || value.trim().isEmpty();
|
}
|
}
|