zhou zhou
11 小时以前 3d81df739dc45599c257d8cdefe0996f66ccdeae
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatMemoryServiceImpl.java
@@ -3,8 +3,11 @@
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.vincent.rsf.framework.common.Cools;
import com.vincent.rsf.framework.exception.CoolException;
import com.vincent.rsf.server.ai.config.AiDefaults;
import com.vincent.rsf.server.ai.dto.AiChatMemoryDto;
import com.vincent.rsf.server.ai.dto.AiChatMessageDto;
import com.vincent.rsf.server.ai.dto.AiChatSessionPinRequest;
import com.vincent.rsf.server.ai.dto.AiChatSessionRenameRequest;
import com.vincent.rsf.server.ai.dto.AiChatSessionDto;
import com.vincent.rsf.server.ai.entity.AiChatMessage;
import com.vincent.rsf.server.ai.entity.AiChatSession;
@@ -37,17 +40,27 @@
        if (session == null) {
            return AiChatMemoryDto.builder()
                    .sessionId(null)
                    .memorySummary(null)
                    .memoryFacts(null)
                    .recentMessageCount(0)
                    .persistedMessages(List.of())
                    .shortMemoryMessages(List.of())
                    .build();
        }
        List<AiChatMessageDto> persistedMessages = listMessages(session.getId());
        List<AiChatMessageDto> shortMemoryMessages = tailMessagesByRounds(persistedMessages, AiDefaults.MEMORY_RECENT_ROUNDS);
        return AiChatMemoryDto.builder()
                .sessionId(session.getId())
                .persistedMessages(listMessages(session.getId()))
                .memorySummary(session.getMemorySummary())
                .memoryFacts(session.getMemoryFacts())
                .recentMessageCount(shortMemoryMessages.size())
                .persistedMessages(persistedMessages)
                .shortMemoryMessages(shortMemoryMessages)
                .build();
    }
    @Override
    public List<AiChatSessionDto> listSessions(Long userId, Long tenantId, String promptCode) {
    public List<AiChatSessionDto> listSessions(Long userId, Long tenantId, String promptCode, String keyword) {
        ensureIdentity(userId, tenantId);
        String resolvedPromptCode = requirePromptCode(promptCode);
        List<AiChatSession> sessions = aiChatSessionMapper.selectList(new LambdaQueryWrapper<AiChatSession>()
@@ -56,6 +69,8 @@
                .eq(AiChatSession::getPromptCode, resolvedPromptCode)
                .eq(AiChatSession::getDeleted, 0)
                .eq(AiChatSession::getStatus, StatusType.ENABLE.val)
                .like(StringUtils.hasText(keyword), AiChatSession::getTitle, keyword == null ? null : keyword.trim())
                .orderByDesc(AiChatSession::getPinned)
                .orderByDesc(AiChatSession::getLastMessageTime)
                .orderByDesc(AiChatSession::getId));
        if (Cools.isEmpty(sessions)) {
@@ -63,12 +78,7 @@
        }
        List<AiChatSessionDto> result = new ArrayList<>();
        for (AiChatSession session : sessions) {
            result.add(AiChatSessionDto.builder()
                    .sessionId(session.getId())
                    .title(session.getTitle())
                    .promptCode(session.getPromptCode())
                    .lastMessageTime(session.getLastMessageTime())
                    .build());
            result.add(buildSessionDto(session));
        }
        return result;
    }
@@ -87,6 +97,7 @@
                .setUserId(userId)
                .setTenantId(tenantId)
                .setLastMessageTime(now)
                .setPinned(0)
                .setStatus(StatusType.ENABLE.val)
                .setDeleted(0)
                .setCreateBy(userId)
@@ -122,6 +133,7 @@
                .setUpdateBy(userId)
                .setUpdateTime(now);
        aiChatSessionMapper.updateById(update);
        refreshMemoryProfile(session.getId(), userId);
    }
    @Override
