#AI
zhou zhou
11 小时以前 51877df13075ad10ef51107f15bcd21f1661febe
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiSessionServiceImpl.java
@@ -1,16 +1,24 @@
package com.vincent.rsf.server.ai.service.impl;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.vincent.rsf.server.ai.mapper.AiChatMessageMapper;
import com.vincent.rsf.server.ai.mapper.AiChatSessionMapper;
import com.vincent.rsf.server.ai.model.AiChatMessage;
import com.vincent.rsf.server.ai.model.AiChatSession;
import com.vincent.rsf.server.ai.service.AiRuntimeConfigService;
import com.vincent.rsf.server.ai.service.AiSessionService;
import org.springframework.stereotype.Service;
import javax.annotation.PostConstruct;
import javax.annotation.Resource;
import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.ResultSet;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.Date;
import java.util.List;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
@@ -21,28 +29,68 @@
    private static final ConcurrentMap<String, List<AiChatSession>> LOCAL_SESSION_CACHE = new ConcurrentHashMap<>();
    private static final ConcurrentMap<String, List<AiChatMessage>> LOCAL_MESSAGE_CACHE = new ConcurrentHashMap<>();
    private static final ConcurrentMap<String, String> LOCAL_STOP_CACHE = new ConcurrentHashMap<>();
    private static final String SESSION_TABLE_NAME = "sys_ai_chat_session";
    private static final String MESSAGE_TABLE_NAME = "sys_ai_chat_message";
    @Resource
    private AiRuntimeConfigService aiRuntimeConfigService;
    @Resource
    private AiChatSessionMapper aiChatSessionMapper;
    @Resource
    private AiChatMessageMapper aiChatMessageMapper;
    @Resource
    private DataSource dataSource;
    private volatile boolean storageReady;
    @PostConstruct
    /**
     * 启动时探测聊天存储表是否已创建。
     * 如果表存在则走数据库持久化,否则回退到本地内存缓存,保证开发和缺表场景可继续运行。
     */
    public void initStorageMode() {
        storageReady = detectStorageTables();
    }
    @Override
    /**
     * 读取用户会话列表。
     * 数据库存储模式直接查表,内存模式则从本地缓存取出并按最近更新时间排序。
     */
    public synchronized List<AiChatSession> listSessions(Long tenantId, Long userId) {
        if (useDatabaseStorage()) {
            return aiChatSessionMapper.selectList(new LambdaQueryWrapper<AiChatSession>()
                    .eq(AiChatSession::getTenantId, tenantId)
                    .eq(AiChatSession::getUserId, userId)
                    .orderByDesc(AiChatSession::getUpdateTime, AiChatSession::getCreateTime));
        }
        List<AiChatSession> sessions = getSessions(tenantId, userId);
        sessions.sort(Comparator.comparing(AiChatSession::getUpdateTime, Comparator.nullsLast(Date::compareTo)).reversed());
        return sessions;
    }
    @Override
    /**
     * 创建新会话,并初始化标题、模型和时间戳。
     */
    public synchronized AiChatSession createSession(Long tenantId, Long userId, String title, String modelCode) {
        List<AiChatSession> sessions = getSessions(tenantId, userId);
        List<AiChatSession> sessions = useDatabaseStorage() ? listSessions(tenantId, userId) : getSessions(tenantId, userId);
        Date now = new Date();
        AiChatSession session = new AiChatSession()
                .setId(UUID.randomUUID().toString().replace("-", ""))
                .setTenantId(tenantId)
                .setUserId(userId)
                .setTitle(resolveTitle(title, sessions.size() + 1))
                .setModelCode(resolveModelCode(modelCode))
                .setCreateTime(now)
                .setUpdateTime(now)
                .setLastMessageAt(now);
                .setLastMessageAt(now)
                .setStatus(1)
                .setDeleted(0);
        if (useDatabaseStorage()) {
            aiChatSessionMapper.insert(session);
            return session;
        }
        sessions.add(0, session);
        saveSessions(tenantId, userId, sessions);
        saveMessages(session.getId(), new ArrayList<>());
@@ -50,6 +98,9 @@
    }
    @Override
    /**
     * 确保会话存在;如果会话已存在但模型发生变化,会同步更新会话记录。
     */
    public synchronized AiChatSession ensureSession(Long tenantId, Long userId, String sessionId, String modelCode) {
        AiChatSession session = getSession(tenantId, userId, sessionId);
        if (session == null) {
@@ -64,9 +115,22 @@
    }
    @Override
    /**
     * 安全读取会话,并校验租户与用户归属。
     */
    public synchronized AiChatSession getSession(Long tenantId, Long userId, String sessionId) {
        if (sessionId == null || sessionId.trim().isEmpty()) {
            return null;
        }
        if (useDatabaseStorage()) {
            AiChatSession session = aiChatSessionMapper.selectById(sessionId);
            if (session == null) {
                return null;
            }
            if (!Objects.equals(tenantId, session.getTenantId()) || !Objects.equals(userId, session.getUserId())) {
                return null;
            }
            return session;
        }
        for (AiChatSession session : getSessions(tenantId, userId)) {
            if (sessionId.equals(session.getId())) {
@@ -77,6 +141,9 @@
    }
    @Override
    /**
     * 更新会话标题。
     */
    public synchronized AiChatSession renameSession(Long tenantId, Long userId, String sessionId, String title) {
        AiChatSession session = getSession(tenantId, userId, sessionId);
        if (session == null) {
@@ -89,7 +156,20 @@
    }
    @Override
    /**
     * 删除会话及其关联消息,同时清理停止标记缓存。
     */
    public synchronized void removeSession(Long tenantId, Long userId, String sessionId) {
        if (useDatabaseStorage()) {
            AiChatSession session = getSession(tenantId, userId, sessionId);
            if (session != null) {
                aiChatMessageMapper.delete(new LambdaQueryWrapper<AiChatMessage>()
                        .eq(AiChatMessage::getSessionId, sessionId));
                aiChatSessionMapper.deleteById(sessionId);
            }
            LOCAL_STOP_CACHE.remove(sessionId);
            return;
        }
        List<AiChatSession> sessions = getSessions(tenantId, userId);
        sessions.removeIf(session -> sessionId.equals(session.getId()));
        saveSessions(tenantId, userId, sessions);
@@ -98,15 +178,26 @@
    }
    @Override
    /**
     * 查询会话的完整消息历史。
     */
    public synchronized List<AiChatMessage> listMessages(Long tenantId, Long userId, String sessionId) {
        AiChatSession session = getSession(tenantId, userId, sessionId);
        if (session == null) {
            return new ArrayList<>();
        }
        if (useDatabaseStorage()) {
            return aiChatMessageMapper.selectList(new LambdaQueryWrapper<AiChatMessage>()
                    .eq(AiChatMessage::getSessionId, sessionId)
                    .orderByAsc(AiChatMessage::getCreateTime, AiChatMessage::getId));
        }
        return getMessages(sessionId);
    }
    @Override
    /**
     * 截取最近若干条消息作为模型上下文,避免每次都把完整历史发送给模型。
     */
    public synchronized List<AiChatMessage> listContextMessages(Long tenantId, Long userId, String sessionId, int maxCount) {
        List<AiChatMessage> messages = listMessages(tenantId, userId, sessionId);
        if (messages.size() <= maxCount) {
@@ -116,6 +207,9 @@
    }
    @Override
    /**
     * 追加一条消息,并同步刷新会话摘要、活跃时间和默认标题。
     */
    public synchronized AiChatMessage appendMessage(Long tenantId, Long userId, String sessionId, String role, String content, String modelCode) {
        AiChatSession session = getSession(tenantId, userId, sessionId);
        if (session == null) {
@@ -124,13 +218,21 @@
        List<AiChatMessage> messages = getMessages(sessionId);
        AiChatMessage message = new AiChatMessage()
                .setId(UUID.randomUUID().toString().replace("-", ""))
                .setTenantId(tenantId)
                .setUserId(userId)
                .setSessionId(sessionId)
                .setRole(role)
                .setContent(content)
                .setModelCode(resolveModelCode(modelCode))
                .setCreateTime(new Date());
        messages.add(message);
        saveMessages(sessionId, messages);
                .setCreateTime(new Date())
                .setStatus(1)
                .setDeleted(0);
        if (useDatabaseStorage()) {
            aiChatMessageMapper.insert(message);
        } else {
            messages.add(message);
            saveMessages(sessionId, messages);
        }
        session.setLastMessage(buildPreview(content));
        session.setLastMessageAt(message.getCreateTime());
        session.setUpdateTime(message.getCreateTime());
@@ -145,44 +247,72 @@
    }
    @Override
    /**
     * 清除停止生成标记。
     */
    public void clearStopFlag(String sessionId) {
        LOCAL_STOP_CACHE.remove(sessionId);
    }
    @Override
    /**
     * 标记会话需要停止生成。
     */
    public void requestStop(String sessionId) {
        LOCAL_STOP_CACHE.put(sessionId, "1");
    }
    @Override
    /**
     * 读取停止生成标记。
     */
    public boolean isStopRequested(String sessionId) {
        String stopFlag = LOCAL_STOP_CACHE.get(sessionId);
        return "1".equals(stopFlag);
    }
    /**
     * 从内存缓存中读取当前用户的会话列表。
     */
    private List<AiChatSession> getSessions(Long tenantId, Long userId) {
        String ownerKey = buildOwnerKey(tenantId, userId);
        List<AiChatSession> sessions = LOCAL_SESSION_CACHE.get(ownerKey);
        return sessions == null ? new ArrayList<>() : new ArrayList<>(sessions);
    }
    /**
     * 将会话列表写回本地缓存。
     */
    private void saveSessions(Long tenantId, Long userId, List<AiChatSession> sessions) {
        String ownerKey = buildOwnerKey(tenantId, userId);
        List<AiChatSession> cachedSessions = new ArrayList<>(sessions);
        LOCAL_SESSION_CACHE.put(ownerKey, cachedSessions);
    }
    /**
     * 从内存缓存中读取指定会话的消息列表。
     */
    private List<AiChatMessage> getMessages(String sessionId) {
        List<AiChatMessage> messages = LOCAL_MESSAGE_CACHE.get(sessionId);
        return messages == null ? new ArrayList<>() : new ArrayList<>(messages);
    }
    /**
     * 将消息列表写回本地缓存。
     */
    private void saveMessages(String sessionId, List<AiChatMessage> messages) {
        List<AiChatMessage> cachedMessages = new ArrayList<>(messages);
        LOCAL_MESSAGE_CACHE.put(sessionId, cachedMessages);
    }
    /**
     * 按存储模式刷新单个会话记录。
     */
    private void refreshSession(Long tenantId, Long userId, AiChatSession target) {
        if (useDatabaseStorage()) {
            aiChatSessionMapper.updateById(target);
            return;
        }
        List<AiChatSession> sessions = getSessions(tenantId, userId);
        for (int i = 0; i < sessions.size(); i++) {
            if (target.getId().equals(sessions.get(i).getId())) {
@@ -195,14 +325,23 @@
        saveSessions(tenantId, userId, sessions);
    }
    /**
     * 组装租户与用户维度的本地缓存 key。
     */
    private String buildOwnerKey(Long tenantId, Long userId) {
        return String.valueOf(tenantId) + ":" + String.valueOf(userId);
    }
    /**
     * 解析本次消息使用的模型编码;为空时回退到系统默认模型。
     */
    private String resolveModelCode(String modelCode) {
        return modelCode == null || modelCode.trim().isEmpty() ? aiRuntimeConfigService.resolveDefaultModelCode() : modelCode;
    }
    /**
     * 生成会话标题,未显式传标题时使用“新对话 N”。
     */
    private String resolveTitle(String title, int index) {
        if (title == null || title.trim().isEmpty()) {
            return "新对话 " + index;
@@ -210,6 +349,9 @@
        return buildPreview(title);
    }
    /**
     * 将用户输入压缩成适合作为标题或最后消息预览的短文本。
     */
    private String buildPreview(String content) {
        if (content == null || content.trim().isEmpty()) {
            return "新对话";
@@ -218,4 +360,41 @@
        return normalized.length() > 24 ? normalized.substring(0, 24) : normalized;
    }
    /**
     * 判断当前是否可以使用数据库持久化聊天数据。
     */
    private boolean useDatabaseStorage() {
        return storageReady || (storageReady = detectStorageTables());
    }
    /**
     * 检查聊天存储所需表是否已经存在。
     */
    private boolean detectStorageTables() {
        try (Connection connection = dataSource.getConnection()) {
            return tableExists(connection, SESSION_TABLE_NAME) && tableExists(connection, MESSAGE_TABLE_NAME);
        } catch (Exception ignore) {
            return false;
        }
    }
    /**
     * 判断指定表名是否在当前数据库中存在。
     */
    private boolean tableExists(Connection connection, String tableName) throws Exception {
        if (tableName == null || tableName.trim().isEmpty()) {
            return false;
        }
        String[] candidates = new String[]{tableName, tableName.toUpperCase(), tableName.toLowerCase()};
        for (String candidate : candidates) {
            try (ResultSet resultSet = connection.getMetaData().getTables(connection.getCatalog(), null, candidate, null)) {
                if (resultSet.next()) {
                    return true;
                }
            }
        }
        return false;
    }
}