#AI
zhou zhou
4 小时以前 51877df13075ad10ef51107f15bcd21f1661febe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
package com.vincent.rsf.server.ai.service.impl;
 
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.vincent.rsf.server.ai.mapper.AiChatMessageMapper;
import com.vincent.rsf.server.ai.mapper.AiChatSessionMapper;
import com.vincent.rsf.server.ai.model.AiChatMessage;
import com.vincent.rsf.server.ai.model.AiChatSession;
import com.vincent.rsf.server.ai.service.AiRuntimeConfigService;
import com.vincent.rsf.server.ai.service.AiSessionService;
import org.springframework.stereotype.Service;
 
import javax.annotation.PostConstruct;
import javax.annotation.Resource;
import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.ResultSet;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.Date;
import java.util.List;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
 
@Service
public class AiSessionServiceImpl implements AiSessionService {
 
    private static final ConcurrentMap<String, List<AiChatSession>> LOCAL_SESSION_CACHE = new ConcurrentHashMap<>();
    private static final ConcurrentMap<String, List<AiChatMessage>> LOCAL_MESSAGE_CACHE = new ConcurrentHashMap<>();
    private static final ConcurrentMap<String, String> LOCAL_STOP_CACHE = new ConcurrentHashMap<>();
    private static final String SESSION_TABLE_NAME = "sys_ai_chat_session";
    private static final String MESSAGE_TABLE_NAME = "sys_ai_chat_message";
 
    @Resource
    private AiRuntimeConfigService aiRuntimeConfigService;
    @Resource
    private AiChatSessionMapper aiChatSessionMapper;
    @Resource
    private AiChatMessageMapper aiChatMessageMapper;
    @Resource
    private DataSource dataSource;
 
    private volatile boolean storageReady;
 
    @PostConstruct
    /**
     * 启动时探测聊天存储表是否已创建。
     * 如果表存在则走数据库持久化,否则回退到本地内存缓存,保证开发和缺表场景可继续运行。
     */
    public void initStorageMode() {
        storageReady = detectStorageTables();
    }
 
    @Override
    /**
     * 读取用户会话列表。
     * 数据库存储模式直接查表,内存模式则从本地缓存取出并按最近更新时间排序。
     */
    public synchronized List<AiChatSession> listSessions(Long tenantId, Long userId) {
        if (useDatabaseStorage()) {
            return aiChatSessionMapper.selectList(new LambdaQueryWrapper<AiChatSession>()
                    .eq(AiChatSession::getTenantId, tenantId)
                    .eq(AiChatSession::getUserId, userId)
                    .orderByDesc(AiChatSession::getUpdateTime, AiChatSession::getCreateTime));
        }
        List<AiChatSession> sessions = getSessions(tenantId, userId);
        sessions.sort(Comparator.comparing(AiChatSession::getUpdateTime, Comparator.nullsLast(Date::compareTo)).reversed());
        return sessions;
    }
 
    @Override
    /**
     * 创建新会话,并初始化标题、模型和时间戳。
     */
    public synchronized AiChatSession createSession(Long tenantId, Long userId, String title, String modelCode) {
        List<AiChatSession> sessions = useDatabaseStorage() ? listSessions(tenantId, userId) : getSessions(tenantId, userId);
        Date now = new Date();
        AiChatSession session = new AiChatSession()
                .setId(UUID.randomUUID().toString().replace("-", ""))
                .setTenantId(tenantId)
                .setUserId(userId)
                .setTitle(resolveTitle(title, sessions.size() + 1))
                .setModelCode(resolveModelCode(modelCode))
                .setCreateTime(now)
                .setUpdateTime(now)
                .setLastMessageAt(now)
                .setStatus(1)
                .setDeleted(0);
        if (useDatabaseStorage()) {
            aiChatSessionMapper.insert(session);
            return session;
        }
        sessions.add(0, session);
        saveSessions(tenantId, userId, sessions);
        saveMessages(session.getId(), new ArrayList<>());
        return session;
    }
 
    @Override
    /**
     * 确保会话存在;如果会话已存在但模型发生变化,会同步更新会话记录。
     */
    public synchronized AiChatSession ensureSession(Long tenantId, Long userId, String sessionId, String modelCode) {
        AiChatSession session = getSession(tenantId, userId, sessionId);
        if (session == null) {
            return createSession(tenantId, userId, null, modelCode);
        }
        if (modelCode != null && !modelCode.trim().isEmpty() && !modelCode.equals(session.getModelCode())) {
            session.setModelCode(modelCode);
            session.setUpdateTime(new Date());
            refreshSession(tenantId, userId, session);
        }
        return session;
    }
 
