package com.vincent.rsf.server.ai.service.impl.conversation;
|
|
import com.vincent.rsf.server.ai.dto.AiChatMessageDto;
|
import com.vincent.rsf.server.ai.entity.AiChatMessage;
|
import com.vincent.rsf.server.ai.entity.AiChatSession;
|
import com.vincent.rsf.server.ai.mapper.AiChatMessageMapper;
|
import com.vincent.rsf.server.ai.mapper.AiChatSessionMapper;
|
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.extension.ExtendWith;
|
import org.mockito.ArgumentCaptor;
|
import org.mockito.InjectMocks;
|
import org.mockito.Mock;
|
import org.mockito.junit.jupiter.MockitoExtension;
|
import org.springframework.transaction.support.TransactionSynchronization;
|
import org.springframework.transaction.support.TransactionSynchronizationManager;
|
|
import java.util.List;
|
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.mockito.Mockito.never;
|
import static org.mockito.ArgumentMatchers.any;
|
import static org.mockito.ArgumentMatchers.eq;
|
import static org.mockito.Mockito.verify;
|
import static org.mockito.Mockito.when;
|
|
@ExtendWith(MockitoExtension.class)
|
class AiConversationCommandServiceTest {
|
|
@Mock
|
private AiChatSessionMapper aiChatSessionMapper;
|
@Mock
|
private AiChatMessageMapper aiChatMessageMapper;
|
@Mock
|
private AiConversationQueryService aiConversationQueryService;
|
@Mock
|
private AiMemoryProfileService aiMemoryProfileService;
|
|
@InjectMocks
|
private AiConversationCommandService aiConversationCommandService;
|
|
@Test
|
void shouldBatchInsertMessagesAndScheduleRefreshWhenSavingRound() {
|
AiChatSession session = new AiChatSession().setId(10L).setTitle(null);
|
when(aiConversationQueryService.findNextSeqNo(10L)).thenReturn(3);
|
when(aiConversationQueryService.resolveUpdatedTitle(eq(null), any())).thenReturn("new title");
|
|
aiConversationCommandService.saveRound(session, 7L, 8L,
|
List.of(message("user", "hello"), message("assistant", "draft answer")),
|
"final answer");
|
|
ArgumentCaptor<List<AiChatMessage>> captor = ArgumentCaptor.forClass(List.class);
|
verify(aiChatMessageMapper).insertBatch(captor.capture());
|
assertEquals(3, captor.getValue().size());
|
assertEquals(3, captor.getValue().get(0).getSeqNo());
|
assertEquals(5, captor.getValue().get(2).getSeqNo());
|
verify(aiMemoryProfileService).scheduleMemoryProfileRefresh(10L, 7L, 8L);
|
verify(aiConversationQueryService).evictConversationCaches(8L, 7L);
|
}
|
|
@Test
|
void shouldDeferConversationSideEffectsUntilAfterCommitWhenSynchronizationActive() {
|
AiChatSession session = new AiChatSession().setId(10L).setTitle(null);
|
when(aiConversationQueryService.findNextSeqNo(10L)).thenReturn(1);
|
when(aiConversationQueryService.resolveUpdatedTitle(eq(null), any())).thenReturn("new title");
|
|
TransactionSynchronizationManager.initSynchronization();
|
try {
|
aiConversationCommandService.saveRound(session, 7L, 8L, List.of(message("user", "hello")), "final answer");
|
|
verify(aiConversationQueryService, never()).evictConversationCaches(8L, 7L);
|
verify(aiMemoryProfileService, never()).scheduleMemoryProfileRefresh(10L, 7L, 8L);
|
|
for (TransactionSynchronization synchronization : TransactionSynchronizationManager.getSynchronizations()) {
|
synchronization.afterCommit();
|
}
|
|
verify(aiConversationQueryService).evictConversationCaches(8L, 7L);
|
verify(aiMemoryProfileService).scheduleMemoryProfileRefresh(10L, 7L, 8L);
|
} finally {
|
TransactionSynchronizationManager.clearSynchronization();
|
}
|
}
|
|
@Test
|
void shouldSoftDeleteOnlyOlderMessagesWhenRetainingLatestRound() {
|
when(aiConversationQueryService.listMessageRecords(10L)).thenReturn(List.of(
|
record(1L, "user"),
|
record(2L, "assistant"),
|
record(3L, "user"),
|
record(4L, "assistant")
|
));
|
when(aiConversationQueryService.tailMessageRecordsByRounds(any(), eq(1))).thenReturn(List.of(
|
record(3L, "user"),
|
record(4L, "assistant")
|
));
|
|
aiConversationCommandService.retainLatestRound(7L, 8L, 10L);
|
|
verify(aiChatMessageMapper).softDeleteByIds(List.of(1L, 2L));
|
verify(aiMemoryProfileService).scheduleMemoryProfileRefresh(10L, 7L, 8L);
|
}
|
|
private AiChatMessageDto message(String role, String content) {
|
AiChatMessageDto dto = new AiChatMessageDto();
|
dto.setRole(role);
|
dto.setContent(content);
|
return dto;
|
}
|
|
private AiChatMessage record(Long id, String role) {
|
return new AiChatMessage().setId(id).setRole(role);
|
}
|
}
|