| | |
| | | import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; |
| | | import com.vincent.rsf.framework.common.Cools; |
| | | import com.vincent.rsf.framework.exception.CoolException; |
| | | import com.vincent.rsf.server.ai.config.AiDefaults; |
| | | import com.vincent.rsf.server.ai.dto.AiChatMemoryDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatMessageDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatSessionPinRequest; |
| | |
| | | import java.util.ArrayList; |
| | | import java.util.Date; |
| | | import java.util.List; |
| | | import java.util.Locale; |
| | | |
| | | @Service |
| | | @RequiredArgsConstructor |
| | |
| | | if (session == null) { |
| | | return AiChatMemoryDto.builder() |
| | | .sessionId(null) |
| | | .memorySummary(null) |
| | | .memoryFacts(null) |
| | | .recentMessageCount(0) |
| | | .persistedMessages(List.of()) |
| | | .shortMemoryMessages(List.of()) |
| | | .build(); |
| | | } |
| | | List<AiChatMessageDto> persistedMessages = listMessages(session.getId()); |
| | | List<AiChatMessageDto> shortMemoryMessages = tailMessagesByRounds(persistedMessages, AiDefaults.MEMORY_RECENT_ROUNDS); |
| | | return AiChatMemoryDto.builder() |
| | | .sessionId(session.getId()) |
| | | .persistedMessages(listMessages(session.getId())) |
| | | .memorySummary(session.getMemorySummary()) |
| | | .memoryFacts(session.getMemoryFacts()) |
| | | .recentMessageCount(shortMemoryMessages.size()) |
| | | .persistedMessages(persistedMessages) |
| | | .shortMemoryMessages(shortMemoryMessages) |
| | | .build(); |
| | | } |
| | | |
| | |
| | | .setUpdateBy(userId) |
| | | .setUpdateTime(now); |
| | | aiChatSessionMapper.updateById(update); |
| | | refreshMemoryProfile(session.getId(), userId); |
| | | } |
| | | |
| | | @Override |
| | |
| | | return buildSessionDto(requireOwnedSession(sessionId, userId, tenantId)); |
| | | } |
| | | |
| | | @Override |
| | | public void clearSessionMemory(Long userId, Long tenantId, Long sessionId) { |
| | | ensureIdentity(userId, tenantId); |
| | | AiChatSession session = requireOwnedSession(sessionId, userId, tenantId); |
| | | List<AiChatMessage> messages = aiChatMessageMapper.selectList(new LambdaQueryWrapper<AiChatMessage>() |
| | | .eq(AiChatMessage::getSessionId, sessionId) |
| | | .eq(AiChatMessage::getDeleted, 0)); |
| | | for (AiChatMessage message : messages) { |
| | | aiChatMessageMapper.updateById(new AiChatMessage() |
| | | .setId(message.getId()) |
| | | .setDeleted(1)); |
| | | } |
| | | aiChatSessionMapper.updateById(new AiChatSession() |
| | | .setId(sessionId) |
| | | .setMemorySummary(null) |
| | | .setMemoryFacts(null) |
| | | .setUpdateBy(userId) |
| | | .setUpdateTime(new Date()) |
| | | .setLastMessageTime(session.getCreateTime())); |
| | | } |
| | | |
| | | @Override |
| | | public void retainLatestRound(Long userId, Long tenantId, Long sessionId) { |
| | | ensureIdentity(userId, tenantId); |
| | | requireOwnedSession(sessionId, userId, tenantId); |
| | | List<AiChatMessage> records = listMessageRecords(sessionId); |
| | | if (records.isEmpty()) { |
| | | return; |
| | | } |
| | | List<AiChatMessage> retained = tailMessageRecordsByRounds(records, 1); |
| | | for (AiChatMessage message : records) { |
| | | boolean shouldKeep = retained.stream().anyMatch(item -> item.getId().equals(message.getId())); |
| | | if (!shouldKeep) { |
| | | aiChatMessageMapper.updateById(new AiChatMessage() |
| | | .setId(message.getId()) |
| | | .setDeleted(1)); |
| | | } |
| | | } |
| | | refreshMemoryProfile(sessionId, userId); |
| | | } |
| | | |
| | | private AiChatSession findLatestSession(Long userId, Long tenantId, String promptCode) { |
| | | return aiChatSessionMapper.selectOne(new LambdaQueryWrapper<AiChatSession>() |
| | | .eq(AiChatSession::getUserId, userId) |
| | |
| | | } |
| | | |
| | | private List<AiChatMessageDto> listMessages(Long sessionId) { |
| | | List<AiChatMessage> records = aiChatMessageMapper.selectList(new LambdaQueryWrapper<AiChatMessage>() |
| | | .eq(AiChatMessage::getSessionId, sessionId) |
| | | .eq(AiChatMessage::getDeleted, 0) |
| | | .orderByAsc(AiChatMessage::getSeqNo) |
| | | .orderByAsc(AiChatMessage::getId)); |
| | | List<AiChatMessage> records = listMessageRecords(sessionId); |
| | | if (Cools.isEmpty(records)) { |
| | | return List.of(); |
| | | } |
| | |
| | | return normalized; |
| | | } |
| | | |
| | | private List<AiChatMessage> listMessageRecords(Long sessionId) { |
| | | return aiChatMessageMapper.selectList(new LambdaQueryWrapper<AiChatMessage>() |
| | | .eq(AiChatMessage::getSessionId, sessionId) |
| | | .eq(AiChatMessage::getDeleted, 0) |
| | | .orderByAsc(AiChatMessage::getSeqNo) |
| | | .orderByAsc(AiChatMessage::getId)); |
| | | } |
| | | |
| | | private int findNextSeqNo(Long sessionId) { |
| | | AiChatMessage lastMessage = aiChatMessageMapper.selectOne(new LambdaQueryWrapper<AiChatMessage>() |
| | | .eq(AiChatMessage::getSessionId, sessionId) |
| | |
| | | .setSeqNo(seqNo) |
| | | .setRole(role) |
| | | .setContent(content) |
| | | .setContentLength(content == null ? 0 : content.length()) |
| | | .setUserId(userId) |
| | | .setTenantId(tenantId) |
| | | .setDeleted(0) |
| | |
| | | return normalized.length() > 80 ? normalized.substring(0, 80) : normalized; |
| | | } |
| | | |
| | | private void refreshMemoryProfile(Long sessionId, Long userId) { |
| | | List<AiChatMessageDto> messages = listMessages(sessionId); |
| | | List<AiChatMessageDto> shortMemoryMessages = tailMessagesByRounds(messages, AiDefaults.MEMORY_RECENT_ROUNDS); |
| | | List<AiChatMessageDto> historyMessages = messages.size() > shortMemoryMessages.size() |
| | | ? messages.subList(0, messages.size() - shortMemoryMessages.size()) |
| | | : List.of(); |
| | | String memorySummary = historyMessages.size() >= AiDefaults.MEMORY_SUMMARY_TRIGGER_MESSAGES |
| | | ? buildMemorySummary(historyMessages) |
| | | : null; |
| | | String memoryFacts = buildMemoryFacts(messages); |
| | | AiChatMessage lastMessage = aiChatMessageMapper.selectOne(new LambdaQueryWrapper<AiChatMessage>() |
| | | .eq(AiChatMessage::getSessionId, sessionId) |
| | | .eq(AiChatMessage::getDeleted, 0) |
| | | .orderByDesc(AiChatMessage::getSeqNo) |
| | | .orderByDesc(AiChatMessage::getId) |
| | | .last("limit 1")); |
| | | aiChatSessionMapper.updateById(new AiChatSession() |
| | | .setId(sessionId) |
| | | .setMemorySummary(memorySummary) |
| | | .setMemoryFacts(memoryFacts) |
| | | .setLastMessageTime(lastMessage == null ? null : lastMessage.getCreateTime()) |
| | | .setUpdateBy(userId) |
| | | .setUpdateTime(new Date())); |
| | | } |
| | | |
| | | private List<AiChatMessageDto> tailMessagesByRounds(List<AiChatMessageDto> source, int rounds) { |
| | | if (Cools.isEmpty(source) || rounds <= 0) { |
| | | return List.of(); |
| | | } |
| | | int userCount = 0; |
| | | int startIndex = source.size(); |
| | | for (int i = source.size() - 1; i >= 0; i--) { |
| | | AiChatMessageDto item = source.get(i); |
| | | startIndex = i; |
| | | if (item != null && "user".equalsIgnoreCase(item.getRole())) { |
| | | userCount++; |
| | | if (userCount >= rounds) { |
| | | break; |
| | | } |
| | | } |
| | | } |
| | | return new ArrayList<>(source.subList(Math.max(0, startIndex), source.size())); |
| | | } |
| | | |
| | | private List<AiChatMessage> tailMessageRecordsByRounds(List<AiChatMessage> source, int rounds) { |
| | | if (Cools.isEmpty(source) || rounds <= 0) { |
| | | return List.of(); |
| | | } |
| | | int userCount = 0; |
| | | int startIndex = source.size(); |
| | | for (int i = source.size() - 1; i >= 0; i--) { |
| | | AiChatMessage item = source.get(i); |
| | | startIndex = i; |
| | | if (item != null && "user".equalsIgnoreCase(item.getRole())) { |
| | | userCount++; |
| | | if (userCount >= rounds) { |
| | | break; |
| | | } |
| | | } |
| | | } |
| | | return new ArrayList<>(source.subList(Math.max(0, startIndex), source.size())); |
| | | } |
| | | |
| | | private String buildMemorySummary(List<AiChatMessageDto> historyMessages) { |
| | | StringBuilder builder = new StringBuilder("较早对话摘要:\n"); |
| | | for (AiChatMessageDto item : historyMessages) { |
| | | if (item == null || !StringUtils.hasText(item.getContent())) { |
| | | continue; |
| | | } |
| | | String prefix = "assistant".equalsIgnoreCase(item.getRole()) ? "- AI: " : "- 用户: "; |
| | | String content = compactText(item.getContent(), 120); |
| | | if (!StringUtils.hasText(content)) { |
| | | continue; |
| | | } |
| | | builder.append(prefix).append(content).append("\n"); |
| | | if (builder.length() >= AiDefaults.MEMORY_SUMMARY_MAX_LENGTH) { |
| | | break; |
| | | } |
| | | } |
| | | return compactText(builder.toString(), AiDefaults.MEMORY_SUMMARY_MAX_LENGTH); |
| | | } |
| | | |
| | | private String buildMemoryFacts(List<AiChatMessageDto> messages) { |
| | | if (Cools.isEmpty(messages)) { |
| | | return null; |
| | | } |
| | | StringBuilder builder = new StringBuilder("关键事实:\n"); |
| | | int userFacts = 0; |
| | | for (int i = messages.size() - 1; i >= 0 && userFacts < 4; i--) { |
| | | AiChatMessageDto item = messages.get(i); |
| | | if (item == null || !"user".equalsIgnoreCase(item.getRole()) || !StringUtils.hasText(item.getContent())) { |
| | | continue; |
| | | } |
| | | builder.append("- 用户关注: ").append(compactText(item.getContent(), 100)).append("\n"); |
| | | userFacts++; |
| | | } |
| | | return userFacts == 0 ? null : compactText(builder.toString(), AiDefaults.MEMORY_FACTS_MAX_LENGTH); |
| | | } |
| | | |
| | | private String compactText(String content, int maxLength) { |
| | | if (!StringUtils.hasText(content)) { |
| | | return null; |
| | | } |
| | | String normalized = content.trim() |
| | | .replace("\r", " ") |
| | | .replace("\n", " ") |
| | | .replaceAll("\\s+", " "); |
| | | return normalized.length() > maxLength ? normalized.substring(0, maxLength) : normalized; |
| | | } |
| | | |
| | | private void ensureIdentity(Long userId, Long tenantId) { |
| | | if (userId == null) { |
| | | throw new CoolException("当前登录用户不存在"); |