    @Override
    /**
     * 安全读取会话,并校验租户与用户归属。
     */
    public synchronized AiChatSession getSession(Long tenantId, Long userId, String sessionId) {
        if (sessionId == null || sessionId.trim().isEmpty()) {
            return null;
        }
        if (useDatabaseStorage()) {
            AiChatSession session = aiChatSessionMapper.selectById(sessionId);
            if (session == null) {
                return null;
            }
            if (!Objects.equals(tenantId, session.getTenantId()) || !Objects.equals(userId, session.getUserId())) {
                return null;
            }
            return session;
        }
        for (AiChatSession session : getSessions(tenantId, userId)) {
            if (sessionId.equals(session.getId())) {
                return session;
            }
        }
        return null;
    }
 
    @Override
    /**
     * 更新会话标题。
     */
    public synchronized AiChatSession renameSession(Long tenantId, Long userId, String sessionId, String title) {
        AiChatSession session = getSession(tenantId, userId, sessionId);
        if (session == null) {
            return null;
        }
        session.setTitle(resolveTitle(title, 1));
        session.setUpdateTime(new Date());
        refreshSession(tenantId, userId, session);
        return session;
    }
 
    @Override
    /**
     * 删除会话及其关联消息,同时清理停止标记缓存。
     */
    public synchronized void removeSession(Long tenantId, Long userId, String sessionId) {
        if (useDatabaseStorage()) {
            AiChatSession session = getSession(tenantId, userId, sessionId);
            if (session != null) {
                aiChatMessageMapper.delete(new LambdaQueryWrapper<AiChatMessage>()
                        .eq(AiChatMessage::getSessionId, sessionId));
                aiChatSessionMapper.deleteById(sessionId);
            }
            LOCAL_STOP_CACHE.remove(sessionId);
            return;
        }
        List<AiChatSession> sessions = getSessions(tenantId, userId);
        sessions.removeIf(session -> sessionId.equals(session.getId()));
        saveSessions(tenantId, userId, sessions);
        LOCAL_MESSAGE_CACHE.remove(sessionId);
        LOCAL_STOP_CACHE.remove(sessionId);
    }
 
    @Override
    /**
     * 查询会话的完整消息历史。
     */
    public synchronized List<AiChatMessage> listMessages(Long tenantId, Long userId, String sessionId) {
        AiChatSession session = getSession(tenantId, userId, sessionId);
        if (session == null) {
            return new ArrayList<>();
        }
        if (useDatabaseStorage()) {
            return aiChatMessageMapper.selectList(new LambdaQueryWrapper<AiChatMessage>()
                    .eq(AiChatMessage::getSessionId, sessionId)
                    .orderByAsc(AiChatMessage::getCreateTime, AiChatMessage::getId));
        }
        return getMessages(sessionId);
    }
 
    @Override
    /**
     * 截取最近若干条消息作为模型上下文,避免每次都把完整历史发送给模型。
     */
    public synchronized List<AiChatMessage> listContextMessages(Long tenantId, Long userId, String sessionId, int maxCount) {
        List<AiChatMessage> messages = listMessages(tenantId, userId, sessionId);
        if (messages.size() <= maxCount) {
            return messages;
        }
        return new ArrayList<>(messages.subList(messages.size() - maxCount, messages.size()));
    }
 
    @Override
    /**
     * 追加一条消息,并同步刷新会话摘要、活跃时间和默认标题。
     */
    public synchronized AiChatMessage appendMessage(Long tenantId, Long userId, String sessionId, String role, String content, String modelCode) {
        AiChatSession session = getSession(tenantId, userId, sessionId);
        if (session == null) {
            return null;
        }
        List<AiChatMessage> messages = getMessages(sessionId);
        AiChatMessage message = new AiChatMessage()
                .setId(UUID.randomUUID().toString().replace("-", ""))
                .setTenantId(tenantId)
                .setUserId(userId)
                .setSessionId(sessionId)
                .setRole(role)
                .setContent(content)
                .setModelCode(resolveModelCode(modelCode))
                .setCreateTime(new Date())
                .setStatus(1)
                .setDeleted(0);
        if (useDatabaseStorage()) {
            aiChatMessageMapper.insert(message);
        } else {
            messages.add(message);
            saveMessages(sessionId, messages);
        }
        session.setLastMessage(buildPreview(content));
        session.setLastMessageAt(message.getCreateTime());
        session.setUpdateTime(message.getCreateTime());
        if (modelCode != null && !modelCode.trim().isEmpty()) {
            session.setModelCode(modelCode);
        }
        if ((session.getTitle() == null || session.getTitle().startsWith("新对话")) && "user".equalsIgnoreCase(role)) {
            session.setTitle(buildPreview(content));
        }
        refreshSession(tenantId, userId, session);
        return message;
    }
 
    @Override
    /**
     * 清除停止生成标记。
     */
    public void clearStopFlag(String sessionId) {
        LOCAL_STOP_CACHE.remove(sessionId);
    }
 
