package com.vincent.rsf.server.ai.service.impl.conversation;
|
|
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
import com.vincent.rsf.framework.common.Cools;
|
import com.vincent.rsf.framework.exception.CoolException;
|
import com.vincent.rsf.server.ai.config.AiDefaults;
|
import com.vincent.rsf.server.ai.dto.AiChatMemoryDto;
|
import com.vincent.rsf.server.ai.dto.AiChatMessageDto;
|
import com.vincent.rsf.server.ai.dto.AiChatSessionDto;
|
import com.vincent.rsf.server.ai.entity.AiChatMessage;
|
import com.vincent.rsf.server.ai.entity.AiChatSession;
|
import com.vincent.rsf.server.ai.mapper.AiChatMessageMapper;
|
import com.vincent.rsf.server.ai.mapper.AiChatSessionMapper;
|
import com.vincent.rsf.server.ai.store.AiConversationCacheStore;
|
import com.vincent.rsf.server.system.enums.StatusType;
|
import lombok.RequiredArgsConstructor;
|
import org.springframework.stereotype.Service;
|
import org.springframework.util.StringUtils;
|
|
import java.util.ArrayList;
|
import java.util.LinkedHashMap;
|
import java.util.List;
|
import java.util.Map;
|
|
@Service
|
@RequiredArgsConstructor
|
public class AiConversationQueryService {
|
|
private final AiChatSessionMapper aiChatSessionMapper;
|
private final AiChatMessageMapper aiChatMessageMapper;
|
private final AiConversationCacheStore aiConversationCacheStore;
|
|
public AiChatMemoryDto getMemory(Long userId, Long tenantId, String promptCode, Long sessionId) {
|
ensureIdentity(userId, tenantId);
|
String resolvedPromptCode = requirePromptCode(promptCode);
|
AiChatMemoryDto cached = aiConversationCacheStore.getMemory(tenantId, userId, resolvedPromptCode, sessionId);
|
if (cached != null) {
|
return cached;
|
}
|
AiChatSession session = sessionId == null
|
? findLatestSession(userId, tenantId, resolvedPromptCode)
|
: getSession(sessionId, userId, tenantId, resolvedPromptCode);
|
AiChatMemoryDto memory;
|
if (session == null) {
|
memory = AiChatMemoryDto.builder()
|
.sessionId(null)
|
.memorySummary(null)
|
.memoryFacts(null)
|
.recentMessageCount(0)
|
.persistedMessages(List.of())
|
.shortMemoryMessages(List.of())
|
.build();
|
aiConversationCacheStore.cacheMemory(tenantId, userId, resolvedPromptCode, sessionId, memory);
|
return memory;
|
}
|
List<AiChatMessageDto> persistedMessages = listMessages(session.getId());
|
List<AiChatMessageDto> shortMemoryMessages = tailMessagesByRounds(persistedMessages, AiDefaults.MEMORY_RECENT_ROUNDS);
|
memory = AiChatMemoryDto.builder()
|
.sessionId(session.getId())
|
.memorySummary(session.getMemorySummary())
|
.memoryFacts(session.getMemoryFacts())
|
.recentMessageCount(shortMemoryMessages.size())
|
.persistedMessages(persistedMessages)
|
.shortMemoryMessages(shortMemoryMessages)
|
.build();
|
aiConversationCacheStore.cacheMemory(tenantId, userId, resolvedPromptCode, session.getId(), memory);
|
if (sessionId == null || !session.getId().equals(sessionId)) {
|
aiConversationCacheStore.cacheMemory(tenantId, userId, resolvedPromptCode, null, memory);
|
}
|
return memory;
|
}
|
|
public List<AiChatSessionDto> listSessions(Long userId, Long tenantId, String promptCode, String keyword) {
|
ensureIdentity(userId, tenantId);
|
String resolvedPromptCode = requirePromptCode(promptCode);
|
List<AiChatSessionDto> cached = aiConversationCacheStore.getSessionList(tenantId, userId, resolvedPromptCode, keyword);
|
if (cached != null) {
|
return cached;
|
}
|
List<AiChatSession> sessions = aiChatSessionMapper.selectList(new LambdaQueryWrapper<AiChatSession>()
|
.eq(AiChatSession::getUserId, userId)
|
.eq(AiChatSession::getTenantId, tenantId)
|
.eq(AiChatSession::getPromptCode, resolvedPromptCode)
|
.eq(AiChatSession::getDeleted, 0)
|
.eq(AiChatSession::getStatus, StatusType.ENABLE.val)
|
.like(StringUtils.hasText(keyword), AiChatSession::getTitle, keyword == null ? null : keyword.trim())
|
.orderByDesc(AiChatSession::getPinned)
|
.orderByDesc(AiChatSession::getLastMessageTime)
|
.orderByDesc(AiChatSession::getId));
|
if (Cools.isEmpty(sessions)) {
|
aiConversationCacheStore.cacheSessionList(tenantId, userId, resolvedPromptCode, keyword, List.of());
|
return List.of();
|
}
|
List<Long> sessionIds = sessions.stream().map(AiChatSession::getId).toList();
|
Map<Long, AiChatMessage> latestMessageMap = new LinkedHashMap<>();
|
for (AiChatMessage message : aiChatMessageMapper.selectLatestMessagesBySessionIds(sessionIds)) {
|
latestMessageMap.put(message.getSessionId(), message);
|
}
|
List<AiChatSessionDto> result = new ArrayList<>();
|
for (AiChatSession session : sessions) {
|
result.add(buildSessionDto(session, latestMessageMap.get(session.getId())));
|
}
|
aiConversationCacheStore.cacheSessionList(tenantId, userId, resolvedPromptCode, keyword, result);
|
return result;
|
}
|
|
public void evictConversationCaches(Long tenantId, Long userId) {
|
aiConversationCacheStore.evictUserConversationCaches(tenantId, userId);
|
}
|
|
public AiChatSession findLatestSession(Long userId, Long tenantId, String promptCode) {
|
return aiChatSessionMapper.selectOne(new LambdaQueryWrapper<AiChatSession>()
|
.eq(AiChatSession::getUserId, userId)
|
.eq(AiChatSession::getTenantId, tenantId)
|
.eq(AiChatSession::getPromptCode, promptCode)
|
.eq(AiChatSession::getDeleted, 0)
|
.eq(AiChatSession::getStatus, StatusType.ENABLE.val)
|
.orderByDesc(AiChatSession::getLastMessageTime)
|
.orderByDesc(AiChatSession::getId)
|
.last("limit 1"));
|
}
|
|
public AiChatSession getSession(Long sessionId, Long userId, Long tenantId, String promptCode) {
|
AiChatSession session = aiChatSessionMapper.selectOne(new LambdaQueryWrapper<AiChatSession>()
|
.eq(AiChatSession::getId, sessionId)
|
.eq(AiChatSession::getUserId, userId)
|
.eq(AiChatSession::getTenantId, tenantId)
|
.eq(AiChatSession::getPromptCode, promptCode)
|
.eq(AiChatSession::getDeleted, 0)
|
.eq(AiChatSession::getStatus, StatusType.ENABLE.val)
|
.last("limit 1"));
|
if (session == null) {
|
throw new CoolException("AI 会话不存在或无权访问");
|
}
|
return session;
|
}
|
|
public AiChatSession requireOwnedSession(Long sessionId, Long userId, Long tenantId) {
|
if (sessionId == null) {
|
throw new CoolException("AI 会话 ID 不能为空");
|
}
|
AiChatSession session = aiChatSessionMapper.selectOne(new LambdaQueryWrapper<AiChatSession>()
|
.eq(AiChatSession::getId, sessionId)
|
.eq(AiChatSession::getUserId, userId)
|
.eq(AiChatSession::getTenantId, tenantId)
|
.eq(AiChatSession::getDeleted, 0)
|
.eq(AiChatSession::getStatus, StatusType.ENABLE.val)
|
.last("limit 1"));
|
if (session == null) {
|
throw new CoolException("AI 会话不存在或无权访问");
|
}
|
return session;
|
}
|
|
public List<AiChatMessageDto> listMessages(Long sessionId) {
|
List<AiChatMessage> records = listMessageRecords(sessionId);
|
if (Cools.isEmpty(records)) {
|
return List.of();
|
}
|
List<AiChatMessageDto> messages = new ArrayList<>();
|
for (AiChatMessage record : records) {
|
if (!StringUtils.hasText(record.getContent())) {
|
continue;
|
}
|
AiChatMessageDto item = new AiChatMessageDto();
|
item.setRole(record.getRole());
|
item.setContent(record.getContent());
|
messages.add(item);
|
}
|
return messages;
|
}
|
|
public List<AiChatMessage> listMessageRecords(Long sessionId) {
|
return aiChatMessageMapper.selectList(new LambdaQueryWrapper<AiChatMessage>()
|
.eq(AiChatMessage::getSessionId, sessionId)
|
.eq(AiChatMessage::getDeleted, 0)
|
.orderByAsc(AiChatMessage::getSeqNo)
|
.orderByAsc(AiChatMessage::getId));
|
}
|
|
public int findNextSeqNo(Long sessionId) {
|
AiChatMessage lastMessage = aiChatMessageMapper.selectOne(new LambdaQueryWrapper<AiChatMessage>()
|
.eq(AiChatMessage::getSessionId, sessionId)
|
.eq(AiChatMessage::getDeleted, 0)
|
.orderByDesc(AiChatMessage::getSeqNo)
|
.orderByDesc(AiChatMessage::getId)
|
.last("limit 1"));
|
return lastMessage == null || lastMessage.getSeqNo() == null ? 1 : lastMessage.getSeqNo() + 1;
|
}
|
|
public String resolveUpdatedTitle(String currentTitle, List<AiChatMessageDto> memoryMessages) {
|
if (StringUtils.hasText(currentTitle)) {
|
return currentTitle;
|
}
|
for (AiChatMessageDto item : memoryMessages) {
|
if ("user".equals(item.getRole()) && StringUtils.hasText(item.getContent())) {
|
return buildSessionTitle(item.getContent());
|
}
|
}
|
return null;
|
}
|
|
public String buildSessionTitle(String titleSeed) {
|
if (!StringUtils.hasText(titleSeed)) {
|
throw new CoolException("AI 会话标题不能为空");
|
}
|
String title = titleSeed.trim()
|
.replace("\r", " ")
|
.replace("\n", " ")
|
.replaceAll("\\s+", " ");
|
int punctuationIndex = findSummaryBreakIndex(title);
|
if (punctuationIndex > 0) {
|
title = title.substring(0, punctuationIndex).trim();
|
}
|
return title.length() > 48 ? title.substring(0, 48) : title;
|
}
|
|
public List<AiChatMessageDto> tailMessagesByRounds(List<AiChatMessageDto> source, int rounds) {
|
if (Cools.isEmpty(source) || rounds <= 0) {
|
return List.of();
|
}
|
int userCount = 0;
|
int startIndex = source.size();
|
for (int i = source.size() - 1; i >= 0; i--) {
|
AiChatMessageDto item = source.get(i);
|
startIndex = i;
|
if (item != null && "user".equalsIgnoreCase(item.getRole())) {
|
userCount++;
|
if (userCount >= rounds) {
|
break;
|
}
|
}
|
}
|
return new ArrayList<>(source.subList(Math.max(0, startIndex), source.size()));
|
}
|
|
public List<AiChatMessage> tailMessageRecordsByRounds(List<AiChatMessage> source, int rounds) {
|
if (Cools.isEmpty(source) || rounds <= 0) {
|
return List.of();
|
}
|
int userCount = 0;
|
int startIndex = source.size();
|
for (int i = source.size() - 1; i >= 0; i--) {
|
AiChatMessage item = source.get(i);
|
startIndex = i;
|
if (item != null && "user".equalsIgnoreCase(item.getRole())) {
|
userCount++;
|
if (userCount >= rounds) {
|
break;
|
}
|
}
|
}
|
return new ArrayList<>(source.subList(Math.max(0, startIndex), source.size()));
|
}
|
|
public String compactText(String content, int maxLength) {
|
if (!StringUtils.hasText(content)) {
|
return null;
|
}
|
String normalized = content.trim()
|
.replace("\r", " ")
|
.replace("\n", " ")
|
.replaceAll("\\s+", " ");
|
return normalized.length() > maxLength ? normalized.substring(0, maxLength) : normalized;
|
}
|
|
public String requirePromptCode(String promptCode) {
|
if (!StringUtils.hasText(promptCode)) {
|
throw new CoolException("Prompt 编码不能为空");
|
}
|
return promptCode;
|
}
|
|
public void ensureIdentity(Long userId, Long tenantId) {
|
if (userId == null) {
|
throw new CoolException("当前登录用户不存在");
|
}
|
if (tenantId == null) {
|
throw new CoolException("当前租户不存在");
|
}
|
}
|
|
private int findSummaryBreakIndex(String title) {
|
String[] separators = {"。", "!", "?", ".", "!", "?"};
|
int result = -1;
|
for (String separator : separators) {
|
int index = title.indexOf(separator);
|
if (index > 0 && (result < 0 || index < result)) {
|
result = index;
|
}
|
}
|
return result;
|
}
|
|
private AiChatSessionDto buildSessionDto(AiChatSession session, AiChatMessage lastMessage) {
|
return AiChatSessionDto.builder()
|
.sessionId(session.getId())
|
.title(session.getTitle())
|
.promptCode(session.getPromptCode())
|
.pinned(session.getPinned() != null && session.getPinned() == 1)
|
.lastMessagePreview(buildLastMessagePreview(lastMessage))
|
.lastMessageTime(session.getLastMessageTime())
|
.build();
|
}
|
|
private String buildLastMessagePreview(AiChatMessage message) {
|
if (message == null || !StringUtils.hasText(message.getContent())) {
|
return null;
|
}
|
String preview = message.getContent().trim()
|
.replace("\r", " ")
|
.replace("\n", " ")
|
.replaceAll("\\s+", " ");
|
String prefix = "assistant".equalsIgnoreCase(message.getRole()) ? "AI: " : "你: ";
|
String normalized = prefix + preview;
|
return normalized.length() > 80 ? normalized.substring(0, 80) : normalized;
|
}
|
}
|