package com.vincent.rsf.server.ai.service.impl.chat;
|
|
import com.vincent.rsf.framework.common.Cools;
|
import com.vincent.rsf.framework.exception.CoolException;
|
import com.vincent.rsf.server.ai.dto.AiChatMemoryDto;
|
import com.vincent.rsf.server.ai.dto.AiChatMessageDto;
|
import com.vincent.rsf.server.ai.entity.AiPrompt;
|
import org.springframework.ai.chat.messages.AssistantMessage;
|
import org.springframework.ai.chat.messages.Message;
|
import org.springframework.ai.chat.messages.SystemMessage;
|
import org.springframework.ai.chat.messages.UserMessage;
|
import org.springframework.stereotype.Component;
|
import org.springframework.util.StringUtils;
|
|
import java.util.ArrayList;
|
import java.util.List;
|
import java.util.Map;
|
import java.util.Objects;
|
|
@Component
|
public class AiPromptMessageBuilder {
|
|
public List<Message> buildPromptMessages(AiChatMemoryDto memory, List<AiChatMessageDto> sourceMessages, AiPrompt aiPrompt,
|
Map<String, Object> metadata) {
|
if (Cools.isEmpty(sourceMessages)) {
|
throw new CoolException("对话消息不能为空");
|
}
|
List<Message> messages = new ArrayList<>();
|
if (StringUtils.hasText(aiPrompt.getSystemPrompt())) {
|
messages.add(new SystemMessage(aiPrompt.getSystemPrompt()));
|
}
|
if (memory != null && StringUtils.hasText(memory.getMemorySummary())) {
|
messages.add(new SystemMessage("历史摘要:\n" + memory.getMemorySummary()));
|
}
|
if (memory != null && StringUtils.hasText(memory.getMemoryFacts())) {
|
messages.add(new SystemMessage("关键事实:\n" + memory.getMemoryFacts()));
|
}
|
int lastUserIndex = -1;
|
for (int i = 0; i < sourceMessages.size(); i++) {
|
AiChatMessageDto item = sourceMessages.get(i);
|
if (item != null && "user".equalsIgnoreCase(item.getRole())) {
|
lastUserIndex = i;
|
}
|
}
|
for (int i = 0; i < sourceMessages.size(); i++) {
|
AiChatMessageDto item = sourceMessages.get(i);
|
if (item == null || !StringUtils.hasText(item.getContent())) {
|
continue;
|
}
|
String role = item.getRole() == null ? "user" : item.getRole().toLowerCase();
|
if ("system".equals(role)) {
|
continue;
|
}
|
String content = item.getContent();
|
if ("user".equals(role) && i == lastUserIndex) {
|
content = renderUserPrompt(aiPrompt.getUserPromptTemplate(), content, metadata);
|
}
|
if ("assistant".equals(role)) {
|
messages.add(new AssistantMessage(content));
|
} else {
|
messages.add(new UserMessage(content));
|
}
|
}
|
if (messages.stream().noneMatch(item -> item instanceof UserMessage)) {
|
throw new CoolException("至少需要一条用户消息");
|
}
|
return messages;
|
}
|
|
public List<AiChatMessageDto> mergeMessages(List<AiChatMessageDto> persistedMessages, List<AiChatMessageDto> memoryMessages) {
|
List<AiChatMessageDto> merged = new ArrayList<>();
|
if (!Cools.isEmpty(persistedMessages)) {
|
merged.addAll(persistedMessages);
|
}
|
if (!Cools.isEmpty(memoryMessages)) {
|
merged.addAll(memoryMessages);
|
}
|
if (merged.isEmpty()) {
|
throw new CoolException("对话消息不能为空");
|
}
|
return merged;
|
}
|
|
public String resolveTitleSeed(List<AiChatMessageDto> messages) {
|
if (Cools.isEmpty(messages)) {
|
throw new CoolException("对话消息不能为空");
|
}
|
for (int i = messages.size() - 1; i >= 0; i--) {
|
AiChatMessageDto item = messages.get(i);
|
if (item != null && "user".equalsIgnoreCase(item.getRole()) && StringUtils.hasText(item.getContent())) {
|
return item.getContent();
|
}
|
}
|
throw new CoolException("至少需要一条用户消息");
|
}
|
|
private String renderUserPrompt(String userPromptTemplate, String content, Map<String, Object> metadata) {
|
if (!StringUtils.hasText(userPromptTemplate)) {
|
return content;
|
}
|
String rendered = userPromptTemplate
|
.replace("{{input}}", content)
|
.replace("{input}", content);
|
if (metadata != null) {
|
for (Map.Entry<String, Object> entry : metadata.entrySet()) {
|
String value = entry.getValue() == null ? "" : String.valueOf(entry.getValue());
|
rendered = rendered.replace("{{" + entry.getKey() + "}}", value);
|
rendered = rendered.replace("{" + entry.getKey() + "}", value);
|
}
|
}
|
if (Objects.equals(rendered, userPromptTemplate)) {
|
return userPromptTemplate + "\n\n" + content;
|
}
|
return rendered;
|
}
|
}
|