zhou zhou
13 小时以前 82624affb0251b75b62b35567d3eb260c06efe78
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatMemoryServiceImpl.java
@@ -1,389 +1,68 @@
package com.vincent.rsf.server.ai.service.impl;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.vincent.rsf.framework.common.Cools;
import com.vincent.rsf.framework.exception.CoolException;
import com.vincent.rsf.server.ai.dto.AiChatMemoryDto;
import com.vincent.rsf.server.ai.dto.AiChatMessageDto;
import com.vincent.rsf.server.ai.dto.AiChatSessionDto;
import com.vincent.rsf.server.ai.dto.AiChatSessionPinRequest;
import com.vincent.rsf.server.ai.dto.AiChatSessionRenameRequest;
import com.vincent.rsf.server.ai.dto.AiChatSessionDto;
import com.vincent.rsf.server.ai.entity.AiChatMessage;
import com.vincent.rsf.server.ai.entity.AiChatSession;
import com.vincent.rsf.server.ai.mapper.AiChatMessageMapper;
import com.vincent.rsf.server.ai.mapper.AiChatSessionMapper;
import com.vincent.rsf.server.ai.service.AiChatMemoryService;
import com.vincent.rsf.server.system.enums.StatusType;
import com.vincent.rsf.server.ai.service.impl.conversation.AiConversationCommandService;
import com.vincent.rsf.server.ai.service.impl.conversation.AiConversationQueryService;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.Locale;
@Service
@RequiredArgsConstructor
public class AiChatMemoryServiceImpl implements AiChatMemoryService {
    private final AiChatSessionMapper aiChatSessionMapper;
    private final AiChatMessageMapper aiChatMessageMapper;
    private final AiConversationQueryService aiConversationQueryService;
    private final AiConversationCommandService aiConversationCommandService;
    @Override
    public AiChatMemoryDto getMemory(Long userId, Long tenantId, String promptCode, Long sessionId) {
        ensureIdentity(userId, tenantId);
        String resolvedPromptCode = requirePromptCode(promptCode);
        AiChatSession session = sessionId == null
                ? findLatestSession(userId, tenantId, resolvedPromptCode)
                : getSession(sessionId, userId, tenantId, resolvedPromptCode);
        if (session == null) {
            return AiChatMemoryDto.builder()
                    .sessionId(null)
                    .persistedMessages(List.of())
                    .build();
        }
        return AiChatMemoryDto.builder()
                .sessionId(session.getId())
                .persistedMessages(listMessages(session.getId()))
                .build();
        return aiConversationQueryService.getMemory(userId, tenantId, promptCode, sessionId);
    }
    @Override
    public List<AiChatSessionDto> listSessions(Long userId, Long tenantId, String promptCode, String keyword) {
        ensureIdentity(userId, tenantId);
        String resolvedPromptCode = requirePromptCode(promptCode);
        List<AiChatSession> sessions = aiChatSessionMapper.selectList(new LambdaQueryWrapper<AiChatSession>()
                .eq(AiChatSession::getUserId, userId)
                .eq(AiChatSession::getTenantId, tenantId)
                .eq(AiChatSession::getPromptCode, resolvedPromptCode)
                .eq(AiChatSession::getDeleted, 0)
                .eq(AiChatSession::getStatus, StatusType.ENABLE.val)
                .like(StringUtils.hasText(keyword), AiChatSession::getTitle, keyword == null ? null : keyword.trim())
                .orderByDesc(AiChatSession::getPinned)
                .orderByDesc(AiChatSession::getLastMessageTime)
                .orderByDesc(AiChatSession::getId));
        if (Cools.isEmpty(sessions)) {
            return List.of();
        }
        List<AiChatSessionDto> result = new ArrayList<>();
        for (AiChatSession session : sessions) {
            result.add(buildSessionDto(session));
        }
        return result;
        return aiConversationQueryService.listSessions(userId, tenantId, promptCode, keyword);
    }
    @Override
    public AiChatSession resolveSession(Long userId, Long tenantId, String promptCode, Long sessionId, String titleSeed) {
        ensureIdentity(userId, tenantId);
        String resolvedPromptCode = requirePromptCode(promptCode);
        if (sessionId != null) {
            return getSession(sessionId, userId, tenantId, resolvedPromptCode);
        }
        Date now = new Date();
        AiChatSession session = new AiChatSession()
                .setTitle(buildSessionTitle(titleSeed))
                .setPromptCode(resolvedPromptCode)
                .setUserId(userId)
                .setTenantId(tenantId)
                .setLastMessageTime(now)
                .setPinned(0)
                .setStatus(StatusType.ENABLE.val)
                .setDeleted(0)
                .setCreateBy(userId)
                .setCreateTime(now)
                .setUpdateBy(userId)
                .setUpdateTime(now);
        aiChatSessionMapper.insert(session);
        return session;
        return aiConversationCommandService.resolveSession(userId, tenantId, promptCode, sessionId, titleSeed);
    }
    @Override
    public void saveRound(AiChatSession session, Long userId, Long tenantId, List<AiChatMessageDto> memoryMessages, String assistantContent) {
        if (session == null || session.getId() == null) {
            throw new CoolException("AI 会话不存在");
        }
        ensureIdentity(userId, tenantId);
        List<AiChatMessageDto> normalizedMessages = normalizeMessages(memoryMessages);
        if (normalizedMessages.isEmpty()) {
            throw new CoolException("本轮没有可保存的对话消息");
        }
        int nextSeqNo = findNextSeqNo(session.getId());
        Date now = new Date();
        for (AiChatMessageDto message : normalizedMessages) {
            aiChatMessageMapper.insert(buildMessageEntity(session.getId(), nextSeqNo++, message.getRole(), message.getContent(), userId, tenantId, now));
        }
        if (StringUtils.hasText(assistantContent)) {
            aiChatMessageMapper.insert(buildMessageEntity(session.getId(), nextSeqNo, "assistant", assistantContent, userId, tenantId, now));
        }
        AiChatSession update = new AiChatSession()
                .setId(session.getId())
                .setTitle(resolveUpdatedTitle(session.getTitle(), normalizedMessages))
                .setLastMessageTime(now)
                .setUpdateBy(userId)
                .setUpdateTime(now);
        aiChatSessionMapper.updateById(update);
        aiConversationCommandService.saveRound(session, userId, tenantId, memoryMessages, assistantContent);
    }
    @Override
    public void removeSession(Long userId, Long tenantId, Long sessionId) {
        ensureIdentity(userId, tenantId);
        if (sessionId == null) {
            throw new CoolException("AI 会话 ID 不能为空");
        }
        AiChatSession session = aiChatSessionMapper.selectOne(new LambdaQueryWrapper<AiChatSession>()
                .eq(AiChatSession::getId, sessionId)
                .eq(AiChatSession::getUserId, userId)
                .eq(AiChatSession::getTenantId, tenantId)
                .eq(AiChatSession::getDeleted, 0)
                .last("limit 1"));
        if (session == null) {
            throw new CoolException("AI 会话不存在或无权删除");
        }
        Date now = new Date();
        AiChatSession updateSession = new AiChatSession()
                .setId(sessionId)
                .setDeleted(1)
                .setUpdateBy(userId)
                .setUpdateTime(now);
        aiChatSessionMapper.updateById(updateSession);
        List<AiChatMessage> messages = aiChatMessageMapper.selectList(new LambdaQueryWrapper<AiChatMessage>()
                .eq(AiChatMessage::getSessionId, sessionId)
                .eq(AiChatMessage::getDeleted, 0));
        for (AiChatMessage message : messages) {
            AiChatMessage updateMessage = new AiChatMessage()
                    .setId(message.getId())
                    .setDeleted(1);
            aiChatMessageMapper.updateById(updateMessage);
        }
        aiConversationCommandService.removeSession(userId, tenantId, sessionId);
    }
    @Override
    public AiChatSessionDto renameSession(Long userId, Long tenantId, Long sessionId, AiChatSessionRenameRequest request) {
        ensureIdentity(userId, tenantId);
        if (request == null || !StringUtils.hasText(request.getTitle())) {
            throw new CoolException("会话标题不能为空");
        }
        AiChatSession session = requireOwnedSession(sessionId, userId, tenantId);
        Date now = new Date();
        AiChatSession update = new AiChatSession()
                .setId(sessionId)
                .setTitle(buildSessionTitle(request.getTitle()))
                .setUpdateBy(userId)
                .setUpdateTime(now);
        aiChatSessionMapper.updateById(update);
        return buildSessionDto(requireOwnedSession(sessionId, userId, tenantId));
        return aiConversationCommandService.renameSession(userId, tenantId, sessionId, request);
    }
    @Override
    public AiChatSessionDto pinSession(Long userId, Long tenantId, Long sessionId, AiChatSessionPinRequest request) {
        ensureIdentity(userId, tenantId);
        if (request == null || request.getPinned() == null) {
            throw new CoolException("置顶状态不能为空");
        }
        AiChatSession session = requireOwnedSession(sessionId, userId, tenantId);
        Date now = new Date();
        AiChatSession update = new AiChatSession()
                .setId(sessionId)
                .setPinned(Boolean.TRUE.equals(request.getPinned()) ? 1 : 0)
                .setUpdateBy(userId)
                .setUpdateTime(now);
        aiChatSessionMapper.updateById(update);
        return buildSessionDto(requireOwnedSession(sessionId, userId, tenantId));
        return aiConversationCommandService.pinSession(userId, tenantId, sessionId, request);
    }
    private AiChatSession findLatestSession(Long userId, Long tenantId, String promptCode) {
        return aiChatSessionMapper.selectOne(new LambdaQueryWrapper<AiChatSession>()
                .eq(AiChatSession::getUserId, userId)
                .eq(AiChatSession::getTenantId, tenantId)
                .eq(AiChatSession::getPromptCode, promptCode)
                .eq(AiChatSession::getDeleted, 0)
                .eq(AiChatSession::getStatus, StatusType.ENABLE.val)
                .orderByDesc(AiChatSession::getLastMessageTime)
                .orderByDesc(AiChatSession::getId)
                .last("limit 1"));
    @Override
    public void clearSessionMemory(Long userId, Long tenantId, Long sessionId) {
        aiConversationCommandService.clearSessionMemory(userId, tenantId, sessionId);
    }
    private AiChatSession getSession(Long sessionId, Long userId, Long tenantId, String promptCode) {
        AiChatSession session = aiChatSessionMapper.selectOne(new LambdaQueryWrapper<AiChatSession>()
                .eq(AiChatSession::getId, sessionId)
                .eq(AiChatSession::getUserId, userId)
                .eq(AiChatSession::getTenantId, tenantId)
                .eq(AiChatSession::getPromptCode, promptCode)
                .eq(AiChatSession::getDeleted, 0)
                .eq(AiChatSession::getStatus, StatusType.ENABLE.val)
                .last("limit 1"));
        if (session == null) {
            throw new CoolException("AI 会话不存在或无权访问");
        }
        return session;
    }
    private AiChatSession requireOwnedSession(Long sessionId, Long userId, Long tenantId) {
        if (sessionId == null) {
            throw new CoolException("AI 会话 ID 不能为空");
        }
        AiChatSession session = aiChatSessionMapper.selectOne(new LambdaQueryWrapper<AiChatSession>()
                .eq(AiChatSession::getId, sessionId)
                .eq(AiChatSession::getUserId, userId)
                .eq(AiChatSession::getTenantId, tenantId)
                .eq(AiChatSession::getDeleted, 0)
                .eq(AiChatSession::getStatus, StatusType.ENABLE.val)
                .last("limit 1"));
        if (session == null) {
            throw new CoolException("AI 会话不存在或无权访问");
        }
        return session;
    }
    private List<AiChatMessageDto> listMessages(Long sessionId) {
        List<AiChatMessage> records = aiChatMessageMapper.selectList(new LambdaQueryWrapper<AiChatMessage>()
                .eq(AiChatMessage::getSessionId, sessionId)
                .eq(AiChatMessage::getDeleted, 0)
                .orderByAsc(AiChatMessage::getSeqNo)
                .orderByAsc(AiChatMessage::getId));
        if (Cools.isEmpty(records)) {
            return List.of();
        }
        List<AiChatMessageDto> messages = new ArrayList<>();
        for (AiChatMessage record : records) {
            if (!StringUtils.hasText(record.getContent())) {
                continue;
            }
            AiChatMessageDto item = new AiChatMessageDto();
            item.setRole(record.getRole());
            item.setContent(record.getContent());
            messages.add(item);
        }
        return messages;
    }
    private List<AiChatMessageDto> normalizeMessages(List<AiChatMessageDto> memoryMessages) {
        List<AiChatMessageDto> normalized = new ArrayList<>();
        if (Cools.isEmpty(memoryMessages)) {
            return normalized;
        }
        for (AiChatMessageDto item : memoryMessages) {
            if (item == null || !StringUtils.hasText(item.getContent())) {
                continue;
            }
            String role = item.getRole() == null ? "user" : item.getRole().toLowerCase();
            if ("system".equals(role)) {
                continue;
            }
            AiChatMessageDto normalizedItem = new AiChatMessageDto();
            normalizedItem.setRole("assistant".equals(role) ? "assistant" : "user");
            normalizedItem.setContent(item.getContent().trim());
            normalized.add(normalizedItem);
        }
        return normalized;
    }
    private int findNextSeqNo(Long sessionId) {
        AiChatMessage lastMessage = aiChatMessageMapper.selectOne(new LambdaQueryWrapper<AiChatMessage>()
                .eq(AiChatMessage::getSessionId, sessionId)
                .eq(AiChatMessage::getDeleted, 0)
                .orderByDesc(AiChatMessage::getSeqNo)
                .orderByDesc(AiChatMessage::getId)
                .last("limit 1"));
        return lastMessage == null || lastMessage.getSeqNo() == null ? 1 : lastMessage.getSeqNo() + 1;
    }
    private AiChatMessage buildMessageEntity(Long sessionId, int seqNo, String role, String content, Long userId, Long tenantId, Date createTime) {
        return new AiChatMessage()
                .setSessionId(sessionId)
                .setSeqNo(seqNo)
                .setRole(role)
                .setContent(content)
                .setUserId(userId)
                .setTenantId(tenantId)
                .setDeleted(0)
                .setCreateBy(userId)
                .setCreateTime(createTime);
    }
    private String resolveUpdatedTitle(String currentTitle, List<AiChatMessageDto> memoryMessages) {
        if (StringUtils.hasText(currentTitle)) {
            return currentTitle;
        }
        for (AiChatMessageDto item : memoryMessages) {
            if ("user".equals(item.getRole()) && StringUtils.hasText(item.getContent())) {
                return buildSessionTitle(item.getContent());
            }
        }
        return null;
    }
    private String buildSessionTitle(String titleSeed) {
        if (!StringUtils.hasText(titleSeed)) {
            throw new CoolException("AI 会话标题不能为空");
        }
        String title = titleSeed.trim()
                .replace("\r", " ")
                .replace("\n", " ")
                .replaceAll("\\s+", " ");
        int punctuationIndex = findSummaryBreakIndex(title);
        if (punctuationIndex > 0) {
            title = title.substring(0, punctuationIndex).trim();
        }
        return title.length() > 48 ? title.substring(0, 48) : title;
    }
    private int findSummaryBreakIndex(String title) {
        String[] separators = {"。", "!", "?", ".", "!", "?"};
        int result = -1;
        for (String separator : separators) {
            int index = title.indexOf(separator);
            if (index > 0 && (result < 0 || index < result)) {
                result = index;
            }
        }
        return result;
    }
    private AiChatSessionDto buildSessionDto(AiChatSession session) {
        AiChatMessage lastMessage = aiChatMessageMapper.selectOne(new LambdaQueryWrapper<AiChatMessage>()
                .eq(AiChatMessage::getSessionId, session.getId())
                .eq(AiChatMessage::getDeleted, 0)
                .orderByDesc(AiChatMessage::getSeqNo)
                .orderByDesc(AiChatMessage::getId)
                .last("limit 1"));
        return AiChatSessionDto.builder()
                .sessionId(session.getId())
                .title(session.getTitle())
                .promptCode(session.getPromptCode())
                .pinned(session.getPinned() != null && session.getPinned() == 1)
                .lastMessagePreview(buildLastMessagePreview(lastMessage))
                .lastMessageTime(session.getLastMessageTime())
                .build();
    }
    private String buildLastMessagePreview(AiChatMessage message) {
        if (message == null || !StringUtils.hasText(message.getContent())) {
            return null;
        }
        String preview = message.getContent().trim()
                .replace("\r", " ")
                .replace("\n", " ")
                .replaceAll("\\s+", " ");
        String prefix = "assistant".equalsIgnoreCase(message.getRole()) ? "AI: " : "你: ";
        String normalized = prefix + preview;
        return normalized.length() > 80 ? normalized.substring(0, 80) : normalized;
    }
    private void ensureIdentity(Long userId, Long tenantId) {
        if (userId == null) {
            throw new CoolException("当前登录用户不存在");
        }
        if (tenantId == null) {
            throw new CoolException("当前租户不存在");
        }
    }
    private String requirePromptCode(String promptCode) {
        if (!StringUtils.hasText(promptCode)) {
            throw new CoolException("Prompt 编码不能为空");
        }
        return promptCode;
    @Override
    public void retainLatestRound(Long userId, Long tenantId, Long sessionId) {
        aiConversationCommandService.retainLatestRound(userId, tenantId, sessionId);
    }
}