package com.vincent.rsf.server.ai.service.impl.chat;
|
|
import com.vincent.rsf.server.ai.dto.AiChatRequest;
|
import com.vincent.rsf.server.ai.dto.AiChatMessageDto;
|
import com.vincent.rsf.server.ai.dto.AiResolvedConfig;
|
import com.vincent.rsf.server.ai.entity.AiParam;
|
import com.vincent.rsf.server.ai.entity.AiPrompt;
|
import com.vincent.rsf.server.ai.service.AiCallLogService;
|
import com.vincent.rsf.server.ai.service.AiChatMemoryService;
|
import com.vincent.rsf.server.ai.service.AiConfigResolverService;
|
import com.vincent.rsf.server.ai.service.AiParamService;
|
import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory;
|
import com.vincent.rsf.server.ai.store.AiChatRateLimiter;
|
import com.vincent.rsf.server.ai.store.AiStreamStateStore;
|
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.extension.ExtendWith;
|
import org.mockito.Mock;
|
import org.mockito.junit.jupiter.MockitoExtension;
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
|
import java.util.List;
|
|
import static org.mockito.ArgumentMatchers.any;
|
import static org.mockito.ArgumentMatchers.eq;
|
import static org.mockito.Mockito.never;
|
import static org.mockito.Mockito.verify;
|
import static org.mockito.Mockito.when;
|
|
@ExtendWith(MockitoExtension.class)
|
class AiChatOrchestratorTest {
|
|
@Mock
|
private AiConfigResolverService aiConfigResolverService;
|
@Mock
|
private AiChatMemoryService aiChatMemoryService;
|
@Mock
|
private AiParamService aiParamService;
|
@Mock
|
private McpMountRuntimeFactory mcpMountRuntimeFactory;
|
@Mock
|
private AiCallLogService aiCallLogService;
|
@Mock
|
private AiChatRateLimiter aiChatRateLimiter;
|
@Mock
|
private AiStreamStateStore aiStreamStateStore;
|
@Mock
|
private AiChatRuntimeAssembler aiChatRuntimeAssembler;
|
@Mock
|
private AiPromptMessageBuilder aiPromptMessageBuilder;
|
@Mock
|
private AiOpenAiChatModelFactory aiOpenAiChatModelFactory;
|
@Mock
|
private AiToolObservationService aiToolObservationService;
|
@Mock
|
private AiSseEventPublisher aiSseEventPublisher;
|
@Mock
|
private AiChatFailureHandler aiChatFailureHandler;
|
|
@Test
|
void shouldShortCircuitWhenRateLimited() {
|
AiChatOrchestrator aiChatOrchestrator = new AiChatOrchestrator(
|
aiConfigResolverService,
|
aiChatMemoryService,
|
aiParamService,
|
mcpMountRuntimeFactory,
|
aiCallLogService,
|
aiChatRateLimiter,
|
aiStreamStateStore,
|
aiChatRuntimeAssembler,
|
aiPromptMessageBuilder,
|
aiOpenAiChatModelFactory,
|
aiToolObservationService,
|
aiSseEventPublisher,
|
aiChatFailureHandler
|
);
|
AiChatRequest request = new AiChatRequest();
|
request.setRequestId("req-1");
|
request.setPromptCode("home.default");
|
request.setMessages(List.of(message("user", "hello")));
|
AiResolvedConfig config = AiResolvedConfig.builder()
|
.promptCode("home.default")
|
.aiParam(new AiParam().setModel("gpt-test"))
|
.prompt(new AiPrompt().setName("default"))
|
.mcpMounts(List.of())
|
.build();
|
when(aiConfigResolverService.resolve("home.default", 1L, null)).thenReturn(config);
|
when(aiParamService.listChatModelOptions(1L)).thenReturn(List.of());
|
when(aiChatRateLimiter.allowChatRequest(1L, 2L, "home.default")).thenReturn(false);
|
when(aiChatFailureHandler.buildAiException(eq("AI_RATE_LIMITED"), any(), eq("RATE_LIMIT"), any(), eq(null)))
|
.thenCallRealMethod();
|
|
aiChatOrchestrator.executeStream(request, 2L, 1L, new SseEmitter(1000L));
|
|
verify(aiChatFailureHandler).handleStreamFailure(any(), eq("req-1"), eq(null), eq(null), any(Long.class),
|
eq(null), any(), eq(null), eq(0L), eq(0L), eq(null), eq(1L), eq(2L), eq("home.default"));
|
verify(aiCallLogService, never()).startCallLog(any(), any(), any(), any(), any(), any(), any(), any(), any(), any());
|
}
|
|
private AiChatMessageDto message(String role, String content) {
|
AiChatMessageDto dto = new AiChatMessageDto();
|
dto.setRole(role);
|
dto.setContent(content);
|
return dto;
|
}
|
}
|