zhou zhou
18 小时以前 82624affb0251b75b62b35567d3eb260c06efe78
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
package com.vincent.rsf.server.ai.service.impl.chat;
 
import com.vincent.rsf.server.ai.dto.AiChatMemoryDto;
import com.vincent.rsf.server.ai.dto.AiChatMessageDto;
import com.vincent.rsf.server.ai.entity.AiPrompt;
import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
 
import java.util.List;
import java.util.Map;
 
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
 
class AiPromptMessageBuilderTest {
 
    private final AiPromptMessageBuilder builder = new AiPromptMessageBuilder();
 
    @Test
    void shouldBuildPromptMessagesInExpectedOrderAndRenderLastUserPrompt() {
        AiChatMemoryDto memory = AiChatMemoryDto.builder()
                .memorySummary("summary")
                .memoryFacts("facts")
                .build();
        AiPrompt prompt = new AiPrompt()
                .setSystemPrompt("system")
                .setUserPromptTemplate("用户问题: {{input}} | 仓库: {{warehouse}}");
        List<AiChatMessageDto> messages = List.of(
                message("user", "old question"),
                message("assistant", "old answer"),
                message("user", "latest question")
        );
 
        List<Message> built = builder.buildPromptMessages(memory, messages, prompt, Map.of("warehouse", "WH1"));
 
        assertEquals(6, built.size());
        assertInstanceOf(SystemMessage.class, built.get(0));
        assertEquals("system", built.get(0).getText());
        assertEquals("历史摘要:\nsummary", built.get(1).getText());
        assertEquals("关键事实:\nfacts", built.get(2).getText());
        assertInstanceOf(UserMessage.class, built.get(3));
        assertEquals("old question", built.get(3).getText());
        assertEquals("old answer", built.get(4).getText());
        assertInstanceOf(UserMessage.class, built.get(5));
        assertEquals("用户问题: latest question | 仓库: WH1", built.get(5).getText());
    }
 
    @Test
    void shouldMergePersistedAndMemoryMessages() {
        List<AiChatMessageDto> merged = builder.mergeMessages(
                List.of(message("user", "persisted")),
                List.of(message("assistant", "memory"))
        );
 
        assertEquals(2, merged.size());
        assertEquals("persisted", merged.get(0).getContent());
        assertEquals("memory", merged.get(1).getContent());
    }
 
    private AiChatMessageDto message(String role, String content) {
        AiChatMessageDto dto = new AiChatMessageDto();
        dto.setRole(role);
        dto.setContent(content);
        return dto;
    }
}