| | |
| | | package com.vincent.rsf.server.ai.service.impl; |
| | | |
| | | import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; |
| | | import com.vincent.rsf.server.ai.mapper.AiChatMessageMapper; |
| | | import com.vincent.rsf.server.ai.mapper.AiChatSessionMapper; |
| | | import com.vincent.rsf.server.ai.model.AiChatMessage; |
| | | import com.vincent.rsf.server.ai.model.AiChatSession; |
| | | import com.vincent.rsf.server.ai.service.AiRuntimeConfigService; |
| | | import com.vincent.rsf.server.ai.service.AiSessionService; |
| | | import org.springframework.stereotype.Service; |
| | | |
| | | import javax.annotation.PostConstruct; |
| | | import javax.annotation.Resource; |
| | | import javax.sql.DataSource; |
| | | import java.sql.Connection; |
| | | import java.sql.ResultSet; |
| | | import java.util.ArrayList; |
| | | import java.util.Comparator; |
| | | import java.util.Date; |
| | | import java.util.List; |
| | | import java.util.Objects; |
| | | import java.util.UUID; |
| | | import java.util.concurrent.ConcurrentHashMap; |
| | | import java.util.concurrent.ConcurrentMap; |
| | |
| | | private static final ConcurrentMap<String, List<AiChatSession>> LOCAL_SESSION_CACHE = new ConcurrentHashMap<>(); |
| | | private static final ConcurrentMap<String, List<AiChatMessage>> LOCAL_MESSAGE_CACHE = new ConcurrentHashMap<>(); |
| | | private static final ConcurrentMap<String, String> LOCAL_STOP_CACHE = new ConcurrentHashMap<>(); |
| | | private static final String SESSION_TABLE_NAME = "sys_ai_chat_session"; |
| | | private static final String MESSAGE_TABLE_NAME = "sys_ai_chat_message"; |
| | | |
| | | @Resource |
| | | private AiRuntimeConfigService aiRuntimeConfigService; |
| | | @Resource |
| | | private AiChatSessionMapper aiChatSessionMapper; |
| | | @Resource |
| | | private AiChatMessageMapper aiChatMessageMapper; |
| | | @Resource |
| | | private DataSource dataSource; |
| | | |
| | | private volatile boolean storageReady; |
| | | |
| | | @PostConstruct |
| | | /** |
| | | * 启动时探测聊天存储表是否已创建。 |
| | | * 如果表存在则走数据库持久化,否则回退到本地内存缓存,保证开发和缺表场景可继续运行。 |
| | | */ |
| | | public void initStorageMode() { |
| | | storageReady = detectStorageTables(); |
| | | } |
| | | |
| | | @Override |
| | | /** |
| | | * 读取用户会话列表。 |
| | | * 数据库存储模式直接查表,内存模式则从本地缓存取出并按最近更新时间排序。 |
| | | */ |
| | | public synchronized List<AiChatSession> listSessions(Long tenantId, Long userId) { |
| | | if (useDatabaseStorage()) { |
| | | return aiChatSessionMapper.selectList(new LambdaQueryWrapper<AiChatSession>() |
| | | .eq(AiChatSession::getTenantId, tenantId) |
| | | .eq(AiChatSession::getUserId, userId) |
| | | .orderByDesc(AiChatSession::getUpdateTime, AiChatSession::getCreateTime)); |
| | | } |
| | | List<AiChatSession> sessions = getSessions(tenantId, userId); |
| | | sessions.sort(Comparator.comparing(AiChatSession::getUpdateTime, Comparator.nullsLast(Date::compareTo)).reversed()); |
| | | return sessions; |
| | | } |
| | | |
| | | @Override |
| | | /** |
| | | * 创建新会话,并初始化标题、模型和时间戳。 |
| | | */ |
| | | public synchronized AiChatSession createSession(Long tenantId, Long userId, String title, String modelCode) { |
| | | List<AiChatSession> sessions = getSessions(tenantId, userId); |
| | | List<AiChatSession> sessions = useDatabaseStorage() ? listSessions(tenantId, userId) : getSessions(tenantId, userId); |
| | | Date now = new Date(); |
| | | AiChatSession session = new AiChatSession() |
| | | .setId(UUID.randomUUID().toString().replace("-", "")) |
| | | .setTenantId(tenantId) |
| | | .setUserId(userId) |
| | | .setTitle(resolveTitle(title, sessions.size() + 1)) |
| | | .setModelCode(resolveModelCode(modelCode)) |
| | | .setCreateTime(now) |
| | | .setUpdateTime(now) |
| | | .setLastMessageAt(now); |
| | | .setLastMessageAt(now) |
| | | .setStatus(1) |
| | | .setDeleted(0); |
| | | if (useDatabaseStorage()) { |
| | | aiChatSessionMapper.insert(session); |
| | | return session; |
| | | } |
| | | sessions.add(0, session); |
| | | saveSessions(tenantId, userId, sessions); |
| | | saveMessages(session.getId(), new ArrayList<>()); |
| | |
| | | } |
| | | |
| | | @Override |
| | | /** |
| | | * 确保会话存在;如果会话已存在但模型发生变化,会同步更新会话记录。 |
| | | */ |
| | | public synchronized AiChatSession ensureSession(Long tenantId, Long userId, String sessionId, String modelCode) { |
| | | AiChatSession session = getSession(tenantId, userId, sessionId); |
| | | if (session == null) { |
| | |
| | | } |
| | | |
| | | @Override |
| | | /** |
| | | * 安全读取会话,并校验租户与用户归属。 |
| | | */ |
| | | public synchronized AiChatSession getSession(Long tenantId, Long userId, String sessionId) { |
| | | if (sessionId == null || sessionId.trim().isEmpty()) { |
| | | return null; |
| | | } |
| | | if (useDatabaseStorage()) { |
| | | AiChatSession session = aiChatSessionMapper.selectById(sessionId); |
| | | if (session == null) { |
| | | return null; |
| | | } |
| | | if (!Objects.equals(tenantId, session.getTenantId()) || !Objects.equals(userId, session.getUserId())) { |
| | | return null; |
| | | } |
| | | return session; |
| | | } |
| | | for (AiChatSession session : getSessions(tenantId, userId)) { |
| | | if (sessionId.equals(session.getId())) { |
| | |
| | | } |
| | | |
| | | @Override |
| | | /** |
| | | * 更新会话标题。 |
| | | */ |
| | | public synchronized AiChatSession renameSession(Long tenantId, Long userId, String sessionId, String title) { |
| | | AiChatSession session = getSession(tenantId, userId, sessionId); |
| | | if (session == null) { |
| | |
| | | } |
| | | |
| | | @Override |
| | | /** |
| | | * 删除会话及其关联消息,同时清理停止标记缓存。 |
| | | */ |
| | | public synchronized void removeSession(Long tenantId, Long userId, String sessionId) { |
| | | if (useDatabaseStorage()) { |
| | | AiChatSession session = getSession(tenantId, userId, sessionId); |
| | | if (session != null) { |
| | | aiChatMessageMapper.delete(new LambdaQueryWrapper<AiChatMessage>() |
| | | .eq(AiChatMessage::getSessionId, sessionId)); |
| | | aiChatSessionMapper.deleteById(sessionId); |
| | | } |
| | | LOCAL_STOP_CACHE.remove(sessionId); |
| | | return; |
| | | } |
| | | List<AiChatSession> sessions = getSessions(tenantId, userId); |
| | | sessions.removeIf(session -> sessionId.equals(session.getId())); |
| | | saveSessions(tenantId, userId, sessions); |
| | |
| | | } |
| | | |
| | | @Override |
| | | /** |
| | | * 查询会话的完整消息历史。 |
| | | */ |
| | | public synchronized List<AiChatMessage> listMessages(Long tenantId, Long userId, String sessionId) { |
| | | AiChatSession session = getSession(tenantId, userId, sessionId); |
| | | if (session == null) { |
| | | return new ArrayList<>(); |
| | | } |
| | | if (useDatabaseStorage()) { |
| | | return aiChatMessageMapper.selectList(new LambdaQueryWrapper<AiChatMessage>() |
| | | .eq(AiChatMessage::getSessionId, sessionId) |
| | | .orderByAsc(AiChatMessage::getCreateTime, AiChatMessage::getId)); |
| | | } |
| | | return getMessages(sessionId); |
| | | } |
| | | |
| | | @Override |
| | | /** |
| | | * 截取最近若干条消息作为模型上下文,避免每次都把完整历史发送给模型。 |
| | | */ |
| | | public synchronized List<AiChatMessage> listContextMessages(Long tenantId, Long userId, String sessionId, int maxCount) { |
| | | List<AiChatMessage> messages = listMessages(tenantId, userId, sessionId); |
| | | if (messages.size() <= maxCount) { |
| | |
| | | } |
| | | |
| | | @Override |
| | | /** |
| | | * 追加一条消息,并同步刷新会话摘要、活跃时间和默认标题。 |
| | | */ |
| | | public synchronized AiChatMessage appendMessage(Long tenantId, Long userId, String sessionId, String role, String content, String modelCode) { |
| | | AiChatSession session = getSession(tenantId, userId, sessionId); |
| | | if (session == null) { |
| | |
| | | List<AiChatMessage> messages = getMessages(sessionId); |
| | | AiChatMessage message = new AiChatMessage() |
| | | .setId(UUID.randomUUID().toString().replace("-", "")) |
| | | .setTenantId(tenantId) |
| | | .setUserId(userId) |
| | | .setSessionId(sessionId) |
| | | .setRole(role) |
| | | .setContent(content) |
| | | .setModelCode(resolveModelCode(modelCode)) |
| | | .setCreateTime(new Date()); |
| | | messages.add(message); |
| | | saveMessages(sessionId, messages); |
| | | .setCreateTime(new Date()) |
| | | .setStatus(1) |
| | | .setDeleted(0); |
| | | if (useDatabaseStorage()) { |
| | | aiChatMessageMapper.insert(message); |
| | | } else { |
| | | messages.add(message); |
| | | saveMessages(sessionId, messages); |
| | | } |
| | | session.setLastMessage(buildPreview(content)); |
| | | session.setLastMessageAt(message.getCreateTime()); |
| | | session.setUpdateTime(message.getCreateTime()); |
| | |
| | | } |
| | | |
| | | @Override |
| | | /** |
| | | * 清除停止生成标记。 |
| | | */ |
| | | public void clearStopFlag(String sessionId) { |
| | | LOCAL_STOP_CACHE.remove(sessionId); |
| | | } |
| | | |
| | | @Override |
| | | /** |
| | | * 标记会话需要停止生成。 |
| | | */ |
| | | public void requestStop(String sessionId) { |
| | | LOCAL_STOP_CACHE.put(sessionId, "1"); |
| | | } |
| | | |
| | | @Override |
| | | /** |
| | | * 读取停止生成标记。 |
| | | */ |
| | | public boolean isStopRequested(String sessionId) { |
| | | String stopFlag = LOCAL_STOP_CACHE.get(sessionId); |
| | | return "1".equals(stopFlag); |
| | | } |
| | | |
| | | /** |
| | | * 从内存缓存中读取当前用户的会话列表。 |
| | | */ |
| | | private List<AiChatSession> getSessions(Long tenantId, Long userId) { |
| | | String ownerKey = buildOwnerKey(tenantId, userId); |
| | | List<AiChatSession> sessions = LOCAL_SESSION_CACHE.get(ownerKey); |
| | | return sessions == null ? new ArrayList<>() : new ArrayList<>(sessions); |
| | | } |
| | | |
| | | /** |
| | | * 将会话列表写回本地缓存。 |
| | | */ |
| | | private void saveSessions(Long tenantId, Long userId, List<AiChatSession> sessions) { |
| | | String ownerKey = buildOwnerKey(tenantId, userId); |
| | | List<AiChatSession> cachedSessions = new ArrayList<>(sessions); |
| | | LOCAL_SESSION_CACHE.put(ownerKey, cachedSessions); |
| | | } |
| | | |
| | | /** |
| | | * 从内存缓存中读取指定会话的消息列表。 |
| | | */ |
| | | private List<AiChatMessage> getMessages(String sessionId) { |
| | | List<AiChatMessage> messages = LOCAL_MESSAGE_CACHE.get(sessionId); |
| | | return messages == null ? new ArrayList<>() : new ArrayList<>(messages); |
| | | } |
| | | |
| | | /** |
| | | * 将消息列表写回本地缓存。 |
| | | */ |
| | | private void saveMessages(String sessionId, List<AiChatMessage> messages) { |
| | | List<AiChatMessage> cachedMessages = new ArrayList<>(messages); |
| | | LOCAL_MESSAGE_CACHE.put(sessionId, cachedMessages); |
| | | } |
| | | |
| | | /** |
| | | * 按存储模式刷新单个会话记录。 |
| | | */ |
| | | private void refreshSession(Long tenantId, Long userId, AiChatSession target) { |
| | | if (useDatabaseStorage()) { |
| | | aiChatSessionMapper.updateById(target); |
| | | return; |
| | | } |
| | | List<AiChatSession> sessions = getSessions(tenantId, userId); |
| | | for (int i = 0; i < sessions.size(); i++) { |
| | | if (target.getId().equals(sessions.get(i).getId())) { |
| | |
| | | saveSessions(tenantId, userId, sessions); |
| | | } |
| | | |
| | | /** |
| | | * 组装租户与用户维度的本地缓存 key。 |
| | | */ |
| | | private String buildOwnerKey(Long tenantId, Long userId) { |
| | | return String.valueOf(tenantId) + ":" + String.valueOf(userId); |
| | | } |
| | | |
| | | /** |
| | | * 解析本次消息使用的模型编码;为空时回退到系统默认模型。 |
| | | */ |
| | | private String resolveModelCode(String modelCode) { |
| | | return modelCode == null || modelCode.trim().isEmpty() ? aiRuntimeConfigService.resolveDefaultModelCode() : modelCode; |
| | | } |
| | | |
| | | /** |
| | | * 生成会话标题,未显式传标题时使用“新对话 N”。 |
| | | */ |
| | | private String resolveTitle(String title, int index) { |
| | | if (title == null || title.trim().isEmpty()) { |
| | | return "新对话 " + index; |
| | |
| | | return buildPreview(title); |
| | | } |
| | | |
| | | /** |
| | | * 将用户输入压缩成适合作为标题或最后消息预览的短文本。 |
| | | */ |
| | | private String buildPreview(String content) { |
| | | if (content == null || content.trim().isEmpty()) { |
| | | return "新对话"; |
| | |
| | | return normalized.length() > 24 ? normalized.substring(0, 24) : normalized; |
| | | } |
| | | |
| | | /** |
| | | * 判断当前是否可以使用数据库持久化聊天数据。 |
| | | */ |
| | | private boolean useDatabaseStorage() { |
| | | return storageReady || (storageReady = detectStorageTables()); |
| | | } |
| | | |
| | | /** |
| | | * 检查聊天存储所需表是否已经存在。 |
| | | */ |
| | | private boolean detectStorageTables() { |
| | | try (Connection connection = dataSource.getConnection()) { |
| | | return tableExists(connection, SESSION_TABLE_NAME) && tableExists(connection, MESSAGE_TABLE_NAME); |
| | | } catch (Exception ignore) { |
| | | return false; |
| | | } |
| | | } |
| | | |
| | | /** |
| | | * 判断指定表名是否在当前数据库中存在。 |
| | | */ |
| | | private boolean tableExists(Connection connection, String tableName) throws Exception { |
| | | if (tableName == null || tableName.trim().isEmpty()) { |
| | | return false; |
| | | } |
| | | String[] candidates = new String[]{tableName, tableName.toUpperCase(), tableName.toLowerCase()}; |
| | | for (String candidate : candidates) { |
| | | try (ResultSet resultSet = connection.getMetaData().getTables(connection.getCatalog(), null, candidate, null)) { |
| | | if (resultSet.next()) { |
| | | return true; |
| | | } |
| | | } |
| | | } |
| | | return false; |
| | | } |
| | | |
| | | } |
| | | |