zhou zhou
8 小时以前 82624affb0251b75b62b35567d3eb260c06efe78
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
package com.vincent.rsf.server.ai.service.impl.chat;
 
import com.vincent.rsf.framework.common.Cools;
import com.vincent.rsf.framework.exception.CoolException;
import com.vincent.rsf.server.ai.dto.AiChatMemoryDto;
import com.vincent.rsf.server.ai.dto.AiChatMessageDto;
import com.vincent.rsf.server.ai.entity.AiPrompt;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
 
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
 
@Component
public class AiPromptMessageBuilder {
 
    public List<Message> buildPromptMessages(AiChatMemoryDto memory, List<AiChatMessageDto> sourceMessages, AiPrompt aiPrompt,
                                             Map<String, Object> metadata) {
        if (Cools.isEmpty(sourceMessages)) {
            throw new CoolException("对话消息不能为空");
        }
        List<Message> messages = new ArrayList<>();
        if (StringUtils.hasText(aiPrompt.getSystemPrompt())) {
            messages.add(new SystemMessage(aiPrompt.getSystemPrompt()));
        }
        if (memory != null && StringUtils.hasText(memory.getMemorySummary())) {
            messages.add(new SystemMessage("历史摘要:\n" + memory.getMemorySummary()));
        }
        if (memory != null && StringUtils.hasText(memory.getMemoryFacts())) {
            messages.add(new SystemMessage("关键事实:\n" + memory.getMemoryFacts()));
        }
        int lastUserIndex = -1;
        for (int i = 0; i < sourceMessages.size(); i++) {
            AiChatMessageDto item = sourceMessages.get(i);
            if (item != null && "user".equalsIgnoreCase(item.getRole())) {
                lastUserIndex = i;
            }
        }
        for (int i = 0; i < sourceMessages.size(); i++) {
            AiChatMessageDto item = sourceMessages.get(i);
            if (item == null || !StringUtils.hasText(item.getContent())) {
                continue;
            }
            String role = item.getRole() == null ? "user" : item.getRole().toLowerCase();
            if ("system".equals(role)) {
                continue;
            }
            String content = item.getContent();
            if ("user".equals(role) && i == lastUserIndex) {
                content = renderUserPrompt(aiPrompt.getUserPromptTemplate(), content, metadata);
            }
            if ("assistant".equals(role)) {
                messages.add(new AssistantMessage(content));
            } else {
                messages.add(new UserMessage(content));
            }
        }
        if (messages.stream().noneMatch(item -> item instanceof UserMessage)) {
            throw new CoolException("至少需要一条用户消息");
        }
        return messages;
    }
 
    public List<AiChatMessageDto> mergeMessages(List<AiChatMessageDto> persistedMessages, List<AiChatMessageDto> memoryMessages) {
        List<AiChatMessageDto> merged = new ArrayList<>();
        if (!Cools.isEmpty(persistedMessages)) {
            merged.addAll(persistedMessages);
        }
        if (!Cools.isEmpty(memoryMessages)) {
            merged.addAll(memoryMessages);
        }
        if (merged.isEmpty()) {
            throw new CoolException("对话消息不能为空");
        }
        return merged;
    }
 
    public String resolveTitleSeed(List<AiChatMessageDto> messages) {
        if (Cools.isEmpty(messages)) {
            throw new CoolException("对话消息不能为空");
        }
        for (int i = messages.size() - 1; i >= 0; i--) {
            AiChatMessageDto item = messages.get(i);
            if (item != null && "user".equalsIgnoreCase(item.getRole()) && StringUtils.hasText(item.getContent())) {
                return item.getContent();
            }
        }
        throw new CoolException("至少需要一条用户消息");
    }
 
    private String renderUserPrompt(String userPromptTemplate, String content, Map<String, Object> metadata) {
        if (!StringUtils.hasText(userPromptTemplate)) {
            return content;
        }
        String rendered = userPromptTemplate
                .replace("{{input}}", content)
                .replace("{input}", content);
        if (metadata != null) {
            for (Map.Entry<String, Object> entry : metadata.entrySet()) {
                String value = entry.getValue() == null ? "" : String.valueOf(entry.getValue());
                rendered = rendered.replace("{{" + entry.getKey() + "}}", value);
                rendered = rendered.replace("{" + entry.getKey() + "}", value);
            }
        }
        if (Objects.equals(rendered, userPromptTemplate)) {
            return userPromptTemplate + "\n\n" + content;
        }
        return rendered;
    }
}