    @Override
    /**
     * 标记会话需要停止生成。
     */
    public void requestStop(String sessionId) {
        LOCAL_STOP_CACHE.put(sessionId, "1");
    }
 
    @Override
    /**
     * 读取停止生成标记。
     */
    public boolean isStopRequested(String sessionId) {
        String stopFlag = LOCAL_STOP_CACHE.get(sessionId);
        return "1".equals(stopFlag);
    }
 
    /**
     * 从内存缓存中读取当前用户的会话列表。
     */
    private List<AiChatSession> getSessions(Long tenantId, Long userId) {
        String ownerKey = buildOwnerKey(tenantId, userId);
        List<AiChatSession> sessions = LOCAL_SESSION_CACHE.get(ownerKey);
        return sessions == null ? new ArrayList<>() : new ArrayList<>(sessions);
    }
 
    /**
     * 将会话列表写回本地缓存。
     */
    private void saveSessions(Long tenantId, Long userId, List<AiChatSession> sessions) {
        String ownerKey = buildOwnerKey(tenantId, userId);
        List<AiChatSession> cachedSessions = new ArrayList<>(sessions);
        LOCAL_SESSION_CACHE.put(ownerKey, cachedSessions);
    }
 
    /**
     * 从内存缓存中读取指定会话的消息列表。
     */
    private List<AiChatMessage> getMessages(String sessionId) {
        List<AiChatMessage> messages = LOCAL_MESSAGE_CACHE.get(sessionId);
        return messages == null ? new ArrayList<>() : new ArrayList<>(messages);
    }
 
    /**
     * 将消息列表写回本地缓存。
     */
    private void saveMessages(String sessionId, List<AiChatMessage> messages) {
        List<AiChatMessage> cachedMessages = new ArrayList<>(messages);
        LOCAL_MESSAGE_CACHE.put(sessionId, cachedMessages);
    }
 
    /**
     * 按存储模式刷新单个会话记录。
     */
    private void refreshSession(Long tenantId, Long userId, AiChatSession target) {
        if (useDatabaseStorage()) {
            aiChatSessionMapper.updateById(target);
            return;
        }
        List<AiChatSession> sessions = getSessions(tenantId, userId);
        for (int i = 0; i < sessions.size(); i++) {
            if (target.getId().equals(sessions.get(i).getId())) {
                sessions.set(i, target);
                saveSessions(tenantId, userId, sessions);
                return;
            }
        }
        sessions.add(target);
        saveSessions(tenantId, userId, sessions);
    }
 
    /**
     * 组装租户与用户维度的本地缓存 key。
     */
    private String buildOwnerKey(Long tenantId, Long userId) {
        return String.valueOf(tenantId) + ":" + String.valueOf(userId);
    }
 
    /**
     * 解析本次消息使用的模型编码;为空时回退到系统默认模型。
     */
    private String resolveModelCode(String modelCode) {
        return modelCode == null || modelCode.trim().isEmpty() ? aiRuntimeConfigService.resolveDefaultModelCode() : modelCode;
    }
 
    /**
     * 生成会话标题,未显式传标题时使用“新对话 N”。
     */
    private String resolveTitle(String title, int index) {
        if (title == null || title.trim().isEmpty()) {
            return "新对话 " + index;
        }
        return buildPreview(title);
    }
 
    /**
     * 将用户输入压缩成适合作为标题或最后消息预览的短文本。
     */
    private String buildPreview(String content) {
        if (content == null || content.trim().isEmpty()) {
            return "新对话";
        }
        String normalized = content.replace("\r", " ").replace("\n", " ").trim();
        return normalized.length() > 24 ? normalized.substring(0, 24) : normalized;
    }
 
    /**
     * 判断当前是否可以使用数据库持久化聊天数据。
     */
    private boolean useDatabaseStorage() {
        return storageReady || (storageReady = detectStorageTables());
    }
 
    /**
     * 检查聊天存储所需表是否已经存在。
     */
    private boolean detectStorageTables() {
        try (Connection connection = dataSource.getConnection()) {
            return tableExists(connection, SESSION_TABLE_NAME) && tableExists(connection, MESSAGE_TABLE_NAME);
        } catch (Exception ignore) {
            return false;
        }
    }
 
    /**
     * 判断指定表名是否在当前数据库中存在。
     */
    private boolean tableExists(Connection connection, String tableName) throws Exception {
        if (tableName == null || tableName.trim().isEmpty()) {
            return false;
        }
        String[] candidates = new String[]{tableName, tableName.toUpperCase(), tableName.toLowerCase()};
        for (String candidate : candidates) {
            try (ResultSet resultSet = connection.getMetaData().getTables(connection.getCatalog(), null, candidate, null)) {
                if (resultSet.next()) {
                    return true;
                }
            }
        }
        return false;
    }
 
}