@@ -157,6 +169,81 @@
        }
    }
    @Override
    public AiChatSessionDto renameSession(Long userId, Long tenantId, Long sessionId, AiChatSessionRenameRequest request) {
        ensureIdentity(userId, tenantId);
        if (request == null || !StringUtils.hasText(request.getTitle())) {
            throw new CoolException("会话标题不能为空");
        }
        AiChatSession session = requireOwnedSession(sessionId, userId, tenantId);
        Date now = new Date();
        AiChatSession update = new AiChatSession()
                .setId(sessionId)
                .setTitle(buildSessionTitle(request.getTitle()))
                .setUpdateBy(userId)
                .setUpdateTime(now);
        aiChatSessionMapper.updateById(update);
        return buildSessionDto(requireOwnedSession(sessionId, userId, tenantId));
    }
    @Override
    public AiChatSessionDto pinSession(Long userId, Long tenantId, Long sessionId, AiChatSessionPinRequest request) {
        ensureIdentity(userId, tenantId);
        if (request == null || request.getPinned() == null) {
            throw new CoolException("置顶状态不能为空");
        }
        AiChatSession session = requireOwnedSession(sessionId, userId, tenantId);
        Date now = new Date();
        AiChatSession update = new AiChatSession()
                .setId(sessionId)
                .setPinned(Boolean.TRUE.equals(request.getPinned()) ? 1 : 0)
                .setUpdateBy(userId)
                .setUpdateTime(now);
        aiChatSessionMapper.updateById(update);
        return buildSessionDto(requireOwnedSession(sessionId, userId, tenantId));
    }
    @Override
    public void clearSessionMemory(Long userId, Long tenantId, Long sessionId) {
        ensureIdentity(userId, tenantId);
        AiChatSession session = requireOwnedSession(sessionId, userId, tenantId);
        List<AiChatMessage> messages = aiChatMessageMapper.selectList(new LambdaQueryWrapper<AiChatMessage>()
                .eq(AiChatMessage::getSessionId, sessionId)
                .eq(AiChatMessage::getDeleted, 0));
        for (AiChatMessage message : messages) {
            aiChatMessageMapper.updateById(new AiChatMessage()
                    .setId(message.getId())
                    .setDeleted(1));
        }
        aiChatSessionMapper.updateById(new AiChatSession()
                .setId(sessionId)
                .setMemorySummary(null)
                .setMemoryFacts(null)
                .setUpdateBy(userId)
                .setUpdateTime(new Date())
                .setLastMessageTime(session.getCreateTime()));
    }
    @Override
    public void retainLatestRound(Long userId, Long tenantId, Long sessionId) {
        ensureIdentity(userId, tenantId);
        requireOwnedSession(sessionId, userId, tenantId);
        List<AiChatMessage> records = listMessageRecords(sessionId);
        if (records.isEmpty()) {
            return;
        }
        List<AiChatMessage> retained = tailMessageRecordsByRounds(records, 1);
        for (AiChatMessage message : records) {
            boolean shouldKeep = retained.stream().anyMatch(item -> item.getId().equals(message.getId()));
            if (!shouldKeep) {
                aiChatMessageMapper.updateById(new AiChatMessage()
                        .setId(message.getId())
                        .setDeleted(1));
            }
        }
        refreshMemoryProfile(sessionId, userId);
    }
    private AiChatSession findLatestSession(Long userId, Long tenantId, String promptCode) {
        return aiChatSessionMapper.selectOne(new LambdaQueryWrapper<AiChatSession>()
                .eq(AiChatSession::getUserId, userId)
@@ -184,12 +271,25 @@
        return session;
    }
    private AiChatSession requireOwnedSession(Long sessionId, Long userId, Long tenantId) {
        if (sessionId == null) {
            throw new CoolException("AI 会话 ID 不能为空");
        }
        AiChatSession session = aiChatSessionMapper.selectOne(new LambdaQueryWrapper<AiChatSession>()
                .eq(AiChatSession::getId, sessionId)
                .eq(AiChatSession::getUserId, userId)
                .eq(AiChatSession::getTenantId, tenantId)
                .eq(AiChatSession::getDeleted, 0)
                .eq(AiChatSession::getStatus, StatusType.ENABLE.val)
                .last("limit 1"));
        if (session == null) {
            throw new CoolException("AI 会话不存在或无权访问");
        }
        return session;
    }
    private List<AiChatMessageDto> listMessages(Long sessionId) {
        List<AiChatMessage> records = aiChatMessageMapper.selectList(new LambdaQueryWrapper<AiChatMessage>()
                .eq(AiChatMessage::getSessionId, sessionId)
                .eq(AiChatMessage::getDeleted, 0)
                .orderByAsc(AiChatMessage::getSeqNo)
                .orderByAsc(AiChatMessage::getId));
        List<AiChatMessage> records = listMessageRecords(sessionId);
        if (Cools.isEmpty(records)) {
            return List.of();
        }
@@ -227,6 +327,14 @@
        return normalized;
    }
    private List<AiChatMessage> listMessageRecords(Long sessionId) {
        return aiChatMessageMapper.selectList(new LambdaQueryWrapper<AiChatMessage>()
                .eq(AiChatMessage::getSessionId, sessionId)
                .eq(AiChatMessage::getDeleted, 0)
                .orderByAsc(AiChatMessage::getSeqNo)
                .orderByAsc(AiChatMessage::getId));
    }
    private int findNextSeqNo(Long sessionId) {
        AiChatMessage lastMessage = aiChatMessageMapper.selectOne(new LambdaQueryWrapper<AiChatMessage>()
                .eq(AiChatMessage::getSessionId, sessionId)
@@ -243,6 +351,7 @@
                .setSeqNo(seqNo)
                .setRole(role)
                .setContent(content)
                .setContentLength(content == null ? 0 : content.length())
                .setUserId(userId)
                .setTenantId(tenantId)
                .setDeleted(0)
@@ -266,8 +375,167 @@
        if (!StringUtils.hasText(titleSeed)) {
            throw new CoolException("AI 会话标题不能为空");
        }
        String title = titleSeed.trim().replace("\r", " ").replace("\n", " ");
        return title.length() > 60 ? title.substring(0, 60) : title;
        String title = titleSeed.trim()
                .replace("\r", " ")
                .replace("\n", " ")
                .replaceAll("\\s+", " ");
        int punctuationIndex = findSummaryBreakIndex(title);
        if (punctuationIndex > 0) {
            title = title.substring(0, punctuationIndex).trim();
        }
        return title.length() > 48 ? title.substring(0, 48) : title;
    }
    private int findSummaryBreakIndex(String title) {
        String[] separators = {"。", "!", "?", ".", "!", "?"};
        int result = -1;
        for (String separator : separators) {
            int index = title.indexOf(separator);
            if (index > 0 && (result < 0 || index < result)) {
                result = index;
            }
        }
        return result;
    }
    private AiChatSessionDto buildSessionDto(AiChatSession session) {
        AiChatMessage lastMessage = aiChatMessageMapper.selectOne(new LambdaQueryWrapper<AiChatMessage>()
                .eq(AiChatMessage::getSessionId, session.getId())
                .eq(AiChatMessage::getDeleted, 0)
                .orderByDesc(AiChatMessage::getSeqNo)
                .orderByDesc(AiChatMessage::getId)
                .last("limit 1"));
        return AiChatSessionDto.builder()
                .sessionId(session.getId())
                .title(session.getTitle())
                .promptCode(session.getPromptCode())
                .pinned(session.getPinned() != null && session.getPinned() == 1)
                .lastMessagePreview(buildLastMessagePreview(lastMessage))
                .lastMessageTime(session.getLastMessageTime())
                .build();
    }
    private String buildLastMessagePreview(AiChatMessage message) {
        if (message == null || !StringUtils.hasText(message.getContent())) {
            return null;
        }
        String preview = message.getContent().trim()
                .replace("\r", " ")
                .replace("\n", " ")
                .replaceAll("\\s+", " ");
        String prefix = "assistant".equalsIgnoreCase(message.getRole()) ? "AI: " : "你: ";
        String normalized = prefix + preview;
        return normalized.length() > 80 ? normalized.substring(0, 80) : normalized;
    }
    private void refreshMemoryProfile(Long sessionId, Long userId) {
        List<AiChatMessageDto> messages = listMessages(sessionId);
        List<AiChatMessageDto> shortMemoryMessages = tailMessagesByRounds(messages, AiDefaults.MEMORY_RECENT_ROUNDS);
        List<AiChatMessageDto> historyMessages = messages.size() > shortMemoryMessages.size()
                ? messages.subList(0, messages.size() - shortMemoryMessages.size())
                : List.of();
        String memorySummary = historyMessages.size() >= AiDefaults.MEMORY_SUMMARY_TRIGGER_MESSAGES
                ? buildMemorySummary(historyMessages)
                : null;
        String memoryFacts = buildMemoryFacts(messages);
        AiChatMessage lastMessage = aiChatMessageMapper.selectOne(new LambdaQueryWrapper<AiChatMessage>()
                .eq(AiChatMessage::getSessionId, sessionId)
                .eq(AiChatMessage::getDeleted, 0)
                .orderByDesc(AiChatMessage::getSeqNo)
                .orderByDesc(AiChatMessage::getId)
                .last("limit 1"));
        aiChatSessionMapper.updateById(new AiChatSession()
                .setId(sessionId)
                .setMemorySummary(memorySummary)
                .setMemoryFacts(memoryFacts)
                .setLastMessageTime(lastMessage == null ? null : lastMessage.getCreateTime())
                .setUpdateBy(userId)
                .setUpdateTime(new Date()));
    }
    private List<AiChatMessageDto> tailMessagesByRounds(List<AiChatMessageDto> source, int rounds) {
        if (Cools.isEmpty(source) || rounds <= 0) {
            return List.of();
        }
        int userCount = 0;
        int startIndex = source.size();
        for (int i = source.size() - 1; i >= 0; i--) {
            AiChatMessageDto item = source.get(i);
            startIndex = i;
            if (item != null && "user".equalsIgnoreCase(item.getRole())) {
                userCount++;
                if (userCount >= rounds) {
                    break;
                }
            }
        }
        return new ArrayList<>(source.subList(Math.max(0, startIndex), source.size()));
    }
    private List<AiChatMessage> tailMessageRecordsByRounds(List<AiChatMessage> source, int rounds) {
        if (Cools.isEmpty(source) || rounds <= 0) {
            return List.of();
        }
        int userCount = 0;
        int startIndex = source.size();
        for (int i = source.size() - 1; i >= 0; i--) {
            AiChatMessage item = source.get(i);
            startIndex = i;
            if (item != null && "user".equalsIgnoreCase(item.getRole())) {
                userCount++;
                if (userCount >= rounds) {
                    break;
                }
            }
        }
        return new ArrayList<>(source.subList(Math.max(0, startIndex), source.size()));
    }
    private String buildMemorySummary(List<AiChatMessageDto> historyMessages) {
        StringBuilder builder = new StringBuilder("较早对话摘要:\n");
        for (AiChatMessageDto item : historyMessages) {
            if (item == null || !StringUtils.hasText(item.getContent())) {
                continue;
            }
            String prefix = "assistant".equalsIgnoreCase(item.getRole()) ? "- AI: " : "- 用户: ";
            String content = compactText(item.getContent(), 120);
            if (!StringUtils.hasText(content)) {
                continue;
            }
            builder.append(prefix).append(content).append("\n");
            if (builder.length() >= AiDefaults.MEMORY_SUMMARY_MAX_LENGTH) {
                break;
            }
        }
        return compactText(builder.toString(), AiDefaults.MEMORY_SUMMARY_MAX_LENGTH);
    }
    private String buildMemoryFacts(List<AiChatMessageDto> messages) {
        if (Cools.isEmpty(messages)) {
            return null;
        }
        StringBuilder builder = new StringBuilder("关键事实:\n");
        int userFacts = 0;
        for (int i = messages.size() - 1; i >= 0 && userFacts < 4; i--) {
            AiChatMessageDto item = messages.get(i);
            if (item == null || !"user".equalsIgnoreCase(item.getRole()) || !StringUtils.hasText(item.getContent())) {
                continue;
            }
            builder.append("- 用户关注: ").append(compactText(item.getContent(), 100)).append("\n");
            userFacts++;
        }
        return userFacts == 0 ? null : compactText(builder.toString(), AiDefaults.MEMORY_FACTS_MAX_LENGTH);
    }
    private String compactText(String content, int maxLength) {
        if (!StringUtils.hasText(content)) {
            return null;
        }
        String normalized = content.trim()
                .replace("\r", " ")
                .replace("\n", " ")
                .replaceAll("\\s+", " ");
        return normalized.length() > maxLength ? normalized.substring(0, maxLength) : normalized;
    }
    private void ensureIdentity(Long userId, Long tenantId) {