| | |
| | | import com.vincent.rsf.server.ai.dto.AiChatMessageDto; |
| | | import com.vincent.rsf.server.ai.entity.AiPrompt; |
| | | import org.junit.jupiter.api.Test; |
| | | import org.springframework.ai.chat.messages.AssistantMessage; |
| | | import org.springframework.ai.chat.messages.Message; |
| | | import org.springframework.ai.chat.messages.SystemMessage; |
| | | import org.springframework.ai.chat.messages.UserMessage; |
| | | |
| | | import java.util.List; |
| | | import java.util.Map; |
| | | |
| | | import static org.junit.jupiter.api.Assertions.assertEquals; |
| | | import static org.junit.jupiter.api.Assertions.assertInstanceOf; |
| | | import static org.assertj.core.api.Assertions.assertThat; |
| | | |
| | | class AiPromptMessageBuilderTest { |
| | | |
| | | private final AiPromptMessageBuilder builder = new AiPromptMessageBuilder(); |
| | | private final AiPromptMessageBuilder aiPromptMessageBuilder = new AiPromptMessageBuilder(); |
| | | |
| | | @Test |
| | | void shouldBuildPromptMessagesInExpectedOrderAndRenderLastUserPrompt() { |
| | | void shouldMergeAllSystemContextIntoSingleLeadingSystemMessage() { |
| | | AiChatMemoryDto memory = AiChatMemoryDto.builder() |
| | | .memorySummary("summary") |
| | | .memoryFacts("facts") |
| | | .memorySummary("这是摘要") |
| | | .memoryFacts("这是事实") |
| | | .build(); |
| | | AiPrompt prompt = new AiPrompt() |
| | | .setSystemPrompt("system") |
| | | .setUserPromptTemplate("用户问题: {{input}} | 仓库: {{warehouse}}"); |
| | | List<AiChatMessageDto> messages = List.of( |
| | | message("user", "old question"), |
| | | message("assistant", "old answer"), |
| | | message("user", "latest question") |
| | | .setSystemPrompt("你是助手") |
| | | .setUserPromptTemplate("请回答:{{input}}"); |
| | | |
| | | List<Message> messages = aiPromptMessageBuilder.buildPromptMessages( |
| | | memory, |
| | | List.of( |
| | | message("user", "第一问"), |
| | | message("assistant", "第一答"), |
| | | message("user", "第二问") |
| | | ), |
| | | prompt, |
| | | null |
| | | ); |
| | | |
| | | List<Message> built = builder.buildPromptMessages(memory, messages, prompt, Map.of("warehouse", "WH1")); |
| | | |
| | | assertEquals(6, built.size()); |
| | | assertInstanceOf(SystemMessage.class, built.get(0)); |
| | | assertEquals("system", built.get(0).getText()); |
| | | assertEquals("历史摘要:\nsummary", built.get(1).getText()); |
| | | assertEquals("关键事实:\nfacts", built.get(2).getText()); |
| | | assertInstanceOf(UserMessage.class, built.get(3)); |
| | | assertEquals("old question", built.get(3).getText()); |
| | | assertEquals("old answer", built.get(4).getText()); |
| | | assertInstanceOf(UserMessage.class, built.get(5)); |
| | | assertEquals("用户问题: latest question | 仓库: WH1", built.get(5).getText()); |
| | | assertThat(messages).hasSize(4); |
| | | assertThat(messages.get(0)).isInstanceOf(SystemMessage.class); |
| | | assertThat(messages.get(1)).isInstanceOf(UserMessage.class); |
| | | assertThat(messages.get(2)).isInstanceOf(AssistantMessage.class); |
| | | assertThat(messages.get(3)).isInstanceOf(UserMessage.class); |
| | | assertThat(((SystemMessage) messages.get(0)).getText()) |
| | | .contains("你是助手") |
| | | .contains("历史摘要:\n这是摘要") |
| | | .contains("关键事实:\n这是事实"); |
| | | assertThat(((UserMessage) messages.get(3)).getText()).isEqualTo("请回答:第二问"); |
| | | } |
| | | |
| | | @Test |
| | | void shouldMergePersistedAndMemoryMessages() { |
| | | List<AiChatMessageDto> merged = builder.mergeMessages( |
| | | List<AiChatMessageDto> merged = aiPromptMessageBuilder.mergeMessages( |
| | | List.of(message("user", "persisted")), |
| | | List.of(message("assistant", "memory")) |
| | | ); |
| | | |
| | | assertEquals(2, merged.size()); |
| | | assertEquals("persisted", merged.get(0).getContent()); |
| | | assertEquals("memory", merged.get(1).getContent()); |
| | | assertThat(merged).hasSize(2); |
| | | assertThat(merged.get(0).getContent()).isEqualTo("persisted"); |
| | | assertThat(merged.get(1).getContent()).isEqualTo("memory"); |
| | | } |
| | | |
| | | private AiChatMessageDto message(String role, String content) { |