package com.zy.ai.gateway;
|
|
import com.zy.ai.entity.LlmCallLog;
|
import com.zy.ai.entity.LlmRouteConfig;
|
import com.zy.ai.gateway.adapter.AiProviderException;
|
import com.zy.ai.gateway.adapter.AiProviderAdapter;
|
import com.zy.ai.gateway.adapter.AiProviderAdapterRegistry;
|
import com.zy.ai.gateway.model.AiMessage;
|
import com.zy.ai.gateway.model.AiRequest;
|
import com.zy.ai.gateway.model.AiResponse;
|
import com.zy.ai.gateway.model.AiStreamEvent;
|
import com.zy.ai.service.LlmCallLogService;
|
import com.zy.ai.service.LlmRoutingService;
|
import org.junit.jupiter.api.Test;
|
import org.mockito.ArgumentCaptor;
|
import reactor.core.publisher.Flux;
|
|
import java.util.List;
|
import java.util.Map;
|
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertFalse;
|
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.times;
|
import static org.mockito.Mockito.verify;
|
import static org.mockito.Mockito.when;
|
|
class AiGatewayServiceTest {
|
|
@Test
|
void generateSwitchesToNextRouteWhenFirstRouteFails() {
|
LlmRouteConfig firstRoute = route(1L, "first");
|
LlmRouteConfig secondRoute = route(2L, "second");
|
LlmRoutingService routingService = mock(LlmRoutingService.class);
|
when(routingService.listAvailableRoutes()).thenReturn(List.of(firstRoute, secondRoute));
|
|
AiProviderAdapter adapter = new FailingThenSuccessAdapter();
|
LlmCallLogService callLogService = mock(LlmCallLogService.class);
|
AiGatewayService gatewayService = new AiGatewayService(
|
routingService,
|
new AiProviderAdapterRegistry(List.of(adapter)),
|
callLogService
|
);
|
|
AiResponse response = gatewayService.generate(new AiRequest());
|
|
assertEquals("ok", response.getText());
|
verify(routingService).markFailure(1L, "first failed", true, 300);
|
verify(routingService).markSuccess(2L);
|
ArgumentCaptor<LlmCallLog> logCaptor = ArgumentCaptor.forClass(LlmCallLog.class);
|
verify(callLogService, times(2)).saveIgnoreError(logCaptor.capture());
|
List<LlmCallLog> logs = logCaptor.getAllValues();
|
assertNotNull(logs.get(0).getTraceId());
|
assertFalse(logs.get(0).getTraceId().isBlank());
|
assertEquals(logs.get(0).getTraceId(), logs.get(1).getTraceId());
|
}
|
|
@Test
|
void generateAcceptsToolOnlyResponse() {
|
LlmRouteConfig route = route(1L, "tool-route");
|
LlmRoutingService routingService = mock(LlmRoutingService.class);
|
when(routingService.listAvailableRoutes()).thenReturn(List.of(route));
|
LlmCallLogService callLogService = mock(LlmCallLogService.class);
|
|
AiGatewayService gatewayService = new AiGatewayService(
|
routingService,
|
new AiProviderAdapterRegistry(List.of(new ToolOnlyAdapter())),
|
callLogService
|
);
|
|
AiResponse response = gatewayService.generate(new AiRequest());
|
|
assertNotNull(response.getMessage().getToolCalls());
|
assertEquals("call_1", ((Map<?, ?>) response.getMessage().getToolCalls().get(0)).get("id"));
|
verify(routingService).markSuccess(1L);
|
}
|
|
@Test
|
void generateUsesSwitchOnQuotaForQuotaFailures() {
|
LlmRouteConfig firstRoute = route(1L, "first");
|
firstRoute.setSwitchOnError((short) 0);
|
firstRoute.setSwitchOnQuota((short) 1);
|
LlmRouteConfig secondRoute = route(2L, "second");
|
|
LlmRoutingService routingService = mock(LlmRoutingService.class);
|
when(routingService.listAvailableRoutes()).thenReturn(List.of(firstRoute, secondRoute));
|
LlmCallLogService callLogService = mock(LlmCallLogService.class);
|
|
AiGatewayService gatewayService = new AiGatewayService(
|
routingService,
|
new AiProviderAdapterRegistry(List.of(new QuotaThenSuccessAdapter())),
|
callLogService
|
);
|
|
AiResponse response = gatewayService.generate(new AiRequest());
|
|
assertEquals("ok", response.getText());
|
verify(routingService).markFailure(1L, "quota exceeded", true, 300);
|
verify(routingService).markSuccess(2L);
|
}
|
|
private LlmRouteConfig route(Long id, String name) {
|
LlmRouteConfig route = new LlmRouteConfig();
|
route.setId(id);
|
route.setName(name);
|
route.setBaseUrl("https://example.com/v1");
|
route.setApiKey("sk-test");
|
route.setModel("gpt-test");
|
route.setProviderType("OPENAI_COMPATIBLE");
|
route.setProtocolType("OPENAI_CHAT_COMPLETIONS");
|
route.setSwitchOnError((short) 1);
|
route.setCooldownSeconds(300);
|
return route;
|
}
|
|
private static class FailingThenSuccessAdapter implements AiProviderAdapter {
|
|
@Override
|
public boolean supports(LlmRouteConfig routeConfig, AiRequest request) {
|
return true;
|
}
|
|
@Override
|
public AiResponse generate(LlmRouteConfig routeConfig, AiRequest request) {
|
if ("first".equals(routeConfig.getName())) {
|
throw new IllegalStateException("first failed");
|
}
|
AiResponse response = new AiResponse();
|
response.setText("ok");
|
return response;
|
}
|
|
@Override
|
public Flux<AiStreamEvent> stream(LlmRouteConfig routeConfig, AiRequest request) {
|
return Flux.empty();
|
}
|
}
|
|
private static class ToolOnlyAdapter implements AiProviderAdapter {
|
|
@Override
|
public boolean supports(LlmRouteConfig routeConfig, AiRequest request) {
|
return true;
|
}
|
|
@Override
|
public AiResponse generate(LlmRouteConfig routeConfig, AiRequest request) {
|
AiMessage message = new AiMessage();
|
message.setRole("assistant");
|
message.setToolCalls(List.of(Map.of("id", "call_1")));
|
|
AiResponse response = new AiResponse();
|
response.setMessage(message);
|
return response;
|
}
|
|
@Override
|
public Flux<AiStreamEvent> stream(LlmRouteConfig routeConfig, AiRequest request) {
|
return Flux.empty();
|
}
|
}
|
|
private static class QuotaThenSuccessAdapter implements AiProviderAdapter {
|
|
@Override
|
public boolean supports(LlmRouteConfig routeConfig, AiRequest request) {
|
return true;
|
}
|
|
@Override
|
public AiResponse generate(LlmRouteConfig routeConfig, AiRequest request) {
|
if ("first".equals(routeConfig.getName())) {
|
throw new AiProviderException("quota exceeded", routeConfig.getId(), 429, "insufficient_quota");
|
}
|
AiResponse response = new AiResponse();
|
response.setText("ok");
|
return response;
|
}
|
|
@Override
|
public Flux<AiStreamEvent> stream(LlmRouteConfig routeConfig, AiRequest request) {
|
return Flux.empty();
|
}
|
}
|
}
|