package com.vincent.rsf.server.ai.service.impl;
|
|
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
import com.vincent.rsf.server.ai.mapper.AiChatMessageMapper;
|
import com.vincent.rsf.server.ai.mapper.AiChatSessionMapper;
|
import com.vincent.rsf.server.ai.model.AiChatMessage;
|
import com.vincent.rsf.server.ai.model.AiChatSession;
|
import com.vincent.rsf.server.ai.service.AiRuntimeConfigService;
|
import com.vincent.rsf.server.ai.service.AiSessionService;
|
import org.springframework.stereotype.Service;
|
|
import javax.annotation.PostConstruct;
|
import javax.annotation.Resource;
|
import javax.sql.DataSource;
|
import java.sql.Connection;
|
import java.sql.ResultSet;
|
import java.util.ArrayList;
|
import java.util.Comparator;
|
import java.util.Date;
|
import java.util.List;
|
import java.util.Objects;
|
import java.util.UUID;
|
import java.util.concurrent.ConcurrentHashMap;
|
import java.util.concurrent.ConcurrentMap;
|
|
@Service
|
public class AiSessionServiceImpl implements AiSessionService {
|
|
private static final ConcurrentMap<String, List<AiChatSession>> LOCAL_SESSION_CACHE = new ConcurrentHashMap<>();
|
private static final ConcurrentMap<String, List<AiChatMessage>> LOCAL_MESSAGE_CACHE = new ConcurrentHashMap<>();
|
private static final ConcurrentMap<String, String> LOCAL_STOP_CACHE = new ConcurrentHashMap<>();
|
private static final String SESSION_TABLE_NAME = "sys_ai_chat_session";
|
private static final String MESSAGE_TABLE_NAME = "sys_ai_chat_message";
|
|
@Resource
|
private AiRuntimeConfigService aiRuntimeConfigService;
|
@Resource
|
private AiChatSessionMapper aiChatSessionMapper;
|
@Resource
|
private AiChatMessageMapper aiChatMessageMapper;
|
@Resource
|
private DataSource dataSource;
|
|
private volatile boolean storageReady;
|
|
@PostConstruct
|
/**
|
* 启动时探测聊天存储表是否已创建。
|
* 如果表存在则走数据库持久化,否则回退到本地内存缓存,保证开发和缺表场景可继续运行。
|
*/
|
public void initStorageMode() {
|
storageReady = detectStorageTables();
|
}
|
|
@Override
|
/**
|
* 读取用户会话列表。
|
* 数据库存储模式直接查表,内存模式则从本地缓存取出并按最近更新时间排序。
|
*/
|
public synchronized List<AiChatSession> listSessions(Long tenantId, Long userId) {
|
if (useDatabaseStorage()) {
|
return aiChatSessionMapper.selectList(new LambdaQueryWrapper<AiChatSession>()
|
.eq(AiChatSession::getTenantId, tenantId)
|
.eq(AiChatSession::getUserId, userId)
|
.orderByDesc(AiChatSession::getUpdateTime, AiChatSession::getCreateTime));
|
}
|
List<AiChatSession> sessions = getSessions(tenantId, userId);
|
sessions.sort(Comparator.comparing(AiChatSession::getUpdateTime, Comparator.nullsLast(Date::compareTo)).reversed());
|
return sessions;
|
}
|
|
@Override
|
/**
|
* 创建新会话,并初始化标题、模型和时间戳。
|
*/
|
public synchronized AiChatSession createSession(Long tenantId, Long userId, String title, String modelCode) {
|
List<AiChatSession> sessions = useDatabaseStorage() ? listSessions(tenantId, userId) : getSessions(tenantId, userId);
|
Date now = new Date();
|
AiChatSession session = new AiChatSession()
|
.setId(UUID.randomUUID().toString().replace("-", ""))
|
.setTenantId(tenantId)
|
.setUserId(userId)
|
.setTitle(resolveTitle(title, sessions.size() + 1))
|
.setModelCode(resolveModelCode(modelCode))
|
.setCreateTime(now)
|
.setUpdateTime(now)
|
.setLastMessageAt(now)
|
.setStatus(1)
|
.setDeleted(0);
|
if (useDatabaseStorage()) {
|
aiChatSessionMapper.insert(session);
|
return session;
|
}
|
sessions.add(0, session);
|
saveSessions(tenantId, userId, sessions);
|
saveMessages(session.getId(), new ArrayList<>());
|
return session;
|
}
|
|
@Override
|
/**
|
* 确保会话存在;如果会话已存在但模型发生变化,会同步更新会话记录。
|
*/
|
public synchronized AiChatSession ensureSession(Long tenantId, Long userId, String sessionId, String modelCode) {
|
AiChatSession session = getSession(tenantId, userId, sessionId);
|
if (session == null) {
|
return createSession(tenantId, userId, null, modelCode);
|
}
|
if (modelCode != null && !modelCode.trim().isEmpty() && !modelCode.equals(session.getModelCode())) {
|
session.setModelCode(modelCode);
|
session.setUpdateTime(new Date());
|
refreshSession(tenantId, userId, session);
|
}
|
return session;
|
}
|
|
@Override
|
/**
|
* 安全读取会话,并校验租户与用户归属。
|
*/
|
public synchronized AiChatSession getSession(Long tenantId, Long userId, String sessionId) {
|
if (sessionId == null || sessionId.trim().isEmpty()) {
|
return null;
|
}
|
if (useDatabaseStorage()) {
|
AiChatSession session = aiChatSessionMapper.selectById(sessionId);
|
if (session == null) {
|
return null;
|
}
|
if (!Objects.equals(tenantId, session.getTenantId()) || !Objects.equals(userId, session.getUserId())) {
|
return null;
|
}
|
return session;
|
}
|
for (AiChatSession session : getSessions(tenantId, userId)) {
|
if (sessionId.equals(session.getId())) {
|
return session;
|
}
|
}
|
return null;
|
}
|
|
@Override
|
/**
|
* 更新会话标题。
|
*/
|
public synchronized AiChatSession renameSession(Long tenantId, Long userId, String sessionId, String title) {
|
AiChatSession session = getSession(tenantId, userId, sessionId);
|
if (session == null) {
|
return null;
|
}
|
session.setTitle(resolveTitle(title, 1));
|
session.setUpdateTime(new Date());
|
refreshSession(tenantId, userId, session);
|
return session;
|
}
|
|
@Override
|
/**
|
* 删除会话及其关联消息,同时清理停止标记缓存。
|
*/
|
public synchronized void removeSession(Long tenantId, Long userId, String sessionId) {
|
if (useDatabaseStorage()) {
|
AiChatSession session = getSession(tenantId, userId, sessionId);
|
if (session != null) {
|
aiChatMessageMapper.delete(new LambdaQueryWrapper<AiChatMessage>()
|
.eq(AiChatMessage::getSessionId, sessionId));
|
aiChatSessionMapper.deleteById(sessionId);
|
}
|
LOCAL_STOP_CACHE.remove(sessionId);
|
return;
|
}
|
List<AiChatSession> sessions = getSessions(tenantId, userId);
|
sessions.removeIf(session -> sessionId.equals(session.getId()));
|
saveSessions(tenantId, userId, sessions);
|
LOCAL_MESSAGE_CACHE.remove(sessionId);
|
LOCAL_STOP_CACHE.remove(sessionId);
|
}
|
|
@Override
|
/**
|
* 查询会话的完整消息历史。
|
*/
|
public synchronized List<AiChatMessage> listMessages(Long tenantId, Long userId, String sessionId) {
|
AiChatSession session = getSession(tenantId, userId, sessionId);
|
if (session == null) {
|
return new ArrayList<>();
|
}
|
if (useDatabaseStorage()) {
|
return aiChatMessageMapper.selectList(new LambdaQueryWrapper<AiChatMessage>()
|
.eq(AiChatMessage::getSessionId, sessionId)
|
.orderByAsc(AiChatMessage::getCreateTime, AiChatMessage::getId));
|
}
|
return getMessages(sessionId);
|
}
|
|
@Override
|
/**
|
* 截取最近若干条消息作为模型上下文,避免每次都把完整历史发送给模型。
|
*/
|
public synchronized List<AiChatMessage> listContextMessages(Long tenantId, Long userId, String sessionId, int maxCount) {
|
List<AiChatMessage> messages = listMessages(tenantId, userId, sessionId);
|
if (messages.size() <= maxCount) {
|
return messages;
|
}
|
return new ArrayList<>(messages.subList(messages.size() - maxCount, messages.size()));
|
}
|
|
@Override
|
/**
|
* 追加一条消息,并同步刷新会话摘要、活跃时间和默认标题。
|
*/
|
public synchronized AiChatMessage appendMessage(Long tenantId, Long userId, String sessionId, String role, String content, String modelCode) {
|
AiChatSession session = getSession(tenantId, userId, sessionId);
|
if (session == null) {
|
return null;
|
}
|
List<AiChatMessage> messages = getMessages(sessionId);
|
AiChatMessage message = new AiChatMessage()
|
.setId(UUID.randomUUID().toString().replace("-", ""))
|
.setTenantId(tenantId)
|
.setUserId(userId)
|
.setSessionId(sessionId)
|
.setRole(role)
|
.setContent(content)
|
.setModelCode(resolveModelCode(modelCode))
|
.setCreateTime(new Date())
|
.setStatus(1)
|
.setDeleted(0);
|
if (useDatabaseStorage()) {
|
aiChatMessageMapper.insert(message);
|
} else {
|
messages.add(message);
|
saveMessages(sessionId, messages);
|
}
|
session.setLastMessage(buildPreview(content));
|
session.setLastMessageAt(message.getCreateTime());
|
session.setUpdateTime(message.getCreateTime());
|
if (modelCode != null && !modelCode.trim().isEmpty()) {
|
session.setModelCode(modelCode);
|
}
|
if ((session.getTitle() == null || session.getTitle().startsWith("新对话")) && "user".equalsIgnoreCase(role)) {
|
session.setTitle(buildPreview(content));
|
}
|
refreshSession(tenantId, userId, session);
|
return message;
|
}
|
|
@Override
|
/**
|
* 清除停止生成标记。
|
*/
|
public void clearStopFlag(String sessionId) {
|
LOCAL_STOP_CACHE.remove(sessionId);
|
}
|
|
@Override
|
/**
|
* 标记会话需要停止生成。
|
*/
|
public void requestStop(String sessionId) {
|
LOCAL_STOP_CACHE.put(sessionId, "1");
|
}
|
|
@Override
|
/**
|
* 读取停止生成标记。
|
*/
|
public boolean isStopRequested(String sessionId) {
|
String stopFlag = LOCAL_STOP_CACHE.get(sessionId);
|
return "1".equals(stopFlag);
|
}
|
|
/**
|
* 从内存缓存中读取当前用户的会话列表。
|
*/
|
private List<AiChatSession> getSessions(Long tenantId, Long userId) {
|
String ownerKey = buildOwnerKey(tenantId, userId);
|
List<AiChatSession> sessions = LOCAL_SESSION_CACHE.get(ownerKey);
|
return sessions == null ? new ArrayList<>() : new ArrayList<>(sessions);
|
}
|
|
/**
|
* 将会话列表写回本地缓存。
|
*/
|
private void saveSessions(Long tenantId, Long userId, List<AiChatSession> sessions) {
|
String ownerKey = buildOwnerKey(tenantId, userId);
|
List<AiChatSession> cachedSessions = new ArrayList<>(sessions);
|
LOCAL_SESSION_CACHE.put(ownerKey, cachedSessions);
|
}
|
|
/**
|
* 从内存缓存中读取指定会话的消息列表。
|
*/
|
private List<AiChatMessage> getMessages(String sessionId) {
|
List<AiChatMessage> messages = LOCAL_MESSAGE_CACHE.get(sessionId);
|
return messages == null ? new ArrayList<>() : new ArrayList<>(messages);
|
}
|
|
/**
|
* 将消息列表写回本地缓存。
|
*/
|
private void saveMessages(String sessionId, List<AiChatMessage> messages) {
|
List<AiChatMessage> cachedMessages = new ArrayList<>(messages);
|
LOCAL_MESSAGE_CACHE.put(sessionId, cachedMessages);
|
}
|
|
/**
|
* 按存储模式刷新单个会话记录。
|
*/
|
private void refreshSession(Long tenantId, Long userId, AiChatSession target) {
|
if (useDatabaseStorage()) {
|
aiChatSessionMapper.updateById(target);
|
return;
|
}
|
List<AiChatSession> sessions = getSessions(tenantId, userId);
|
for (int i = 0; i < sessions.size(); i++) {
|
if (target.getId().equals(sessions.get(i).getId())) {
|
sessions.set(i, target);
|
saveSessions(tenantId, userId, sessions);
|
return;
|
}
|
}
|
sessions.add(target);
|
saveSessions(tenantId, userId, sessions);
|
}
|
|
/**
|
* 组装租户与用户维度的本地缓存 key。
|
*/
|
private String buildOwnerKey(Long tenantId, Long userId) {
|
return String.valueOf(tenantId) + ":" + String.valueOf(userId);
|
}
|
|
/**
|
* 解析本次消息使用的模型编码;为空时回退到系统默认模型。
|
*/
|
private String resolveModelCode(String modelCode) {
|
return modelCode == null || modelCode.trim().isEmpty() ? aiRuntimeConfigService.resolveDefaultModelCode() : modelCode;
|
}
|
|
/**
|
* 生成会话标题,未显式传标题时使用“新对话 N”。
|
*/
|
private String resolveTitle(String title, int index) {
|
if (title == null || title.trim().isEmpty()) {
|
return "新对话 " + index;
|
}
|
return buildPreview(title);
|
}
|
|
/**
|
* 将用户输入压缩成适合作为标题或最后消息预览的短文本。
|
*/
|
private String buildPreview(String content) {
|
if (content == null || content.trim().isEmpty()) {
|
return "新对话";
|
}
|
String normalized = content.replace("\r", " ").replace("\n", " ").trim();
|
return normalized.length() > 24 ? normalized.substring(0, 24) : normalized;
|
}
|
|
/**
|
* 判断当前是否可以使用数据库持久化聊天数据。
|
*/
|
private boolean useDatabaseStorage() {
|
return storageReady || (storageReady = detectStorageTables());
|
}
|
|
/**
|
* 检查聊天存储所需表是否已经存在。
|
*/
|
private boolean detectStorageTables() {
|
try (Connection connection = dataSource.getConnection()) {
|
return tableExists(connection, SESSION_TABLE_NAME) && tableExists(connection, MESSAGE_TABLE_NAME);
|
} catch (Exception ignore) {
|
return false;
|
}
|
}
|
|
/**
|
* 判断指定表名是否在当前数据库中存在。
|
*/
|
private boolean tableExists(Connection connection, String tableName) throws Exception {
|
if (tableName == null || tableName.trim().isEmpty()) {
|
return false;
|
}
|
String[] candidates = new String[]{tableName, tableName.toUpperCase(), tableName.toLowerCase()};
|
for (String candidate : candidates) {
|
try (ResultSet resultSet = connection.getMetaData().getTables(connection.getCatalog(), null, candidate, null)) {
|
if (resultSet.next()) {
|
return true;
|
}
|
}
|
}
|
return false;
|
}
|
|
}
|