package com.vincent.rsf.server.ai.service.impl.chat; import com.vincent.rsf.server.ai.dto.AiChatRequest; import com.vincent.rsf.server.ai.dto.AiChatMessageDto; import com.vincent.rsf.server.ai.dto.AiResolvedConfig; import com.vincent.rsf.server.ai.entity.AiParam; import com.vincent.rsf.server.ai.entity.AiPrompt; import com.vincent.rsf.server.ai.service.AiCallLogService; import com.vincent.rsf.server.ai.service.AiChatMemoryService; import com.vincent.rsf.server.ai.service.AiConfigResolverService; import com.vincent.rsf.server.ai.service.AiParamService; import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory; import com.vincent.rsf.server.ai.store.AiChatRateLimiter; import com.vincent.rsf.server.ai.store.AiStreamStateStore; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import java.util.List; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) class AiChatOrchestratorTest { @Mock private AiConfigResolverService aiConfigResolverService; @Mock private AiChatMemoryService aiChatMemoryService; @Mock private AiParamService aiParamService; @Mock private McpMountRuntimeFactory mcpMountRuntimeFactory; @Mock private AiCallLogService aiCallLogService; @Mock private AiChatRateLimiter aiChatRateLimiter; @Mock private AiStreamStateStore aiStreamStateStore; @Mock private AiChatRuntimeAssembler aiChatRuntimeAssembler; @Mock private AiPromptMessageBuilder aiPromptMessageBuilder; @Mock private AiOpenAiChatModelFactory aiOpenAiChatModelFactory; @Mock private AiToolObservationService aiToolObservationService; @Mock private AiSseEventPublisher aiSseEventPublisher; @Mock private AiChatFailureHandler aiChatFailureHandler; @Test void shouldShortCircuitWhenRateLimited() { AiChatOrchestrator aiChatOrchestrator = new AiChatOrchestrator( aiConfigResolverService, aiChatMemoryService, aiParamService, mcpMountRuntimeFactory, aiCallLogService, aiChatRateLimiter, aiStreamStateStore, aiChatRuntimeAssembler, aiPromptMessageBuilder, aiOpenAiChatModelFactory, aiToolObservationService, aiSseEventPublisher, aiChatFailureHandler ); AiChatRequest request = new AiChatRequest(); request.setRequestId("req-1"); request.setPromptCode("home.default"); request.setMessages(List.of(message("user", "hello"))); AiResolvedConfig config = AiResolvedConfig.builder() .promptCode("home.default") .aiParam(new AiParam().setModel("gpt-test")) .prompt(new AiPrompt().setName("default")) .mcpMounts(List.of()) .build(); when(aiConfigResolverService.resolve("home.default", 1L, null)).thenReturn(config); when(aiParamService.listChatModelOptions(1L)).thenReturn(List.of()); when(aiChatRateLimiter.allowChatRequest(1L, 2L, "home.default")).thenReturn(false); when(aiChatFailureHandler.buildAiException(eq("AI_RATE_LIMITED"), any(), eq("RATE_LIMIT"), any(), eq(null))) .thenCallRealMethod(); aiChatOrchestrator.executeStream(request, 2L, 1L, new SseEmitter(1000L)); verify(aiChatFailureHandler).handleStreamFailure(any(), eq("req-1"), eq(null), eq(null), any(Long.class), eq(null), any(), eq(null), eq(0L), eq(0L), eq(null), eq(1L), eq(2L), eq("home.default")); verify(aiCallLogService, never()).startCallLog(any(), any(), any(), any(), any(), any(), any(), any(), any(), any()); } private AiChatMessageDto message(String role, String content) { AiChatMessageDto dto = new AiChatMessageDto(); dto.setRole(role); dto.setContent(content); return dto; } }