package com.vincent.rsf.server.ai.service.impl.chat; import com.vincent.rsf.framework.common.Cools; import com.vincent.rsf.framework.exception.CoolException; import com.vincent.rsf.server.ai.dto.AiChatMemoryDto; import com.vincent.rsf.server.ai.dto.AiChatMessageDto; import com.vincent.rsf.server.ai.entity.AiPrompt; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.stereotype.Component; import org.springframework.util.StringUtils; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Objects; @Component public class AiPromptMessageBuilder { public List buildPromptMessages(AiChatMemoryDto memory, List sourceMessages, AiPrompt aiPrompt, Map metadata) { if (Cools.isEmpty(sourceMessages)) { throw new CoolException("对话消息不能为空"); } List messages = new ArrayList<>(); if (StringUtils.hasText(aiPrompt.getSystemPrompt())) { messages.add(new SystemMessage(aiPrompt.getSystemPrompt())); } if (memory != null && StringUtils.hasText(memory.getMemorySummary())) { messages.add(new SystemMessage("历史摘要:\n" + memory.getMemorySummary())); } if (memory != null && StringUtils.hasText(memory.getMemoryFacts())) { messages.add(new SystemMessage("关键事实:\n" + memory.getMemoryFacts())); } int lastUserIndex = -1; for (int i = 0; i < sourceMessages.size(); i++) { AiChatMessageDto item = sourceMessages.get(i); if (item != null && "user".equalsIgnoreCase(item.getRole())) { lastUserIndex = i; } } for (int i = 0; i < sourceMessages.size(); i++) { AiChatMessageDto item = sourceMessages.get(i); if (item == null || !StringUtils.hasText(item.getContent())) { continue; } String role = item.getRole() == null ? "user" : item.getRole().toLowerCase(); if ("system".equals(role)) { continue; } String content = item.getContent(); if ("user".equals(role) && i == lastUserIndex) { content = renderUserPrompt(aiPrompt.getUserPromptTemplate(), content, metadata); } if ("assistant".equals(role)) { messages.add(new AssistantMessage(content)); } else { messages.add(new UserMessage(content)); } } if (messages.stream().noneMatch(item -> item instanceof UserMessage)) { throw new CoolException("至少需要一条用户消息"); } return messages; } public List mergeMessages(List persistedMessages, List memoryMessages) { List merged = new ArrayList<>(); if (!Cools.isEmpty(persistedMessages)) { merged.addAll(persistedMessages); } if (!Cools.isEmpty(memoryMessages)) { merged.addAll(memoryMessages); } if (merged.isEmpty()) { throw new CoolException("对话消息不能为空"); } return merged; } public String resolveTitleSeed(List messages) { if (Cools.isEmpty(messages)) { throw new CoolException("对话消息不能为空"); } for (int i = messages.size() - 1; i >= 0; i--) { AiChatMessageDto item = messages.get(i); if (item != null && "user".equalsIgnoreCase(item.getRole()) && StringUtils.hasText(item.getContent())) { return item.getContent(); } } throw new CoolException("至少需要一条用户消息"); } private String renderUserPrompt(String userPromptTemplate, String content, Map metadata) { if (!StringUtils.hasText(userPromptTemplate)) { return content; } String rendered = userPromptTemplate .replace("{{input}}", content) .replace("{input}", content); if (metadata != null) { for (Map.Entry entry : metadata.entrySet()) { String value = entry.getValue() == null ? "" : String.valueOf(entry.getValue()); rendered = rendered.replace("{{" + entry.getKey() + "}}", value); rendered = rendered.replace("{" + entry.getKey() + "}", value); } } if (Objects.equals(rendered, userPromptTemplate)) { return userPromptTemplate + "\n\n" + content; } return rendered; } }