package com.vincent.rsf.server.ai.service.impl.conversation;
|
|
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
import com.vincent.rsf.framework.common.Cools;
|
import com.vincent.rsf.server.ai.config.AiDefaults;
|
import com.vincent.rsf.server.ai.dto.AiChatMessageDto;
|
import com.vincent.rsf.server.ai.entity.AiChatMessage;
|
import com.vincent.rsf.server.ai.entity.AiChatSession;
|
import com.vincent.rsf.server.ai.mapper.AiChatMessageMapper;
|
import com.vincent.rsf.server.ai.mapper.AiChatSessionMapper;
|
import lombok.RequiredArgsConstructor;
|
import lombok.extern.slf4j.Slf4j;
|
import org.springframework.beans.factory.annotation.Qualifier;
|
import org.springframework.stereotype.Service;
|
import org.springframework.util.StringUtils;
|
|
import java.util.Date;
|
import java.util.List;
|
import java.util.Set;
|
import java.util.concurrent.ConcurrentHashMap;
|
import java.util.concurrent.Executor;
|
|
@Slf4j
|
@Service
|
@RequiredArgsConstructor
|
public class AiMemoryProfileService {
|
|
private final AiConversationQueryService aiConversationQueryService;
|
private final AiChatSessionMapper aiChatSessionMapper;
|
private final AiChatMessageMapper aiChatMessageMapper;
|
@Qualifier("aiMemoryTaskExecutor")
|
private final Executor aiMemoryTaskExecutor;
|
|
private final Set<Long> refreshingSessionIds = ConcurrentHashMap.newKeySet();
|
private final Set<Long> pendingRefreshSessionIds = ConcurrentHashMap.newKeySet();
|
|
public void scheduleMemoryProfileRefresh(Long sessionId, Long userId, Long tenantId) {
|
if (sessionId == null) {
|
return;
|
}
|
if (!refreshingSessionIds.add(sessionId)) {
|
pendingRefreshSessionIds.add(sessionId);
|
return;
|
}
|
aiMemoryTaskExecutor.execute(() -> runMemoryProfileRefreshLoop(sessionId, userId, tenantId));
|
}
|
|
private void runMemoryProfileRefreshLoop(Long sessionId, Long userId, Long tenantId) {
|
try {
|
boolean shouldContinue;
|
do {
|
pendingRefreshSessionIds.remove(sessionId);
|
try {
|
refreshMemoryProfile(sessionId, userId);
|
aiConversationQueryService.evictConversationCaches(tenantId, userId);
|
} catch (Exception e) {
|
log.warn("AI memory profile refresh failed, sessionId={}, userId={}, tenantId={}, message={}",
|
sessionId, userId, tenantId, e.getMessage(), e);
|
}
|
shouldContinue = pendingRefreshSessionIds.remove(sessionId);
|
} while (shouldContinue);
|
} finally {
|
refreshingSessionIds.remove(sessionId);
|
if (pendingRefreshSessionIds.remove(sessionId) && refreshingSessionIds.add(sessionId)) {
|
aiMemoryTaskExecutor.execute(() -> runMemoryProfileRefreshLoop(sessionId, userId, tenantId));
|
}
|
}
|
}
|
|
private void refreshMemoryProfile(Long sessionId, Long userId) {
|
List<AiChatMessageDto> messages = aiConversationQueryService.listMessages(sessionId);
|
List<AiChatMessageDto> shortMemoryMessages = aiConversationQueryService.tailMessagesByRounds(messages, AiDefaults.MEMORY_RECENT_ROUNDS);
|
List<AiChatMessageDto> historyMessages = messages.size() > shortMemoryMessages.size()
|
? messages.subList(0, messages.size() - shortMemoryMessages.size())
|
: List.of();
|
String memorySummary = historyMessages.size() >= AiDefaults.MEMORY_SUMMARY_TRIGGER_MESSAGES
|
? buildMemorySummary(historyMessages)
|
: null;
|
String memoryFacts = buildMemoryFacts(messages);
|
AiChatMessage lastMessage = aiChatMessageMapper.selectOne(new LambdaQueryWrapper<AiChatMessage>()
|
.eq(AiChatMessage::getSessionId, sessionId)
|
.eq(AiChatMessage::getDeleted, 0)
|
.orderByDesc(AiChatMessage::getSeqNo)
|
.orderByDesc(AiChatMessage::getId)
|
.last("limit 1"));
|
aiChatSessionMapper.updateById(new AiChatSession()
|
.setId(sessionId)
|
.setMemorySummary(memorySummary)
|
.setMemoryFacts(memoryFacts)
|
.setLastMessageTime(lastMessage == null ? null : lastMessage.getCreateTime())
|
.setUpdateBy(userId)
|
.setUpdateTime(new Date()));
|
}
|
|
private String buildMemorySummary(List<AiChatMessageDto> historyMessages) {
|
StringBuilder builder = new StringBuilder("较早对话摘要:\n");
|
for (AiChatMessageDto item : historyMessages) {
|
if (item == null || !StringUtils.hasText(item.getContent())) {
|
continue;
|
}
|
String prefix = "assistant".equalsIgnoreCase(item.getRole()) ? "- AI: " : "- 用户: ";
|
String content = aiConversationQueryService.compactText(item.getContent(), 120);
|
if (!StringUtils.hasText(content)) {
|
continue;
|
}
|
builder.append(prefix).append(content).append("\n");
|
if (builder.length() >= AiDefaults.MEMORY_SUMMARY_MAX_LENGTH) {
|
break;
|
}
|
}
|
return aiConversationQueryService.compactText(builder.toString(), AiDefaults.MEMORY_SUMMARY_MAX_LENGTH);
|
}
|
|
private String buildMemoryFacts(List<AiChatMessageDto> messages) {
|
if (Cools.isEmpty(messages)) {
|
return null;
|
}
|
StringBuilder builder = new StringBuilder("关键事实:\n");
|
int userFacts = 0;
|
for (int i = messages.size() - 1; i >= 0 && userFacts < 4; i--) {
|
AiChatMessageDto item = messages.get(i);
|
if (item == null || !"user".equalsIgnoreCase(item.getRole()) || !StringUtils.hasText(item.getContent())) {
|
continue;
|
}
|
builder.append("- 用户关注: ").append(aiConversationQueryService.compactText(item.getContent(), 100)).append("\n");
|
userFacts++;
|
}
|
return userFacts == 0 ? null : aiConversationQueryService.compactText(builder.toString(), AiDefaults.MEMORY_FACTS_MAX_LENGTH);
|
}
|
}
|