1个文件已删除
33个文件已添加
12个文件已修改
| | |
| | | <scope>system</scope> |
| | | <systemPath>${project.basedir}/src/main/resources/lib/RouteUtils.jar</systemPath> |
| | | </dependency> |
| | | <dependency> |
| | | <groupId>org.springframework.boot</groupId> |
| | | <artifactId>spring-boot-starter-test</artifactId> |
| | | <scope>test</scope> |
| | | </dependency> |
| | | <dependency> |
| | | <groupId>com.h2database</groupId> |
| | | <artifactId>h2</artifactId> |
| | | <scope>test</scope> |
| | | </dependency> |
| | | </dependencies> |
| | | |
| | | <build> |
| | |
| | | import com.baomidou.mybatisplus.core.mapper.BaseMapper; |
| | | import com.vincent.rsf.server.ai.entity.AiChatMessage; |
| | | import org.apache.ibatis.annotations.Mapper; |
| | | import org.apache.ibatis.annotations.Param; |
| | | |
| | | import java.util.List; |
| | | |
| | | @Mapper |
| | | public interface AiChatMessageMapper extends BaseMapper<AiChatMessage> { |
| | | |
| | | int insertBatch(@Param("list") List<AiChatMessage> messages); |
| | | |
| | | int softDeleteByIds(@Param("ids") List<Long> ids); |
| | | |
| | | int softDeleteBySessionId(@Param("sessionId") Long sessionId); |
| | | |
| | | List<AiChatMessage> selectLatestMessagesBySessionIds(@Param("sessionIds") List<Long> sessionIds); |
| | | } |
| | |
| | | import com.vincent.rsf.server.ai.entity.AiMcpCallLog; |
| | | import com.vincent.rsf.server.ai.mapper.AiCallLogMapper; |
| | | import com.vincent.rsf.server.ai.mapper.AiMcpCallLogMapper; |
| | | import com.vincent.rsf.server.ai.store.AiObserveStatsStore; |
| | | import com.vincent.rsf.server.ai.service.AiCallLogService; |
| | | import lombok.RequiredArgsConstructor; |
| | | import org.springframework.stereotype.Service; |
| | |
| | | private static final Pattern BEARER_PATTERN = Pattern.compile("(?i)(bearer\\s+)([a-z0-9._-]+)"); |
| | | |
| | | private final AiMcpCallLogMapper aiMcpCallLogMapper; |
| | | private final AiRedisSupport aiRedisSupport; |
| | | private final AiObserveStatsStore aiObserveStatsStore; |
| | | |
| | | @Override |
| | | public AiCallLog startCallLog(String requestId, Long sessionId, Long userId, Long tenantId, String promptCode, |
| | |
| | | .setCreateTime(now) |
| | | .setUpdateTime(now); |
| | | this.save(callLog); |
| | | aiRedisSupport.recordObserveCallStarted(tenantId); |
| | | aiObserveStatsStore.recordObserveCallStarted(tenantId); |
| | | return callLog; |
| | | } |
| | | |
| | |
| | | .set(AiCallLog::getUpdateTime, new Date())); |
| | | AiCallLog latest = this.getById(callLogId); |
| | | if (latest != null) { |
| | | aiRedisSupport.recordObserveCallFinished(latest.getTenantId(), status, elapsedMs, firstTokenLatencyMs, totalTokens); |
| | | aiObserveStatsStore.recordObserveCallFinished(latest.getTenantId(), status, elapsedMs, firstTokenLatencyMs, totalTokens); |
| | | } |
| | | } |
| | | |
| | |
| | | .set(AiCallLog::getUpdateTime, new Date())); |
| | | AiCallLog latest = this.getById(callLogId); |
| | | if (latest != null) { |
| | | aiRedisSupport.recordObserveCallFinished(latest.getTenantId(), status, elapsedMs, firstTokenLatencyMs, null); |
| | | aiObserveStatsStore.recordObserveCallFinished(latest.getTenantId(), status, elapsedMs, firstTokenLatencyMs, null); |
| | | } |
| | | } |
| | | |
| | |
| | | .setUserId(userId) |
| | | .setTenantId(tenantId) |
| | | .setCreateTime(new Date())); |
| | | aiRedisSupport.recordObserveToolCall(tenantId, toolName, status); |
| | | aiObserveStatsStore.recordObserveToolCall(tenantId, toolName, status); |
| | | } |
| | | |
| | | @Override |
| | | public AiObserveStatsDto getObserveStats(Long tenantId) { |
| | | return aiRedisSupport.getObserveStats(tenantId, () -> loadObserveStatsFromDatabase(tenantId)); |
| | | return aiObserveStatsStore.getObserveStats(tenantId, () -> loadObserveStatsFromDatabase(tenantId)); |
| | | } |
| | | |
| | | private AiObserveStatsDto loadObserveStatsFromDatabase(Long tenantId) { |
| | |
| | | package com.vincent.rsf.server.ai.service.impl; |
| | | |
| | | import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; |
| | | import com.vincent.rsf.framework.common.Cools; |
| | | import com.vincent.rsf.framework.exception.CoolException; |
| | | import com.vincent.rsf.server.ai.config.AiDefaults; |
| | | import com.vincent.rsf.server.ai.dto.AiChatMemoryDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatMessageDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatSessionDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatSessionPinRequest; |
| | | import com.vincent.rsf.server.ai.dto.AiChatSessionRenameRequest; |
| | | import com.vincent.rsf.server.ai.dto.AiChatSessionDto; |
| | | import com.vincent.rsf.server.ai.entity.AiChatMessage; |
| | | import com.vincent.rsf.server.ai.entity.AiChatSession; |
| | | import com.vincent.rsf.server.ai.mapper.AiChatMessageMapper; |
| | | import com.vincent.rsf.server.ai.mapper.AiChatSessionMapper; |
| | | import com.vincent.rsf.server.ai.service.AiChatMemoryService; |
| | | import com.vincent.rsf.server.system.enums.StatusType; |
| | | import com.vincent.rsf.server.ai.service.impl.conversation.AiConversationCommandService; |
| | | import com.vincent.rsf.server.ai.service.impl.conversation.AiConversationQueryService; |
| | | import lombok.RequiredArgsConstructor; |
| | | import lombok.extern.slf4j.Slf4j; |
| | | import org.springframework.beans.factory.annotation.Qualifier; |
| | | import org.springframework.stereotype.Service; |
| | | import org.springframework.util.StringUtils; |
| | | |
| | | import java.util.ArrayList; |
| | | import java.util.Date; |
| | | import java.util.List; |
| | | import java.util.Set; |
| | | import java.util.concurrent.Executor; |
| | | import java.util.concurrent.ConcurrentHashMap; |
| | | |
| | | @Service |
| | | @Slf4j |
| | | @RequiredArgsConstructor |
| | | public class AiChatMemoryServiceImpl implements AiChatMemoryService { |
| | | |
| | | private final AiChatSessionMapper aiChatSessionMapper; |
| | | private final AiChatMessageMapper aiChatMessageMapper; |
| | | private final AiRedisSupport aiRedisSupport; |
| | | @Qualifier("aiMemoryTaskExecutor") |
| | | private final Executor aiMemoryTaskExecutor; |
| | | private final AiConversationQueryService aiConversationQueryService; |
| | | private final AiConversationCommandService aiConversationCommandService; |
| | | |
| | | /** |
| | | * 用两个本地集合把“同一个会话的摘要刷新”合并成串行任务,避免连续消息把重复任务塞满线程池。 |
| | | */ |
| | | private final Set<Long> refreshingSessionIds = ConcurrentHashMap.newKeySet(); |
| | | private final Set<Long> pendingRefreshSessionIds = ConcurrentHashMap.newKeySet(); |
| | | |
| | | /** |
| | | * 读取会话记忆快照。 |
| | | * 返回结果同时包含完整落库历史、短期记忆窗口以及摘要/事实记忆, |
| | | * 便于调用方按不同用途选择数据粒度。 |
| | | */ |
| | | @Override |
| | | public AiChatMemoryDto getMemory(Long userId, Long tenantId, String promptCode, Long sessionId) { |
| | | ensureIdentity(userId, tenantId); |
| | | String resolvedPromptCode = requirePromptCode(promptCode); |
| | | // 会话记忆属于典型“读多写少”数据,先走短 TTL 缓存能明显减轻抽屉初始化和切会话压力。 |
| | | AiChatMemoryDto cached = aiRedisSupport.getMemory(tenantId, userId, resolvedPromptCode, sessionId); |
| | | if (cached != null) { |
| | | return cached; |
| | | } |
| | | AiChatSession session = sessionId == null |
| | | ? findLatestSession(userId, tenantId, resolvedPromptCode) |
| | | : getSession(sessionId, userId, tenantId, resolvedPromptCode); |
| | | AiChatMemoryDto memory; |
| | | if (session == null) { |
| | | memory = AiChatMemoryDto.builder() |
| | | .sessionId(null) |
| | | .memorySummary(null) |
| | | .memoryFacts(null) |
| | | .recentMessageCount(0) |
| | | .persistedMessages(List.of()) |
| | | .shortMemoryMessages(List.of()) |
| | | .build(); |
| | | aiRedisSupport.cacheMemory(tenantId, userId, resolvedPromptCode, sessionId, memory); |
| | | return memory; |
| | | } |
| | | List<AiChatMessageDto> persistedMessages = listMessages(session.getId()); |
| | | List<AiChatMessageDto> shortMemoryMessages = tailMessagesByRounds(persistedMessages, AiDefaults.MEMORY_RECENT_ROUNDS); |
| | | memory = AiChatMemoryDto.builder() |
| | | .sessionId(session.getId()) |
| | | .memorySummary(session.getMemorySummary()) |
| | | .memoryFacts(session.getMemoryFacts()) |
| | | .recentMessageCount(shortMemoryMessages.size()) |
| | | .persistedMessages(persistedMessages) |
| | | .shortMemoryMessages(shortMemoryMessages) |
| | | .build(); |
| | | aiRedisSupport.cacheMemory(tenantId, userId, resolvedPromptCode, session.getId(), memory); |
| | | if (sessionId == null || !session.getId().equals(sessionId)) { |
| | | aiRedisSupport.cacheMemory(tenantId, userId, resolvedPromptCode, null, memory); |
| | | } |
| | | return memory; |
| | | return aiConversationQueryService.getMemory(userId, tenantId, promptCode, sessionId); |
| | | } |
| | | |
| | | /** |
| | | * 查询当前用户在某个 Prompt 下的会话列表。 |
| | | * 列表只返回用于侧边栏展示的摘要信息,不返回完整对话内容。 |
| | | */ |
| | | @Override |
| | | public List<AiChatSessionDto> listSessions(Long userId, Long tenantId, String promptCode, String keyword) { |
| | | ensureIdentity(userId, tenantId); |
| | | String resolvedPromptCode = requirePromptCode(promptCode); |
| | | List<AiChatSessionDto> cached = aiRedisSupport.getSessionList(tenantId, userId, resolvedPromptCode, keyword); |
| | | if (cached != null) { |
| | | return cached; |
| | | } |
| | | List<AiChatSession> sessions = aiChatSessionMapper.selectList(new LambdaQueryWrapper<AiChatSession>() |
| | | .eq(AiChatSession::getUserId, userId) |
| | | .eq(AiChatSession::getTenantId, tenantId) |
| | | .eq(AiChatSession::getPromptCode, resolvedPromptCode) |
| | | .eq(AiChatSession::getDeleted, 0) |
| | | .eq(AiChatSession::getStatus, StatusType.ENABLE.val) |
| | | .like(StringUtils.hasText(keyword), AiChatSession::getTitle, keyword == null ? null : keyword.trim()) |
| | | .orderByDesc(AiChatSession::getPinned) |
| | | .orderByDesc(AiChatSession::getLastMessageTime) |
| | | .orderByDesc(AiChatSession::getId)); |
| | | if (Cools.isEmpty(sessions)) { |
| | | aiRedisSupport.cacheSessionList(tenantId, userId, resolvedPromptCode, keyword, List.of()); |
| | | return List.of(); |
| | | } |
| | | List<AiChatSessionDto> result = new ArrayList<>(); |
| | | for (AiChatSession session : sessions) { |
| | | result.add(buildSessionDto(session)); |
| | | } |
| | | aiRedisSupport.cacheSessionList(tenantId, userId, resolvedPromptCode, keyword, result); |
| | | return result; |
| | | return aiConversationQueryService.listSessions(userId, tenantId, promptCode, keyword); |
| | | } |
| | | |
| | | /** |
| | | * 解析本轮请求应该落到哪个会话。 |
| | | * 如果前端带了 sessionId 则做归属校验并复用;否则自动创建新会话。 |
| | | */ |
| | | @Override |
| | | public AiChatSession resolveSession(Long userId, Long tenantId, String promptCode, Long sessionId, String titleSeed) { |
| | | ensureIdentity(userId, tenantId); |
| | | String resolvedPromptCode = requirePromptCode(promptCode); |
| | | if (sessionId != null) { |
| | | return getSession(sessionId, userId, tenantId, resolvedPromptCode); |
| | | } |
| | | Date now = new Date(); |
| | | AiChatSession session = new AiChatSession() |
| | | .setTitle(buildSessionTitle(titleSeed)) |
| | | .setPromptCode(resolvedPromptCode) |
| | | .setUserId(userId) |
| | | .setTenantId(tenantId) |
| | | .setLastMessageTime(now) |
| | | .setPinned(0) |
| | | .setStatus(StatusType.ENABLE.val) |
| | | .setDeleted(0) |
| | | .setCreateBy(userId) |
| | | .setCreateTime(now) |
| | | .setUpdateBy(userId) |
| | | .setUpdateTime(now); |
| | | aiChatSessionMapper.insert(session); |
| | | evictConversationCaches(tenantId, userId); |
| | | return session; |
| | | return aiConversationCommandService.resolveSession(userId, tenantId, promptCode, sessionId, titleSeed); |
| | | } |
| | | |
| | | /** |
| | | * 落库保存一整轮对话。 |
| | | * 这里会顺序写入本轮用户消息和模型回复,并在最后刷新会话标题和活跃时间。 |
| | | * 记忆画像改为后台异步刷新,避免把摘要重算耗时压在用户本轮响应尾部。 |
| | | */ |
| | | @Override |
| | | public void saveRound(AiChatSession session, Long userId, Long tenantId, List<AiChatMessageDto> memoryMessages, String assistantContent) { |
| | | if (session == null || session.getId() == null) { |
| | | throw new CoolException("AI 会话不存在"); |
| | | } |
| | | ensureIdentity(userId, tenantId); |
| | | List<AiChatMessageDto> normalizedMessages = normalizeMessages(memoryMessages); |
| | | if (normalizedMessages.isEmpty()) { |
| | | throw new CoolException("本轮没有可保存的对话消息"); |
| | | } |
| | | int nextSeqNo = findNextSeqNo(session.getId()); |
| | | Date now = new Date(); |
| | | for (AiChatMessageDto message : normalizedMessages) { |
| | | aiChatMessageMapper.insert(buildMessageEntity(session.getId(), nextSeqNo++, message.getRole(), message.getContent(), userId, tenantId, now)); |
| | | } |
| | | if (StringUtils.hasText(assistantContent)) { |
| | | aiChatMessageMapper.insert(buildMessageEntity(session.getId(), nextSeqNo, "assistant", assistantContent, userId, tenantId, now)); |
| | | } |
| | | AiChatSession update = new AiChatSession() |
| | | .setId(session.getId()) |
| | | .setTitle(resolveUpdatedTitle(session.getTitle(), normalizedMessages)) |
| | | .setLastMessageTime(now) |
| | | .setUpdateBy(userId) |
| | | .setUpdateTime(now); |
| | | aiChatSessionMapper.updateById(update); |
| | | evictConversationCaches(tenantId, userId); |
| | | scheduleMemoryProfileRefresh(session.getId(), userId, tenantId); |
| | | aiConversationCommandService.saveRound(session, userId, tenantId, memoryMessages, assistantContent); |
| | | } |
| | | |
| | | /** 删除整个会话及其消息。 */ |
| | | @Override |
| | | public void removeSession(Long userId, Long tenantId, Long sessionId) { |
| | | ensureIdentity(userId, tenantId); |
| | | if (sessionId == null) { |
| | | throw new CoolException("AI 会话 ID 不能为空"); |
| | | } |
| | | AiChatSession session = aiChatSessionMapper.selectOne(new LambdaQueryWrapper<AiChatSession>() |
| | | .eq(AiChatSession::getId, sessionId) |
| | | .eq(AiChatSession::getUserId, userId) |
| | | .eq(AiChatSession::getTenantId, tenantId) |
| | | .eq(AiChatSession::getDeleted, 0) |
| | | .last("limit 1")); |
| | | if (session == null) { |
| | | throw new CoolException("AI 会话不存在或无权删除"); |
| | | } |
| | | Date now = new Date(); |
| | | AiChatSession updateSession = new AiChatSession() |
| | | .setId(sessionId) |
| | | .setDeleted(1) |
| | | .setUpdateBy(userId) |
| | | .setUpdateTime(now); |
| | | aiChatSessionMapper.updateById(updateSession); |
| | | List<AiChatMessage> messages = aiChatMessageMapper.selectList(new LambdaQueryWrapper<AiChatMessage>() |
| | | .eq(AiChatMessage::getSessionId, sessionId) |
| | | .eq(AiChatMessage::getDeleted, 0)); |
| | | for (AiChatMessage message : messages) { |
| | | AiChatMessage updateMessage = new AiChatMessage() |
| | | .setId(message.getId()) |
| | | .setDeleted(1); |
| | | aiChatMessageMapper.updateById(updateMessage); |
| | | } |
| | | evictConversationCaches(tenantId, userId); |
| | | aiConversationCommandService.removeSession(userId, tenantId, sessionId); |
| | | } |
| | | |
| | | /** 更新会话标题并返回最新会话摘要。 */ |
| | | @Override |
| | | public AiChatSessionDto renameSession(Long userId, Long tenantId, Long sessionId, AiChatSessionRenameRequest request) { |
| | | ensureIdentity(userId, tenantId); |
| | | if (request == null || !StringUtils.hasText(request.getTitle())) { |
| | | throw new CoolException("会话标题不能为空"); |
| | | } |
| | | AiChatSession session = requireOwnedSession(sessionId, userId, tenantId); |
| | | Date now = new Date(); |
| | | AiChatSession update = new AiChatSession() |
| | | .setId(sessionId) |
| | | .setTitle(buildSessionTitle(request.getTitle())) |
| | | .setUpdateBy(userId) |
| | | .setUpdateTime(now); |
| | | aiChatSessionMapper.updateById(update); |
| | | AiChatSessionDto sessionDto = buildSessionDto(requireOwnedSession(sessionId, userId, tenantId)); |
| | | evictConversationCaches(tenantId, userId); |
| | | return sessionDto; |
| | | return aiConversationCommandService.renameSession(userId, tenantId, sessionId, request); |
| | | } |
| | | |
| | | /** 更新会话置顶状态。 */ |
| | | @Override |
| | | public AiChatSessionDto pinSession(Long userId, Long tenantId, Long sessionId, AiChatSessionPinRequest request) { |
| | | ensureIdentity(userId, tenantId); |
| | | if (request == null || request.getPinned() == null) { |
| | | throw new CoolException("置顶状态不能为空"); |
| | | } |
| | | AiChatSession session = requireOwnedSession(sessionId, userId, tenantId); |
| | | Date now = new Date(); |
| | | AiChatSession update = new AiChatSession() |
| | | .setId(sessionId) |
| | | .setPinned(Boolean.TRUE.equals(request.getPinned()) ? 1 : 0) |
| | | .setUpdateBy(userId) |
| | | .setUpdateTime(now); |
| | | aiChatSessionMapper.updateById(update); |
| | | AiChatSessionDto sessionDto = buildSessionDto(requireOwnedSession(sessionId, userId, tenantId)); |
| | | evictConversationCaches(tenantId, userId); |
| | | return sessionDto; |
| | | return aiConversationCommandService.pinSession(userId, tenantId, sessionId, request); |
| | | } |
| | | |
| | | /** 清空某个会话的全部消息和派生记忆字段。 */ |
| | | @Override |
| | | public void clearSessionMemory(Long userId, Long tenantId, Long sessionId) { |
| | | ensureIdentity(userId, tenantId); |
| | | AiChatSession session = requireOwnedSession(sessionId, userId, tenantId); |
| | | List<AiChatMessage> messages = aiChatMessageMapper.selectList(new LambdaQueryWrapper<AiChatMessage>() |
| | | .eq(AiChatMessage::getSessionId, sessionId) |
| | | .eq(AiChatMessage::getDeleted, 0)); |
| | | for (AiChatMessage message : messages) { |
| | | aiChatMessageMapper.updateById(new AiChatMessage() |
| | | .setId(message.getId()) |
| | | .setDeleted(1)); |
| | | } |
| | | aiChatSessionMapper.updateById(new AiChatSession() |
| | | .setId(sessionId) |
| | | .setMemorySummary(null) |
| | | .setMemoryFacts(null) |
| | | .setUpdateBy(userId) |
| | | .setUpdateTime(new Date()) |
| | | .setLastMessageTime(session.getCreateTime())); |
| | | evictConversationCaches(tenantId, userId); |
| | | aiConversationCommandService.clearSessionMemory(userId, tenantId, sessionId); |
| | | } |
| | | |
| | | /** 只保留最近一轮问答,用于手动裁剪长会话。 */ |
| | | @Override |
| | | public void retainLatestRound(Long userId, Long tenantId, Long sessionId) { |
| | | ensureIdentity(userId, tenantId); |
| | | requireOwnedSession(sessionId, userId, tenantId); |
| | | List<AiChatMessage> records = listMessageRecords(sessionId); |
| | | if (records.isEmpty()) { |
| | | return; |
| | | } |
| | | List<AiChatMessage> retained = tailMessageRecordsByRounds(records, 1); |
| | | for (AiChatMessage message : records) { |
| | | boolean shouldKeep = retained.stream().anyMatch(item -> item.getId().equals(message.getId())); |
| | | if (!shouldKeep) { |
| | | aiChatMessageMapper.updateById(new AiChatMessage() |
| | | .setId(message.getId()) |
| | | .setDeleted(1)); |
| | | } |
| | | } |
| | | evictConversationCaches(tenantId, userId); |
| | | scheduleMemoryProfileRefresh(sessionId, userId, tenantId); |
| | | } |
| | | |
| | | private void evictConversationCaches(Long tenantId, Long userId) { |
| | | // 会话标题、摘要、最近消息和 runtime 都会互相影响,统一按用户维度一起失效更稳妥。 |
| | | aiRedisSupport.evictUserConversationCaches(tenantId, userId); |
| | | } |
| | | |
| | | private void scheduleMemoryProfileRefresh(Long sessionId, Long userId, Long tenantId) { |
| | | if (sessionId == null) { |
| | | return; |
| | | } |
| | | if (!refreshingSessionIds.add(sessionId)) { |
| | | pendingRefreshSessionIds.add(sessionId); |
| | | return; |
| | | } |
| | | aiMemoryTaskExecutor.execute(() -> runMemoryProfileRefreshLoop(sessionId, userId, tenantId)); |
| | | } |
| | | |
| | | private void runMemoryProfileRefreshLoop(Long sessionId, Long userId, Long tenantId) { |
| | | try { |
| | | boolean shouldContinue; |
| | | do { |
| | | pendingRefreshSessionIds.remove(sessionId); |
| | | try { |
| | | refreshMemoryProfile(sessionId, userId); |
| | | evictConversationCaches(tenantId, userId); |
| | | } catch (Exception e) { |
| | | log.warn("AI memory profile refresh failed, sessionId={}, userId={}, tenantId={}, message={}", |
| | | sessionId, userId, tenantId, e.getMessage(), e); |
| | | } |
| | | shouldContinue = pendingRefreshSessionIds.remove(sessionId); |
| | | } while (shouldContinue); |
| | | } finally { |
| | | refreshingSessionIds.remove(sessionId); |
| | | if (pendingRefreshSessionIds.remove(sessionId) && refreshingSessionIds.add(sessionId)) { |
| | | aiMemoryTaskExecutor.execute(() -> runMemoryProfileRefreshLoop(sessionId, userId, tenantId)); |
| | | } |
| | | } |
| | | } |
| | | |
| | | private AiChatSession findLatestSession(Long userId, Long tenantId, String promptCode) { |
| | | return aiChatSessionMapper.selectOne(new LambdaQueryWrapper<AiChatSession>() |
| | | .eq(AiChatSession::getUserId, userId) |
| | | .eq(AiChatSession::getTenantId, tenantId) |
| | | .eq(AiChatSession::getPromptCode, promptCode) |
| | | .eq(AiChatSession::getDeleted, 0) |
| | | .eq(AiChatSession::getStatus, StatusType.ENABLE.val) |
| | | .orderByDesc(AiChatSession::getLastMessageTime) |
| | | .orderByDesc(AiChatSession::getId) |
| | | .last("limit 1")); |
| | | } |
| | | |
| | | private AiChatSession getSession(Long sessionId, Long userId, Long tenantId, String promptCode) { |
| | | AiChatSession session = aiChatSessionMapper.selectOne(new LambdaQueryWrapper<AiChatSession>() |
| | | .eq(AiChatSession::getId, sessionId) |
| | | .eq(AiChatSession::getUserId, userId) |
| | | .eq(AiChatSession::getTenantId, tenantId) |
| | | .eq(AiChatSession::getPromptCode, promptCode) |
| | | .eq(AiChatSession::getDeleted, 0) |
| | | .eq(AiChatSession::getStatus, StatusType.ENABLE.val) |
| | | .last("limit 1")); |
| | | if (session == null) { |
| | | throw new CoolException("AI 会话不存在或无权访问"); |
| | | } |
| | | return session; |
| | | } |
| | | |
| | | private AiChatSession requireOwnedSession(Long sessionId, Long userId, Long tenantId) { |
| | | if (sessionId == null) { |
| | | throw new CoolException("AI 会话 ID 不能为空"); |
| | | } |
| | | AiChatSession session = aiChatSessionMapper.selectOne(new LambdaQueryWrapper<AiChatSession>() |
| | | .eq(AiChatSession::getId, sessionId) |
| | | .eq(AiChatSession::getUserId, userId) |
| | | .eq(AiChatSession::getTenantId, tenantId) |
| | | .eq(AiChatSession::getDeleted, 0) |
| | | .eq(AiChatSession::getStatus, StatusType.ENABLE.val) |
| | | .last("limit 1")); |
| | | if (session == null) { |
| | | throw new CoolException("AI 会话不存在或无权访问"); |
| | | } |
| | | return session; |
| | | } |
| | | |
| | | private List<AiChatMessageDto> listMessages(Long sessionId) { |
| | | List<AiChatMessage> records = listMessageRecords(sessionId); |
| | | if (Cools.isEmpty(records)) { |
| | | return List.of(); |
| | | } |
| | | List<AiChatMessageDto> messages = new ArrayList<>(); |
| | | for (AiChatMessage record : records) { |
| | | if (!StringUtils.hasText(record.getContent())) { |
| | | continue; |
| | | } |
| | | AiChatMessageDto item = new AiChatMessageDto(); |
| | | item.setRole(record.getRole()); |
| | | item.setContent(record.getContent()); |
| | | messages.add(item); |
| | | } |
| | | return messages; |
| | | } |
| | | |
| | | private List<AiChatMessageDto> normalizeMessages(List<AiChatMessageDto> memoryMessages) { |
| | | /** 清洗前端上传的内存消息,只允许 user/assistant 两类角色落库。 */ |
| | | List<AiChatMessageDto> normalized = new ArrayList<>(); |
| | | if (Cools.isEmpty(memoryMessages)) { |
| | | return normalized; |
| | | } |
| | | for (AiChatMessageDto item : memoryMessages) { |
| | | if (item == null || !StringUtils.hasText(item.getContent())) { |
| | | continue; |
| | | } |
| | | String role = item.getRole() == null ? "user" : item.getRole().toLowerCase(); |
| | | if ("system".equals(role)) { |
| | | continue; |
| | | } |
| | | AiChatMessageDto normalizedItem = new AiChatMessageDto(); |
| | | normalizedItem.setRole("assistant".equals(role) ? "assistant" : "user"); |
| | | normalizedItem.setContent(item.getContent().trim()); |
| | | normalized.add(normalizedItem); |
| | | } |
| | | return normalized; |
| | | } |
| | | |
| | | private List<AiChatMessage> listMessageRecords(Long sessionId) { |
| | | return aiChatMessageMapper.selectList(new LambdaQueryWrapper<AiChatMessage>() |
| | | .eq(AiChatMessage::getSessionId, sessionId) |
| | | .eq(AiChatMessage::getDeleted, 0) |
| | | .orderByAsc(AiChatMessage::getSeqNo) |
| | | .orderByAsc(AiChatMessage::getId)); |
| | | } |
| | | |
| | | private int findNextSeqNo(Long sessionId) { |
| | | AiChatMessage lastMessage = aiChatMessageMapper.selectOne(new LambdaQueryWrapper<AiChatMessage>() |
| | | .eq(AiChatMessage::getSessionId, sessionId) |
| | | .eq(AiChatMessage::getDeleted, 0) |
| | | .orderByDesc(AiChatMessage::getSeqNo) |
| | | .orderByDesc(AiChatMessage::getId) |
| | | .last("limit 1")); |
| | | return lastMessage == null || lastMessage.getSeqNo() == null ? 1 : lastMessage.getSeqNo() + 1; |
| | | } |
| | | |
| | | private AiChatMessage buildMessageEntity(Long sessionId, int seqNo, String role, String content, Long userId, Long tenantId, Date createTime) { |
| | | return new AiChatMessage() |
| | | .setSessionId(sessionId) |
| | | .setSeqNo(seqNo) |
| | | .setRole(role) |
| | | .setContent(content) |
| | | .setContentLength(content == null ? 0 : content.length()) |
| | | .setUserId(userId) |
| | | .setTenantId(tenantId) |
| | | .setDeleted(0) |
| | | .setCreateBy(userId) |
| | | .setCreateTime(createTime); |
| | | } |
| | | |
| | | private String resolveUpdatedTitle(String currentTitle, List<AiChatMessageDto> memoryMessages) { |
| | | if (StringUtils.hasText(currentTitle)) { |
| | | return currentTitle; |
| | | } |
| | | for (AiChatMessageDto item : memoryMessages) { |
| | | if ("user".equals(item.getRole()) && StringUtils.hasText(item.getContent())) { |
| | | return buildSessionTitle(item.getContent()); |
| | | } |
| | | } |
| | | return null; |
| | | } |
| | | |
| | | private String buildSessionTitle(String titleSeed) { |
| | | /** |
| | | * 把首轮用户问题压缩成适合作为会话标题的短摘要。 |
| | | * 这里会去掉换行、连续空白,并优先在自然语义断点处截断。 |
| | | */ |
| | | if (!StringUtils.hasText(titleSeed)) { |
| | | throw new CoolException("AI 会话标题不能为空"); |
| | | } |
| | | String title = titleSeed.trim() |
| | | .replace("\r", " ") |
| | | .replace("\n", " ") |
| | | .replaceAll("\\s+", " "); |
| | | int punctuationIndex = findSummaryBreakIndex(title); |
| | | if (punctuationIndex > 0) { |
| | | title = title.substring(0, punctuationIndex).trim(); |
| | | } |
| | | return title.length() > 48 ? title.substring(0, 48) : title; |
| | | } |
| | | |
| | | private int findSummaryBreakIndex(String title) { |
| | | String[] separators = {"。", "!", "?", ".", "!", "?"}; |
| | | int result = -1; |
| | | for (String separator : separators) { |
| | | int index = title.indexOf(separator); |
| | | if (index > 0 && (result < 0 || index < result)) { |
| | | result = index; |
| | | } |
| | | } |
| | | return result; |
| | | } |
| | | |
| | | private AiChatSessionDto buildSessionDto(AiChatSession session) { |
| | | AiChatMessage lastMessage = aiChatMessageMapper.selectOne(new LambdaQueryWrapper<AiChatMessage>() |
| | | .eq(AiChatMessage::getSessionId, session.getId()) |
| | | .eq(AiChatMessage::getDeleted, 0) |
| | | .orderByDesc(AiChatMessage::getSeqNo) |
| | | .orderByDesc(AiChatMessage::getId) |
| | | .last("limit 1")); |
| | | return AiChatSessionDto.builder() |
| | | .sessionId(session.getId()) |
| | | .title(session.getTitle()) |
| | | .promptCode(session.getPromptCode()) |
| | | .pinned(session.getPinned() != null && session.getPinned() == 1) |
| | | .lastMessagePreview(buildLastMessagePreview(lastMessage)) |
| | | .lastMessageTime(session.getLastMessageTime()) |
| | | .build(); |
| | | } |
| | | |
| | | private String buildLastMessagePreview(AiChatMessage message) { |
| | | if (message == null || !StringUtils.hasText(message.getContent())) { |
| | | return null; |
| | | } |
| | | String preview = message.getContent().trim() |
| | | .replace("\r", " ") |
| | | .replace("\n", " ") |
| | | .replaceAll("\\s+", " "); |
| | | String prefix = "assistant".equalsIgnoreCase(message.getRole()) ? "AI: " : "你: "; |
| | | String normalized = prefix + preview; |
| | | return normalized.length() > 80 ? normalized.substring(0, 80) : normalized; |
| | | } |
| | | |
| | | private void refreshMemoryProfile(Long sessionId, Long userId) { |
| | | /** |
| | | * 重新计算会话的摘要记忆和关键事实。 |
| | | * 这是“持久化消息”和“模型上下文治理”之间的桥梁方法。 |
| | | * 现在它运行在后台线程里,因此允许短时间最终一致,而不是强制本轮同步完成。 |
| | | */ |
| | | List<AiChatMessageDto> messages = listMessages(sessionId); |
| | | List<AiChatMessageDto> shortMemoryMessages = tailMessagesByRounds(messages, AiDefaults.MEMORY_RECENT_ROUNDS); |
| | | List<AiChatMessageDto> historyMessages = messages.size() > shortMemoryMessages.size() |
| | | ? messages.subList(0, messages.size() - shortMemoryMessages.size()) |
| | | : List.of(); |
| | | String memorySummary = historyMessages.size() >= AiDefaults.MEMORY_SUMMARY_TRIGGER_MESSAGES |
| | | ? buildMemorySummary(historyMessages) |
| | | : null; |
| | | String memoryFacts = buildMemoryFacts(messages); |
| | | AiChatMessage lastMessage = aiChatMessageMapper.selectOne(new LambdaQueryWrapper<AiChatMessage>() |
| | | .eq(AiChatMessage::getSessionId, sessionId) |
| | | .eq(AiChatMessage::getDeleted, 0) |
| | | .orderByDesc(AiChatMessage::getSeqNo) |
| | | .orderByDesc(AiChatMessage::getId) |
| | | .last("limit 1")); |
| | | aiChatSessionMapper.updateById(new AiChatSession() |
| | | .setId(sessionId) |
| | | .setMemorySummary(memorySummary) |
| | | .setMemoryFacts(memoryFacts) |
| | | .setLastMessageTime(lastMessage == null ? null : lastMessage.getCreateTime()) |
| | | .setUpdateBy(userId) |
| | | .setUpdateTime(new Date())); |
| | | } |
| | | |
| | | private List<AiChatMessageDto> tailMessagesByRounds(List<AiChatMessageDto> source, int rounds) { |
| | | /** 按“用户发言轮次”裁剪最近消息,而不是简单按条数截断。 */ |
| | | if (Cools.isEmpty(source) || rounds <= 0) { |
| | | return List.of(); |
| | | } |
| | | int userCount = 0; |
| | | int startIndex = source.size(); |
| | | for (int i = source.size() - 1; i >= 0; i--) { |
| | | AiChatMessageDto item = source.get(i); |
| | | startIndex = i; |
| | | if (item != null && "user".equalsIgnoreCase(item.getRole())) { |
| | | userCount++; |
| | | if (userCount >= rounds) { |
| | | break; |
| | | } |
| | | } |
| | | } |
| | | return new ArrayList<>(source.subList(Math.max(0, startIndex), source.size())); |
| | | } |
| | | |
| | | private List<AiChatMessage> tailMessageRecordsByRounds(List<AiChatMessage> source, int rounds) { |
| | | if (Cools.isEmpty(source) || rounds <= 0) { |
| | | return List.of(); |
| | | } |
| | | int userCount = 0; |
| | | int startIndex = source.size(); |
| | | for (int i = source.size() - 1; i >= 0; i--) { |
| | | AiChatMessage item = source.get(i); |
| | | startIndex = i; |
| | | if (item != null && "user".equalsIgnoreCase(item.getRole())) { |
| | | userCount++; |
| | | if (userCount >= rounds) { |
| | | break; |
| | | } |
| | | } |
| | | } |
| | | return new ArrayList<>(source.subList(Math.max(0, startIndex), source.size())); |
| | | } |
| | | |
| | | private String buildMemorySummary(List<AiChatMessageDto> historyMessages) { |
| | | /** 为较早历史生成可直接插入系统消息的文本摘要。 */ |
| | | StringBuilder builder = new StringBuilder("较早对话摘要:\n"); |
| | | for (AiChatMessageDto item : historyMessages) { |
| | | if (item == null || !StringUtils.hasText(item.getContent())) { |
| | | continue; |
| | | } |
| | | String prefix = "assistant".equalsIgnoreCase(item.getRole()) ? "- AI: " : "- 用户: "; |
| | | String content = compactText(item.getContent(), 120); |
| | | if (!StringUtils.hasText(content)) { |
| | | continue; |
| | | } |
| | | builder.append(prefix).append(content).append("\n"); |
| | | if (builder.length() >= AiDefaults.MEMORY_SUMMARY_MAX_LENGTH) { |
| | | break; |
| | | } |
| | | } |
| | | return compactText(builder.toString(), AiDefaults.MEMORY_SUMMARY_MAX_LENGTH); |
| | | } |
| | | |
| | | private String buildMemoryFacts(List<AiChatMessageDto> messages) { |
| | | /** 从最近用户关注点中提炼关键事实,作为轻量持久记忆。 */ |
| | | if (Cools.isEmpty(messages)) { |
| | | return null; |
| | | } |
| | | StringBuilder builder = new StringBuilder("关键事实:\n"); |
| | | int userFacts = 0; |
| | | for (int i = messages.size() - 1; i >= 0 && userFacts < 4; i--) { |
| | | AiChatMessageDto item = messages.get(i); |
| | | if (item == null || !"user".equalsIgnoreCase(item.getRole()) || !StringUtils.hasText(item.getContent())) { |
| | | continue; |
| | | } |
| | | builder.append("- 用户关注: ").append(compactText(item.getContent(), 100)).append("\n"); |
| | | userFacts++; |
| | | } |
| | | return userFacts == 0 ? null : compactText(builder.toString(), AiDefaults.MEMORY_FACTS_MAX_LENGTH); |
| | | } |
| | | |
| | | private String compactText(String content, int maxLength) { |
| | | if (!StringUtils.hasText(content)) { |
| | | return null; |
| | | } |
| | | String normalized = content.trim() |
| | | .replace("\r", " ") |
| | | .replace("\n", " ") |
| | | .replaceAll("\\s+", " "); |
| | | return normalized.length() > maxLength ? normalized.substring(0, maxLength) : normalized; |
| | | } |
| | | |
| | | private void ensureIdentity(Long userId, Long tenantId) { |
| | | if (userId == null) { |
| | | throw new CoolException("当前登录用户不存在"); |
| | | } |
| | | if (tenantId == null) { |
| | | throw new CoolException("当前租户不存在"); |
| | | } |
| | | } |
| | | |
| | | private String requirePromptCode(String promptCode) { |
| | | if (!StringUtils.hasText(promptCode)) { |
| | | throw new CoolException("Prompt 编码不能为空"); |
| | | } |
| | | return promptCode; |
| | | aiConversationCommandService.retainLatestRound(userId, tenantId, sessionId); |
| | | } |
| | | } |
| | |
| | | package com.vincent.rsf.server.ai.service.impl; |
| | | |
| | | import com.fasterxml.jackson.databind.ObjectMapper; |
| | | import com.vincent.rsf.framework.common.Cools; |
| | | import com.vincent.rsf.framework.exception.CoolException; |
| | | import com.vincent.rsf.server.ai.config.AiDefaults; |
| | | import com.vincent.rsf.server.ai.dto.AiChatDoneDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatErrorDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatMemoryDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatMessageDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatModelOptionDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatRequest; |
| | | import com.vincent.rsf.server.ai.dto.AiChatRuntimeDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatStatusDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatSessionDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatSessionPinRequest; |
| | | import com.vincent.rsf.server.ai.dto.AiChatSessionRenameRequest; |
| | | import com.vincent.rsf.server.ai.dto.AiChatThinkingEventDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatToolEventDto; |
| | | import com.vincent.rsf.server.ai.dto.AiResolvedConfig; |
| | | import com.vincent.rsf.server.ai.entity.AiCallLog; |
| | | import com.vincent.rsf.server.ai.entity.AiParam; |
| | | import com.vincent.rsf.server.ai.entity.AiPrompt; |
| | | import com.vincent.rsf.server.ai.entity.AiChatSession; |
| | | import com.vincent.rsf.server.ai.enums.AiErrorCategory; |
| | | import com.vincent.rsf.server.ai.exception.AiChatException; |
| | | import com.vincent.rsf.server.ai.service.AiCallLogService; |
| | | import com.vincent.rsf.server.ai.service.AiChatService; |
| | | import com.vincent.rsf.server.ai.service.AiChatMemoryService; |
| | | import com.vincent.rsf.server.ai.service.AiChatService; |
| | | import com.vincent.rsf.server.ai.service.AiConfigResolverService; |
| | | import com.vincent.rsf.server.ai.service.AiParamService; |
| | | import com.vincent.rsf.server.ai.service.MountedToolCallback; |
| | | import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory; |
| | | import io.micrometer.observation.ObservationRegistry; |
| | | import com.vincent.rsf.server.ai.service.impl.chat.AiChatOrchestrator; |
| | | import com.vincent.rsf.server.ai.service.impl.chat.AiChatRuntimeAssembler; |
| | | import com.vincent.rsf.server.ai.store.AiConversationCacheStore; |
| | | import lombok.RequiredArgsConstructor; |
| | | import lombok.extern.slf4j.Slf4j; |
| | | import org.springframework.ai.chat.messages.AssistantMessage; |
| | | import org.springframework.ai.chat.messages.Message; |
| | | import org.springframework.ai.chat.messages.SystemMessage; |
| | | import org.springframework.ai.chat.messages.UserMessage; |
| | | import org.springframework.ai.chat.metadata.ChatResponseMetadata; |
| | | import org.springframework.ai.chat.metadata.Usage; |
| | | import org.springframework.ai.chat.model.ChatResponse; |
| | | import org.springframework.ai.chat.model.ToolContext; |
| | | import org.springframework.ai.chat.prompt.Prompt; |
| | | import org.springframework.ai.model.tool.DefaultToolCallingManager; |
| | | import org.springframework.ai.model.tool.ToolCallingManager; |
| | | import org.springframework.ai.openai.OpenAiChatModel; |
| | | import org.springframework.ai.openai.OpenAiChatOptions; |
| | | import org.springframework.ai.openai.api.OpenAiApi; |
| | | import org.springframework.ai.tool.ToolCallback; |
| | | import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor; |
| | | import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver; |
| | | import org.springframework.ai.util.json.schema.SchemaType; |
| | | import org.springframework.context.support.GenericApplicationContext; |
| | | import org.springframework.http.MediaType; |
| | | import org.springframework.beans.factory.annotation.Qualifier; |
| | | import org.springframework.stereotype.Service; |
| | | import org.springframework.util.StringUtils; |
| | | import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; |
| | | import reactor.core.publisher.Flux; |
| | | |
| | | import java.io.IOException; |
| | | import java.time.Instant; |
| | | import java.util.ArrayList; |
| | | import java.util.Arrays; |
| | | import java.util.LinkedHashMap; |
| | | import java.util.List; |
| | | import java.util.Map; |
| | | import java.util.Objects; |
| | | import java.util.concurrent.CompletableFuture; |
| | | import java.util.concurrent.Executor; |
| | | import java.util.concurrent.atomic.AtomicReference; |
| | | import java.util.concurrent.atomic.AtomicLong; |
| | | |
| | | @Slf4j |
| | | @Service |
| | | @RequiredArgsConstructor |
| | | public class AiChatServiceImpl implements AiChatService { |
| | |
| | | private final AiConfigResolverService aiConfigResolverService; |
| | | private final AiChatMemoryService aiChatMemoryService; |
| | | private final AiParamService aiParamService; |
| | | private final McpMountRuntimeFactory mcpMountRuntimeFactory; |
| | | private final AiCallLogService aiCallLogService; |
| | | private final AiRedisSupport aiRedisSupport; |
| | | private final GenericApplicationContext applicationContext; |
| | | private final ObservationRegistry observationRegistry; |
| | | private final ObjectMapper objectMapper; |
| | | private final AiConversationCacheStore aiConversationCacheStore; |
| | | private final AiChatRuntimeAssembler aiChatRuntimeAssembler; |
| | | private final AiChatOrchestrator aiChatOrchestrator; |
| | | @Qualifier("aiChatTaskExecutor") |
| | | private final Executor aiChatTaskExecutor; |
| | | |
| | | /** |
| | | * 获取当前对话抽屉初始化所需的运行时数据。 |
| | | * 该方法不会触发模型调用,而是把配置解析结果和会话记忆聚合成前端一次渲染所需的快照。 |
| | | */ |
| | | @Override |
| | | public AiChatRuntimeDto getRuntime(String promptCode, Long sessionId, Long aiParamId, Long userId, Long tenantId) { |
| | | AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId, aiParamId); |
| | | Long runtimeCacheAiParamId = aiParamId; |
| | | // runtime 是配置快照和会话记忆的聚合视图,单独缓存能减少一次页面进入时的重复拼装。 |
| | | AiChatRuntimeDto cached = aiRedisSupport.getRuntime(tenantId, userId, config.getPromptCode(), sessionId, runtimeCacheAiParamId); |
| | | AiChatRuntimeDto cached = aiConversationCacheStore.getRuntime(tenantId, userId, config.getPromptCode(), sessionId, aiParamId); |
| | | if (cached != null) { |
| | | return cached; |
| | | } |
| | | AiChatMemoryDto memory = aiChatMemoryService.getMemory(userId, tenantId, config.getPromptCode(), sessionId); |
| | | List<AiChatModelOptionDto> modelOptions = aiParamService.listChatModelOptions(tenantId); |
| | | AiChatRuntimeDto runtime = buildRuntimeSnapshot( |
| | | AiChatRuntimeDto runtime = aiChatRuntimeAssembler.buildRuntimeSnapshot( |
| | | null, |
| | | memory.getSessionId(), |
| | | config, |
| | |
| | | List.of(), |
| | | memory |
| | | ); |
| | | aiRedisSupport.cacheRuntime(tenantId, userId, config.getPromptCode(), sessionId, runtimeCacheAiParamId, runtime); |
| | | aiConversationCacheStore.cacheRuntime(tenantId, userId, config.getPromptCode(), sessionId, aiParamId, runtime); |
| | | if (memory.getSessionId() != null && !Objects.equals(memory.getSessionId(), sessionId)) { |
| | | aiRedisSupport.cacheRuntime(tenantId, userId, config.getPromptCode(), memory.getSessionId(), runtimeCacheAiParamId, runtime); |
| | | aiConversationCacheStore.cacheRuntime(tenantId, userId, config.getPromptCode(), memory.getSessionId(), aiParamId, runtime); |
| | | } |
| | | return runtime; |
| | | } |
| | | |
| | | /** |
| | | * 查询指定 Prompt 场景下的历史会话摘要列表。 |
| | | */ |
| | | @Override |
| | | public List<AiChatSessionDto> listSessions(String promptCode, String keyword, Long userId, Long tenantId) { |
| | | AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId); |
| | | return aiChatMemoryService.listSessions(userId, tenantId, config.getPromptCode(), keyword); |
| | | } |
| | | |
| | | @Override |
| | | public SseEmitter stream(AiChatRequest request, Long userId, Long tenantId) { |
| | | SseEmitter emitter = new SseEmitter(AiDefaults.SSE_TIMEOUT_MS); |
| | | CompletableFuture.runAsync(() -> aiChatOrchestrator.executeStream(request, userId, tenantId, emitter), aiChatTaskExecutor); |
| | | return emitter; |
| | | } |
| | | |
| | | @Override |
| | |
| | | @Override |
| | | public void retainLatestRound(Long sessionId, Long userId, Long tenantId) { |
| | | aiChatMemoryService.retainLatestRound(userId, tenantId, sessionId); |
| | | } |
| | | |
| | | /** |
| | | * 启动一次新的 SSE 对话流。 |
| | | * 控制线程立即返回 emitter,真正的模型调用与工具执行交给 AI 专用线程池异步处理。 |
| | | */ |
| | | @Override |
| | | public SseEmitter stream(AiChatRequest request, Long userId, Long tenantId) { |
| | | SseEmitter emitter = new SseEmitter(AiDefaults.SSE_TIMEOUT_MS); |
| | | CompletableFuture.runAsync(() -> doStream(request, userId, tenantId, emitter), aiChatTaskExecutor); |
| | | return emitter; |
| | | } |
| | | |
| | | private void doStream(AiChatRequest request, Long userId, Long tenantId, SseEmitter emitter) { |
| | | /** |
| | | * AI 对话的核心执行链路: |
| | | * 1. 校验身份和解析租户配置 |
| | | * 2. 解析或创建会话,加载记忆 |
| | | * 3. 动态挂载 MCP 工具 |
| | | * 4. 发起模型流式/非流式调用 |
| | | * 5. 持久化本轮消息,输出 SSE 事件并记录审计日志 |
| | | */ |
| | | String requestId = request.getRequestId(); |
| | | long startedAt = System.currentTimeMillis(); |
| | | AtomicReference<Long> firstTokenAtRef = new AtomicReference<>(); |
| | | AtomicLong toolCallSequence = new AtomicLong(0); |
| | | AtomicLong toolSuccessCount = new AtomicLong(0); |
| | | AtomicLong toolFailureCount = new AtomicLong(0); |
| | | Long sessionId = request.getSessionId(); |
| | | Long callLogId = null; |
| | | String model = null; |
| | | String resolvedPromptCode = request.getPromptCode(); |
| | | ThinkingTraceEmitter thinkingTraceEmitter = null; |
| | | try { |
| | | ensureIdentity(userId, tenantId); |
| | | AiResolvedConfig config = resolveConfig(request, tenantId); |
| | | List<AiChatModelOptionDto> modelOptions = aiParamService.listChatModelOptions(tenantId); |
| | | resolvedPromptCode = config.getPromptCode(); |
| | | if (!aiRedisSupport.allowChatRequest(tenantId, userId, config.getPromptCode())) { |
| | | throw buildAiException("AI_RATE_LIMITED", AiErrorCategory.REQUEST, "RATE_LIMIT", |
| | | "当前提问过于频繁,请稍后再试", null); |
| | | } |
| | | final String resolvedModel = config.getAiParam().getModel(); |
| | | model = resolvedModel; |
| | | AiChatSession session = resolveSession(request, userId, tenantId, config.getPromptCode()); |
| | | sessionId = session.getId(); |
| | | // 流状态落 Redis 的目标是给多实例和后续运维查询留统一入口,不替代数据库日志。 |
| | | aiRedisSupport.markStreamState(requestId, tenantId, userId, sessionId, config.getPromptCode(), "RUNNING", null); |
| | | AiChatMemoryDto memory = loadMemory(userId, tenantId, config.getPromptCode(), session.getId()); |
| | | List<AiChatMessageDto> mergedMessages = mergeMessages(memory.getShortMemoryMessages(), request.getMessages()); |
| | | AiCallLog callLog = aiCallLogService.startCallLog( |
| | | requestId, |
| | | session.getId(), |
| | | userId, |
| | | tenantId, |
| | | config.getPromptCode(), |
| | | config.getPrompt().getName(), |
| | | config.getAiParam().getModel(), |
| | | config.getMcpMounts().size(), |
| | | config.getMcpMounts().size(), |
| | | config.getMcpMounts().stream().map(item -> item.getName()).toList() |
| | | ); |
| | | callLogId = callLog.getId(); |
| | | try (McpMountRuntimeFactory.McpMountRuntime runtime = createRuntime(config, userId)) { |
| | | emitStrict(emitter, "start", buildRuntimeSnapshot( |
| | | requestId, |
| | | session.getId(), |
| | | config, |
| | | modelOptions, |
| | | runtime.getMountedCount(), |
| | | runtime.getMountedNames(), |
| | | runtime.getErrors(), |
| | | memory |
| | | )); |
| | | emitSafely(emitter, "status", AiChatStatusDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(session.getId()) |
| | | .status("STARTED") |
| | | .model(resolvedModel) |
| | | .timestamp(Instant.now().toEpochMilli()) |
| | | .elapsedMs(0L) |
| | | .build()); |
| | | log.info("AI chat started, requestId={}, userId={}, tenantId={}, sessionId={}, model={}", |
| | | requestId, userId, tenantId, session.getId(), resolvedModel); |
| | | thinkingTraceEmitter = new ThinkingTraceEmitter(emitter, requestId, session.getId()); |
| | | thinkingTraceEmitter.startAnalyze(); |
| | | |
| | | ThinkingTraceEmitter activeThinkingTraceEmitter = thinkingTraceEmitter; |
| | | ToolCallback[] observableToolCallbacks = wrapToolCallbacks( |
| | | runtime.getToolCallbacks(), emitter, requestId, session.getId(), toolCallSequence, |
| | | toolSuccessCount, toolFailureCount, callLogId, userId, tenantId, activeThinkingTraceEmitter |
| | | ); |
| | | Prompt prompt = new Prompt( |
| | | buildPromptMessages(memory, mergedMessages, config.getPrompt(), request.getMetadata()), |
| | | buildChatOptions(config.getAiParam(), observableToolCallbacks, userId, tenantId, |
| | | requestId, session.getId(), request.getMetadata()) |
| | | ); |
| | | OpenAiChatModel chatModel = createChatModel(config.getAiParam()); |
| | | if (Boolean.FALSE.equals(config.getAiParam().getStreamingEnabled())) { |
| | | ChatResponse response = invokeChatCall(chatModel, prompt); |
| | | String content = extractContent(response); |
| | | aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), content); |
| | | if (StringUtils.hasText(content)) { |
| | | markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt, activeThinkingTraceEmitter); |
| | | emitStrict(emitter, "delta", buildMessagePayload("requestId", requestId, "content", content)); |
| | | } |
| | | activeThinkingTraceEmitter.completeCurrentPhase(); |
| | | emitDone(emitter, requestId, response.getMetadata(), config.getAiParam().getModel(), session.getId(), startedAt, firstTokenAtRef.get()); |
| | | emitSafely(emitter, "status", buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get())); |
| | | aiCallLogService.completeCallLog( |
| | | callLogId, |
| | | "COMPLETED", |
| | | System.currentTimeMillis() - startedAt, |
| | | resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()), |
| | | response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getPromptTokens(), |
| | | response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getCompletionTokens(), |
| | | response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getTotalTokens(), |
| | | toolSuccessCount.get(), |
| | | toolFailureCount.get() |
| | | ); |
| | | aiRedisSupport.markStreamState(requestId, tenantId, userId, session.getId(), config.getPromptCode(), "COMPLETED", null); |
| | | log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}", |
| | | requestId, session.getId(), System.currentTimeMillis() - startedAt, resolveFirstTokenLatency(startedAt, firstTokenAtRef.get())); |
| | | emitter.complete(); |
| | | return; |
| | | } |
| | | |
| | | Flux<ChatResponse> responseFlux = invokeChatStream(chatModel, prompt); |
| | | AtomicReference<ChatResponseMetadata> lastMetadata = new AtomicReference<>(); |
| | | StringBuilder assistantContent = new StringBuilder(); |
| | | try { |
| | | responseFlux.doOnNext(response -> { |
| | | lastMetadata.set(response.getMetadata()); |
| | | String content = extractContent(response); |
| | | if (StringUtils.hasText(content)) { |
| | | markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt, activeThinkingTraceEmitter); |
| | | assistantContent.append(content); |
| | | emitStrict(emitter, "delta", buildMessagePayload("requestId", requestId, "content", content)); |
| | | } |
| | | }) |
| | | .blockLast(); |
| | | } catch (Exception e) { |
| | | throw buildAiException("AI_MODEL_STREAM_ERROR", AiErrorCategory.MODEL, "MODEL_STREAM", |
| | | e == null ? "AI 模型流式调用失败" : e.getMessage(), e); |
| | | } |
| | | aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), assistantContent.toString()); |
| | | activeThinkingTraceEmitter.completeCurrentPhase(); |
| | | emitDone(emitter, requestId, lastMetadata.get(), config.getAiParam().getModel(), session.getId(), startedAt, firstTokenAtRef.get()); |
| | | emitSafely(emitter, "status", buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get())); |
| | | aiCallLogService.completeCallLog( |
| | | callLogId, |
| | | "COMPLETED", |
| | | System.currentTimeMillis() - startedAt, |
| | | resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()), |
| | | lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getPromptTokens(), |
| | | lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getCompletionTokens(), |
| | | lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getTotalTokens(), |
| | | toolSuccessCount.get(), |
| | | toolFailureCount.get() |
| | | ); |
| | | aiRedisSupport.markStreamState(requestId, tenantId, userId, session.getId(), config.getPromptCode(), "COMPLETED", null); |
| | | log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}", |
| | | requestId, session.getId(), System.currentTimeMillis() - startedAt, resolveFirstTokenLatency(startedAt, firstTokenAtRef.get())); |
| | | emitter.complete(); |
| | | } |
| | | } catch (AiChatException e) { |
| | | handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(), e, |
| | | callLogId, toolSuccessCount.get(), toolFailureCount.get(), thinkingTraceEmitter, |
| | | tenantId, userId, resolvedPromptCode); |
| | | } catch (Exception e) { |
| | | handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(), |
| | | buildAiException("AI_INTERNAL_ERROR", AiErrorCategory.INTERNAL, "INTERNAL", |
| | | e == null ? "AI 对话失败" : e.getMessage(), e), |
| | | callLogId, toolSuccessCount.get(), toolFailureCount.get(), thinkingTraceEmitter, |
| | | tenantId, userId, resolvedPromptCode); |
| | | } finally { |
| | | log.debug("AI chat stream finished, requestId={}", requestId); |
| | | } |
| | | } |
| | | |
| | | private void ensureIdentity(Long userId, Long tenantId) { |
| | | if (userId == null) { |
| | | throw buildAiException("AI_AUTH_USER_MISSING", AiErrorCategory.AUTH, "AUTH_VALIDATE", "当前登录用户不存在", null); |
| | | } |
| | | if (tenantId == null) { |
| | | throw buildAiException("AI_AUTH_TENANT_MISSING", AiErrorCategory.AUTH, "AUTH_VALIDATE", "当前租户不存在", null); |
| | | } |
| | | } |
| | | |
| | | private AiResolvedConfig resolveConfig(AiChatRequest request, Long tenantId) { |
| | | /** 把请求里的 Prompt 场景解析成一份可直接执行的 AI 配置。 */ |
| | | try { |
| | | return aiConfigResolverService.resolve(request.getPromptCode(), tenantId, request.getAiParamId()); |
| | | } catch (Exception e) { |
| | | throw buildAiException("AI_CONFIG_RESOLVE_ERROR", AiErrorCategory.CONFIG, "CONFIG_RESOLVE", |
| | | e == null ? "AI 配置解析失败" : e.getMessage(), e); |
| | | } |
| | | } |
| | | |
| | | private AiChatRuntimeDto buildRuntimeSnapshot(String requestId, Long sessionId, AiResolvedConfig config, |
| | | List<AiChatModelOptionDto> modelOptions, Integer mountedMcpCount, |
| | | List<String> mountedMcpNames, List<String> mountErrors, |
| | | AiChatMemoryDto memory) { |
| | | return AiChatRuntimeDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .aiParamId(config.getAiParam().getId()) |
| | | .promptCode(config.getPromptCode()) |
| | | .promptName(config.getPrompt().getName()) |
| | | .model(config.getAiParam().getModel()) |
| | | .modelOptions(modelOptions) |
| | | .configuredMcpCount(config.getMcpMounts().size()) |
| | | .mountedMcpCount(mountedMcpCount) |
| | | .mountedMcpNames(mountedMcpNames) |
| | | .mountErrors(mountErrors) |
| | | .memorySummary(memory.getMemorySummary()) |
| | | .memoryFacts(memory.getMemoryFacts()) |
| | | .recentMessageCount(memory.getRecentMessageCount()) |
| | | .persistedMessages(memory.getPersistedMessages()) |
| | | .build(); |
| | | } |
| | | |
| | | private AiChatSession resolveSession(AiChatRequest request, Long userId, Long tenantId, String promptCode) { |
| | | /** 根据 sessionId 复用历史会话,或在首次提问时创建新会话。 */ |
| | | try { |
| | | return aiChatMemoryService.resolveSession(userId, tenantId, promptCode, request.getSessionId(), resolveTitleSeed(request.getMessages())); |
| | | } catch (Exception e) { |
| | | throw buildAiException("AI_SESSION_RESOLVE_ERROR", AiErrorCategory.REQUEST, "SESSION_RESOLVE", |
| | | e == null ? "AI 会话解析失败" : e.getMessage(), e); |
| | | } |
| | | } |
| | | |
| | | private AiChatMemoryDto loadMemory(Long userId, Long tenantId, String promptCode, Long sessionId) { |
| | | /** 读取会话的短期记忆、摘要记忆和事实记忆,供模型组装上下文。 */ |
| | | try { |
| | | return aiChatMemoryService.getMemory(userId, tenantId, promptCode, sessionId); |
| | | } catch (Exception e) { |
| | | throw buildAiException("AI_MEMORY_LOAD_ERROR", AiErrorCategory.REQUEST, "MEMORY_LOAD", |
| | | e == null ? "AI 会话记忆加载失败" : e.getMessage(), e); |
| | | } |
| | | } |
| | | |
| | | private McpMountRuntimeFactory.McpMountRuntime createRuntime(AiResolvedConfig config, Long userId) { |
| | | /** 按配置中的 MCP 挂载记录构造本轮对话专属的工具运行时。 */ |
| | | try { |
| | | return mcpMountRuntimeFactory.create(config.getMcpMounts(), userId); |
| | | } catch (Exception e) { |
| | | throw buildAiException("AI_MCP_MOUNT_ERROR", AiErrorCategory.MCP, "MCP_MOUNT", |
| | | e == null ? "MCP 挂载失败" : e.getMessage(), e); |
| | | } |
| | | } |
| | | |
| | | private ChatResponse invokeChatCall(OpenAiChatModel chatModel, Prompt prompt) { |
| | | try { |
| | | return chatModel.call(prompt); |
| | | } catch (Exception e) { |
| | | throw buildAiException("AI_MODEL_CALL_ERROR", AiErrorCategory.MODEL, "MODEL_CALL", |
| | | e == null ? "AI 模型调用失败" : e.getMessage(), e); |
| | | } |
| | | } |
| | | |
| | | private Flux<ChatResponse> invokeChatStream(OpenAiChatModel chatModel, Prompt prompt) { |
| | | try { |
| | | return chatModel.stream(prompt); |
| | | } catch (Exception e) { |
| | | throw buildAiException("AI_MODEL_STREAM_ERROR", AiErrorCategory.MODEL, "MODEL_STREAM_INIT", |
| | | e == null ? "AI 模型流式调用失败" : e.getMessage(), e); |
| | | } |
| | | } |
| | | |
| | | private void markFirstToken(AtomicReference<Long> firstTokenAtRef, SseEmitter emitter, String requestId, |
| | | Long sessionId, String model, long startedAt, ThinkingTraceEmitter thinkingTraceEmitter) { |
| | | if (!firstTokenAtRef.compareAndSet(null, System.currentTimeMillis())) { |
| | | return; |
| | | } |
| | | if (thinkingTraceEmitter != null) { |
| | | thinkingTraceEmitter.startAnswer(); |
| | | } |
| | | emitSafely(emitter, "status", AiChatStatusDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .status("FIRST_TOKEN") |
| | | .model(model) |
| | | .timestamp(Instant.now().toEpochMilli()) |
| | | .elapsedMs(System.currentTimeMillis() - startedAt) |
| | | .firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAtRef.get())) |
| | | .build()); |
| | | } |
| | | |
| | | private AiChatStatusDto buildTerminalStatus(String requestId, Long sessionId, String status, String model, long startedAt, Long firstTokenAt) { |
| | | return AiChatStatusDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .status(status) |
| | | .model(model) |
| | | .timestamp(Instant.now().toEpochMilli()) |
| | | .elapsedMs(System.currentTimeMillis() - startedAt) |
| | | .firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAt)) |
| | | .build(); |
| | | } |
| | | |
| | | private Long resolveFirstTokenLatency(long startedAt, Long firstTokenAt) { |
| | | return firstTokenAt == null ? null : Math.max(0L, firstTokenAt - startedAt); |
| | | } |
| | | |
| | | private void handleStreamFailure(SseEmitter emitter, String requestId, Long sessionId, String model, long startedAt, |
| | | Long firstTokenAt, AiChatException exception, Long callLogId, |
| | | long toolSuccessCount, long toolFailureCount, |
| | | ThinkingTraceEmitter thinkingTraceEmitter, |
| | | Long tenantId, Long userId, String promptCode) { |
| | | if (isClientAbortException(exception)) { |
| | | log.warn("AI chat aborted by client, requestId={}, sessionId={}, stage={}, message={}", |
| | | requestId, sessionId, exception.getStage(), exception.getMessage()); |
| | | if (thinkingTraceEmitter != null) { |
| | | thinkingTraceEmitter.markTerminated("ABORTED"); |
| | | } |
| | | emitSafely(emitter, "status", buildTerminalStatus(requestId, sessionId, "ABORTED", model, startedAt, firstTokenAt)); |
| | | aiCallLogService.failCallLog( |
| | | callLogId, |
| | | "ABORTED", |
| | | exception.getCategory().name(), |
| | | exception.getStage(), |
| | | exception.getMessage(), |
| | | System.currentTimeMillis() - startedAt, |
| | | resolveFirstTokenLatency(startedAt, firstTokenAt), |
| | | toolSuccessCount, |
| | | toolFailureCount |
| | | ); |
| | | aiRedisSupport.markStreamState(requestId, tenantId, userId, sessionId, promptCode, "ABORTED", exception.getMessage()); |
| | | emitter.completeWithError(exception); |
| | | return; |
| | | } |
| | | log.error("AI chat failed, requestId={}, sessionId={}, category={}, stage={}, message={}", |
| | | requestId, sessionId, exception.getCategory(), exception.getStage(), exception.getMessage(), exception); |
| | | if (thinkingTraceEmitter != null) { |
| | | thinkingTraceEmitter.markTerminated("FAILED"); |
| | | } |
| | | emitSafely(emitter, "status", buildTerminalStatus(requestId, sessionId, "FAILED", model, startedAt, firstTokenAt)); |
| | | emitSafely(emitter, "error", AiChatErrorDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .code(exception.getCode()) |
| | | .category(exception.getCategory().name()) |
| | | .stage(exception.getStage()) |
| | | .message(exception.getMessage()) |
| | | .timestamp(Instant.now().toEpochMilli()) |
| | | .build()); |
| | | aiCallLogService.failCallLog( |
| | | callLogId, |
| | | "FAILED", |
| | | exception.getCategory().name(), |
| | | exception.getStage(), |
| | | exception.getMessage(), |
| | | System.currentTimeMillis() - startedAt, |
| | | resolveFirstTokenLatency(startedAt, firstTokenAt), |
| | | toolSuccessCount, |
| | | toolFailureCount |
| | | ); |
| | | aiRedisSupport.markStreamState(requestId, tenantId, userId, sessionId, promptCode, "FAILED", exception.getMessage()); |
| | | emitter.completeWithError(exception); |
| | | } |
| | | |
| | | private OpenAiChatModel createChatModel(AiParam aiParam) { |
| | | OpenAiApi openAiApi = buildOpenAiApi(aiParam); |
| | | ToolCallingManager toolCallingManager = DefaultToolCallingManager.builder() |
| | | .observationRegistry(observationRegistry) |
| | | .toolCallbackResolver(new SpringBeanToolCallbackResolver(applicationContext, SchemaType.OPEN_API_SCHEMA)) |
| | | .toolExecutionExceptionProcessor(new DefaultToolExecutionExceptionProcessor(false)) |
| | | .build(); |
| | | return new OpenAiChatModel( |
| | | openAiApi, |
| | | OpenAiChatOptions.builder() |
| | | .model(aiParam.getModel()) |
| | | .temperature(aiParam.getTemperature()) |
| | | .topP(aiParam.getTopP()) |
| | | .maxTokens(aiParam.getMaxTokens()) |
| | | .streamUsage(true) |
| | | .build(), |
| | | toolCallingManager, |
| | | org.springframework.retry.support.RetryTemplate.builder().maxAttempts(1).build(), |
| | | observationRegistry |
| | | ); |
| | | } |
| | | |
| | | private OpenAiApi buildOpenAiApi(AiParam aiParam) { |
| | | return AiOpenAiApiSupport.buildOpenAiApi(aiParam); |
| | | } |
| | | |
| | | private OpenAiChatOptions buildChatOptions(AiParam aiParam, ToolCallback[] toolCallbacks, Long userId, Long tenantId, |
| | | String requestId, Long sessionId, Map<String, Object> metadata) { |
| | | /** |
| | | * 组装一次聊天调用的全部模型选项和 Tool Context。 |
| | | * Tool Context 会透传给内置工具和外部 MCP,保证工具在租户和会话范围内执行。 |
| | | */ |
| | | if (userId == null) { |
| | | throw buildAiException("AI_AUTH_USER_MISSING", AiErrorCategory.AUTH, "OPTIONS_BUILD", "当前登录用户不存在", null); |
| | | } |
| | | OpenAiChatOptions.Builder builder = OpenAiChatOptions.builder() |
| | | .model(aiParam.getModel()) |
| | | .temperature(aiParam.getTemperature()) |
| | | .topP(aiParam.getTopP()) |
| | | .maxTokens(aiParam.getMaxTokens()) |
| | | .streamUsage(true) |
| | | .user(String.valueOf(userId)); |
| | | if (!Cools.isEmpty(toolCallbacks)) { |
| | | builder.toolCallbacks(Arrays.stream(toolCallbacks).toList()); |
| | | } |
| | | Map<String, Object> toolContext = new LinkedHashMap<>(); |
| | | toolContext.put("userId", userId); |
| | | toolContext.put("tenantId", tenantId); |
| | | toolContext.put("requestId", requestId); |
| | | toolContext.put("sessionId", sessionId); |
| | | Map<String, String> metadataMap = new LinkedHashMap<>(); |
| | | if (metadata != null) { |
| | | metadata.forEach((key, value) -> { |
| | | String normalized = value == null ? "" : String.valueOf(value); |
| | | metadataMap.put(key, normalized); |
| | | toolContext.put(key, normalized); |
| | | }); |
| | | } |
| | | builder.toolContext(toolContext); |
| | | if (!metadataMap.isEmpty()) { |
| | | builder.metadata(metadataMap); |
| | | } |
| | | return builder.build(); |
| | | } |
| | | |
| | | private ToolCallback[] wrapToolCallbacks(ToolCallback[] toolCallbacks, SseEmitter emitter, String requestId, |
| | | Long sessionId, AtomicLong toolCallSequence, |
| | | AtomicLong toolSuccessCount, AtomicLong toolFailureCount, |
| | | Long callLogId, Long userId, Long tenantId, |
| | | ThinkingTraceEmitter thinkingTraceEmitter) { |
| | | /** 给所有工具回调套上一层可观测包装,用于实时 SSE 轨迹和审计日志落库。 */ |
| | | if (Cools.isEmpty(toolCallbacks)) { |
| | | return toolCallbacks; |
| | | } |
| | | List<ToolCallback> wrappedCallbacks = new ArrayList<>(); |
| | | for (ToolCallback callback : toolCallbacks) { |
| | | if (callback == null) { |
| | | continue; |
| | | } |
| | | wrappedCallbacks.add(new ObservableToolCallback(callback, emitter, requestId, sessionId, toolCallSequence, |
| | | toolSuccessCount, toolFailureCount, callLogId, userId, tenantId, thinkingTraceEmitter)); |
| | | } |
| | | return wrappedCallbacks.toArray(new ToolCallback[0]); |
| | | } |
| | | |
| | | private List<Message> buildPromptMessages(AiChatMemoryDto memory, List<AiChatMessageDto> sourceMessages, AiPrompt aiPrompt, Map<String, Object> metadata) { |
| | | /** |
| | | * 组装最终提交给模型的消息列表。 |
| | | * 顺序上始终是:系统 Prompt -> 历史摘要 -> 关键事实 -> 最近对话 -> 当前用户输入。 |
| | | */ |
| | | if (Cools.isEmpty(sourceMessages)) { |
| | | throw new CoolException("对话消息不能为空"); |
| | | } |
| | | List<Message> messages = new ArrayList<>(); |
| | | if (StringUtils.hasText(aiPrompt.getSystemPrompt())) { |
| | | messages.add(new SystemMessage(aiPrompt.getSystemPrompt())); |
| | | } |
| | | if (memory != null && StringUtils.hasText(memory.getMemorySummary())) { |
| | | messages.add(new SystemMessage("历史摘要:\n" + memory.getMemorySummary())); |
| | | } |
| | | if (memory != null && StringUtils.hasText(memory.getMemoryFacts())) { |
| | | messages.add(new SystemMessage("关键事实:\n" + memory.getMemoryFacts())); |
| | | } |
| | | int lastUserIndex = -1; |
| | | for (int i = 0; i < sourceMessages.size(); i++) { |
| | | AiChatMessageDto item = sourceMessages.get(i); |
| | | if (item != null && "user".equalsIgnoreCase(item.getRole())) { |
| | | lastUserIndex = i; |
| | | } |
| | | } |
| | | for (int i = 0; i < sourceMessages.size(); i++) { |
| | | AiChatMessageDto item = sourceMessages.get(i); |
| | | if (item == null || !StringUtils.hasText(item.getContent())) { |
| | | continue; |
| | | } |
| | | String role = item.getRole() == null ? "user" : item.getRole().toLowerCase(); |
| | | if ("system".equals(role)) { |
| | | continue; |
| | | } |
| | | String content = item.getContent(); |
| | | if ("user".equals(role) && i == lastUserIndex) { |
| | | content = renderUserPrompt(aiPrompt.getUserPromptTemplate(), content, metadata); |
| | | } |
| | | if ("assistant".equals(role)) { |
| | | messages.add(new AssistantMessage(content)); |
| | | } else { |
| | | messages.add(new UserMessage(content)); |
| | | } |
| | | } |
| | | if (messages.stream().noneMatch(item -> item instanceof UserMessage)) { |
| | | throw new CoolException("至少需要一条用户消息"); |
| | | } |
| | | return messages; |
| | | } |
| | | |
| | | private List<AiChatMessageDto> mergeMessages(List<AiChatMessageDto> persistedMessages, List<AiChatMessageDto> memoryMessages) { |
| | | /** 把落库历史与本轮前端内存增量合并成模型可消费的完整上下文。 */ |
| | | List<AiChatMessageDto> merged = new ArrayList<>(); |
| | | if (!Cools.isEmpty(persistedMessages)) { |
| | | merged.addAll(persistedMessages); |
| | | } |
| | | if (!Cools.isEmpty(memoryMessages)) { |
| | | merged.addAll(memoryMessages); |
| | | } |
| | | if (merged.isEmpty()) { |
| | | throw new CoolException("对话消息不能为空"); |
| | | } |
| | | return merged; |
| | | } |
| | | |
| | | private String resolveTitleSeed(List<AiChatMessageDto> messages) { |
| | | if (Cools.isEmpty(messages)) { |
| | | throw new CoolException("对话消息不能为空"); |
| | | } |
| | | for (int i = messages.size() - 1; i >= 0; i--) { |
| | | AiChatMessageDto item = messages.get(i); |
| | | if (item != null && "user".equalsIgnoreCase(item.getRole()) && StringUtils.hasText(item.getContent())) { |
| | | return item.getContent(); |
| | | } |
| | | } |
| | | throw new CoolException("至少需要一条用户消息"); |
| | | } |
| | | |
| | | private String renderUserPrompt(String userPromptTemplate, String content, Map<String, Object> metadata) { |
| | | if (!StringUtils.hasText(userPromptTemplate)) { |
| | | return content; |
| | | } |
| | | String rendered = userPromptTemplate |
| | | .replace("{{input}}", content) |
| | | .replace("{input}", content); |
| | | if (metadata != null) { |
| | | for (Map.Entry<String, Object> entry : metadata.entrySet()) { |
| | | String value = entry.getValue() == null ? "" : String.valueOf(entry.getValue()); |
| | | rendered = rendered.replace("{{" + entry.getKey() + "}}", value); |
| | | rendered = rendered.replace("{" + entry.getKey() + "}", value); |
| | | } |
| | | } |
| | | if (Objects.equals(rendered, userPromptTemplate)) { |
| | | return userPromptTemplate + "\n\n" + content; |
| | | } |
| | | return rendered; |
| | | } |
| | | |
| | | private String extractContent(ChatResponse response) { |
| | | if (response == null || response.getResult() == null || response.getResult().getOutput() == null) { |
| | | return null; |
| | | } |
| | | return response.getResult().getOutput().getText(); |
| | | } |
| | | |
| | | private String summarizeToolPayload(String content, int maxLength) { |
| | | if (!StringUtils.hasText(content)) { |
| | | return null; |
| | | } |
| | | String normalized = content.trim() |
| | | .replace("\r", " ") |
| | | .replace("\n", " ") |
| | | .replaceAll("\\s+", " "); |
| | | return normalized.length() > maxLength ? normalized.substring(0, maxLength) : normalized; |
| | | } |
| | | |
| | | private void emitDone(SseEmitter emitter, String requestId, ChatResponseMetadata metadata, String fallbackModel, Long sessionId, long startedAt, Long firstTokenAt) { |
| | | /** 输出对话完成事件,统一封装耗时、首包延迟和 token 用量。 */ |
| | | Usage usage = metadata == null ? null : metadata.getUsage(); |
| | | emitStrict(emitter, "done", AiChatDoneDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .model(metadata != null && StringUtils.hasText(metadata.getModel()) ? metadata.getModel() : fallbackModel) |
| | | .elapsedMs(System.currentTimeMillis() - startedAt) |
| | | .firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAt)) |
| | | .promptTokens(usage == null ? null : usage.getPromptTokens()) |
| | | .completionTokens(usage == null ? null : usage.getCompletionTokens()) |
| | | .totalTokens(usage == null ? null : usage.getTotalTokens()) |
| | | .build()); |
| | | } |
| | | |
| | | private Map<String, String> buildMessagePayload(String... keyValues) { |
| | | Map<String, String> payload = new LinkedHashMap<>(); |
| | | if (keyValues == null || keyValues.length == 0) { |
| | | return payload; |
| | | } |
| | | if (keyValues.length % 2 != 0) { |
| | | throw new CoolException("消息载荷参数必须成对出现"); |
| | | } |
| | | for (int i = 0; i < keyValues.length; i += 2) { |
| | | payload.put(keyValues[i], keyValues[i + 1] == null ? "" : keyValues[i + 1]); |
| | | } |
| | | return payload; |
| | | } |
| | | |
| | | private void emitStrict(SseEmitter emitter, String eventName, Object payload) { |
| | | /** 严格发送 SSE 事件;一旦发送失败,直接上抛为流式输出异常。 */ |
| | | try { |
| | | String data = objectMapper.writeValueAsString(payload); |
| | | emitter.send(SseEmitter.event() |
| | | .name(eventName) |
| | | .data(data, MediaType.APPLICATION_JSON)); |
| | | } catch (IOException e) { |
| | | throw buildAiException("AI_SSE_EMIT_ERROR", AiErrorCategory.STREAM, "SSE_EMIT", "SSE 输出失败: " + e.getMessage(), e); |
| | | } |
| | | } |
| | | |
| | | private void emitSafely(SseEmitter emitter, String eventName, Object payload) { |
| | | /** 尝试发送非关键事件,发送失败只记录日志,不打断主对话流程。 */ |
| | | try { |
| | | emitStrict(emitter, eventName, payload); |
| | | } catch (Exception e) { |
| | | log.warn("AI SSE event emit skipped, eventName={}, message={}", eventName, e.getMessage()); |
| | | } |
| | | } |
| | | |
| | | private AiChatException buildAiException(String code, AiErrorCategory category, String stage, String message, Throwable cause) { |
| | | return new AiChatException(code, category, stage, message, cause); |
| | | } |
| | | |
| | | private boolean isClientAbortException(Throwable throwable) { |
| | | Throwable current = throwable; |
| | | while (current != null) { |
| | | String message = current.getMessage(); |
| | | if (message != null) { |
| | | String normalized = message.toLowerCase(); |
| | | if (normalized.contains("broken pipe") |
| | | || normalized.contains("connection reset") |
| | | || normalized.contains("forcibly closed") |
| | | || normalized.contains("abort")) { |
| | | return true; |
| | | } |
| | | } |
| | | current = current.getCause(); |
| | | } |
| | | return false; |
| | | } |
| | | |
| | | private class ThinkingTraceEmitter { |
| | | |
| | | private final SseEmitter emitter; |
| | | private final String requestId; |
| | | private final Long sessionId; |
| | | private String currentPhase; |
| | | private String currentStatus; |
| | | |
| | | private ThinkingTraceEmitter(SseEmitter emitter, String requestId, Long sessionId) { |
| | | this.emitter = emitter; |
| | | this.requestId = requestId; |
| | | this.sessionId = sessionId; |
| | | } |
| | | |
| | | private void startAnalyze() { |
| | | if (currentPhase != null) { |
| | | return; |
| | | } |
| | | currentPhase = "ANALYZE"; |
| | | currentStatus = "STARTED"; |
| | | emitThinkingEvent("ANALYZE", "STARTED", "正在分析问题", |
| | | "已接收你的问题,正在理解意图并判断是否需要调用工具。", null); |
| | | } |
| | | |
| | | private void onToolStart(String toolName, String toolCallId) { |
| | | switchPhase("TOOL_CALL", "STARTED", "正在调用工具", "已判断需要调用工具,正在查询相关信息。", null); |
| | | currentStatus = "UPDATED"; |
| | | emitThinkingEvent("TOOL_CALL", "UPDATED", "正在调用工具", |
| | | "正在调用工具 " + safeLabel(toolName, "未知工具") + " 获取所需信息。", toolCallId); |
| | | } |
| | | |
| | | private void onToolResult(String toolName, String toolCallId, boolean failed) { |
| | | currentPhase = "TOOL_CALL"; |
| | | currentStatus = failed ? "FAILED" : "UPDATED"; |
| | | emitThinkingEvent("TOOL_CALL", failed ? "FAILED" : "UPDATED", |
| | | failed ? "工具调用失败" : "工具调用完成", |
| | | failed |
| | | ? "工具 " + safeLabel(toolName, "未知工具") + " 调用失败,正在评估失败影响并整理可用信息。" |
| | | : "工具 " + safeLabel(toolName, "未知工具") + " 已返回结果,正在继续分析并提炼关键信息。", |
| | | toolCallId); |
| | | } |
| | | |
| | | private void startAnswer() { |
| | | switchPhase("ANSWER", "STARTED", "正在整理答案", "已完成分析,正在组织最终回复内容。", null); |
| | | } |
| | | |
| | | private void completeCurrentPhase() { |
| | | if (!StringUtils.hasText(currentPhase) || isTerminalStatus(currentStatus)) { |
| | | return; |
| | | } |
| | | currentStatus = "COMPLETED"; |
| | | emitThinkingEvent(currentPhase, "COMPLETED", resolveCompleteTitle(currentPhase), |
| | | resolveCompleteContent(currentPhase), null); |
| | | } |
| | | |
| | | private void markTerminated(String terminalStatus) { |
| | | if (!StringUtils.hasText(currentPhase) || isTerminalStatus(currentStatus)) { |
| | | return; |
| | | } |
| | | currentStatus = terminalStatus; |
| | | emitThinkingEvent(currentPhase, terminalStatus, |
| | | "ABORTED".equals(terminalStatus) ? "思考已中止" : "思考失败", |
| | | "ABORTED".equals(terminalStatus) |
| | | ? "本轮对话已被中止,思考过程提前结束。" |
| | | : "本轮对话在生成答案前失败,当前思考过程已停止。", |
| | | null); |
| | | } |
| | | |
| | | private void switchPhase(String nextPhase, String nextStatus, String title, String content, String toolCallId) { |
| | | if (!Objects.equals(currentPhase, nextPhase)) { |
| | | completeCurrentPhase(); |
| | | } |
| | | currentPhase = nextPhase; |
| | | currentStatus = nextStatus; |
| | | emitThinkingEvent(nextPhase, nextStatus, title, content, toolCallId); |
| | | } |
| | | |
| | | private void emitThinkingEvent(String phase, String status, String title, String content, String toolCallId) { |
| | | emitSafely(emitter, "thinking", AiChatThinkingEventDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .phase(phase) |
| | | .status(status) |
| | | .title(title) |
| | | .content(content) |
| | | .toolCallId(toolCallId) |
| | | .timestamp(Instant.now().toEpochMilli()) |
| | | .build()); |
| | | } |
| | | |
| | | private boolean isTerminalStatus(String status) { |
| | | return "COMPLETED".equals(status) || "FAILED".equals(status) || "ABORTED".equals(status); |
| | | } |
| | | |
| | | private String resolveCompleteTitle(String phase) { |
| | | if ("ANSWER".equals(phase)) { |
| | | return "答案整理完成"; |
| | | } |
| | | if ("TOOL_CALL".equals(phase)) { |
| | | return "工具分析完成"; |
| | | } |
| | | return "问题分析完成"; |
| | | } |
| | | |
| | | private String resolveCompleteContent(String phase) { |
| | | if ("ANSWER".equals(phase)) { |
| | | return "最终答复已生成完成。"; |
| | | } |
| | | if ("TOOL_CALL".equals(phase)) { |
| | | return "工具调用阶段已结束,相关信息已整理完毕。"; |
| | | } |
| | | return "问题意图和处理方向已分析完成。"; |
| | | } |
| | | |
| | | private String safeLabel(String value, String fallback) { |
| | | return StringUtils.hasText(value) ? value : fallback; |
| | | } |
| | | } |
| | | |
| | | private class ObservableToolCallback implements ToolCallback { |
| | | |
| | | private final ToolCallback delegate; |
| | | private final SseEmitter emitter; |
| | | private final String requestId; |
| | | private final Long sessionId; |
| | | private final AtomicLong toolCallSequence; |
| | | private final AtomicLong toolSuccessCount; |
| | | private final AtomicLong toolFailureCount; |
| | | private final Long callLogId; |
| | | private final Long userId; |
| | | private final Long tenantId; |
| | | private final ThinkingTraceEmitter thinkingTraceEmitter; |
| | | |
| | | private ObservableToolCallback(ToolCallback delegate, SseEmitter emitter, String requestId, |
| | | Long sessionId, AtomicLong toolCallSequence, |
| | | AtomicLong toolSuccessCount, AtomicLong toolFailureCount, |
| | | Long callLogId, Long userId, Long tenantId, |
| | | ThinkingTraceEmitter thinkingTraceEmitter) { |
| | | this.delegate = delegate; |
| | | this.emitter = emitter; |
| | | this.requestId = requestId; |
| | | this.sessionId = sessionId; |
| | | this.toolCallSequence = toolCallSequence; |
| | | this.toolSuccessCount = toolSuccessCount; |
| | | this.toolFailureCount = toolFailureCount; |
| | | this.callLogId = callLogId; |
| | | this.userId = userId; |
| | | this.tenantId = tenantId; |
| | | this.thinkingTraceEmitter = thinkingTraceEmitter; |
| | | } |
| | | |
| | | @Override |
| | | public org.springframework.ai.tool.definition.ToolDefinition getToolDefinition() { |
| | | return delegate.getToolDefinition(); |
| | | } |
| | | |
| | | @Override |
| | | public org.springframework.ai.tool.metadata.ToolMetadata getToolMetadata() { |
| | | return delegate.getToolMetadata(); |
| | | } |
| | | |
| | | @Override |
| | | public String call(String toolInput) { |
| | | return call(toolInput, null); |
| | | } |
| | | |
| | | @Override |
| | | public String call(String toolInput, ToolContext toolContext) { |
| | | /** |
| | | * 工具执行观测包装器。 |
| | | * 在真实调用前后分别发送 tool_start / tool_result / tool_error, |
| | | * 同时把调用摘要写入 MCP 调用日志表。 |
| | | */ |
| | | String toolName = delegate.getToolDefinition() == null ? "unknown" : delegate.getToolDefinition().name(); |
| | | String mountName = delegate instanceof MountedToolCallback ? ((MountedToolCallback) delegate).getMountName() : null; |
| | | String toolCallId = requestId + "-tool-" + toolCallSequence.incrementAndGet(); |
| | | long startedAt = System.currentTimeMillis(); |
| | | // 这里只对同一 request 内的重复工具调用做短期复用,避免把跨请求结果误当成通用缓存。 |
| | | AiRedisSupport.CachedToolResult cachedToolResult = aiRedisSupport.getToolResult(tenantId, requestId, toolName, toolInput); |
| | | if (cachedToolResult != null) { |
| | | emitSafely(emitter, "tool_result", AiChatToolEventDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .toolCallId(toolCallId) |
| | | .toolName(toolName) |
| | | .mountName(mountName) |
| | | .status(cachedToolResult.isSuccess() ? "COMPLETED" : "FAILED") |
| | | .inputSummary(summarizeToolPayload(toolInput, 400)) |
| | | .outputSummary(summarizeToolPayload(cachedToolResult.getOutput(), 600)) |
| | | .errorMessage(cachedToolResult.getErrorMessage()) |
| | | .durationMs(0L) |
| | | .timestamp(System.currentTimeMillis()) |
| | | .build()); |
| | | if (thinkingTraceEmitter != null) { |
| | | thinkingTraceEmitter.onToolResult(toolName, toolCallId, !cachedToolResult.isSuccess()); |
| | | } |
| | | if (cachedToolResult.isSuccess()) { |
| | | toolSuccessCount.incrementAndGet(); |
| | | aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName, |
| | | "COMPLETED", summarizeToolPayload(toolInput, 400), summarizeToolPayload(cachedToolResult.getOutput(), 600), |
| | | null, 0L, userId, tenantId); |
| | | return cachedToolResult.getOutput(); |
| | | } |
| | | toolFailureCount.incrementAndGet(); |
| | | aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName, |
| | | "FAILED", summarizeToolPayload(toolInput, 400), null, cachedToolResult.getErrorMessage(), |
| | | 0L, userId, tenantId); |
| | | throw new CoolException(cachedToolResult.getErrorMessage()); |
| | | } |
| | | if (thinkingTraceEmitter != null) { |
| | | thinkingTraceEmitter.onToolStart(toolName, toolCallId); |
| | | } |
| | | emitSafely(emitter, "tool_start", AiChatToolEventDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .toolCallId(toolCallId) |
| | | .toolName(toolName) |
| | | .mountName(mountName) |
| | | .status("STARTED") |
| | | .inputSummary(summarizeToolPayload(toolInput, 400)) |
| | | .timestamp(startedAt) |
| | | .build()); |
| | | try { |
| | | String output = toolContext == null ? delegate.call(toolInput) : delegate.call(toolInput, toolContext); |
| | | long durationMs = System.currentTimeMillis() - startedAt; |
| | | emitSafely(emitter, "tool_result", AiChatToolEventDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .toolCallId(toolCallId) |
| | | .toolName(toolName) |
| | | .mountName(mountName) |
| | | .status("COMPLETED") |
| | | .inputSummary(summarizeToolPayload(toolInput, 400)) |
| | | .outputSummary(summarizeToolPayload(output, 600)) |
| | | .durationMs(durationMs) |
| | | .timestamp(System.currentTimeMillis()) |
| | | .build()); |
| | | if (thinkingTraceEmitter != null) { |
| | | thinkingTraceEmitter.onToolResult(toolName, toolCallId, false); |
| | | } |
| | | aiRedisSupport.cacheToolResult(tenantId, requestId, toolName, toolInput, true, output, null); |
| | | toolSuccessCount.incrementAndGet(); |
| | | aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName, |
| | | "COMPLETED", summarizeToolPayload(toolInput, 400), summarizeToolPayload(output, 600), |
| | | null, durationMs, userId, tenantId); |
| | | return output; |
| | | } catch (RuntimeException e) { |
| | | long durationMs = System.currentTimeMillis() - startedAt; |
| | | emitSafely(emitter, "tool_error", AiChatToolEventDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .toolCallId(toolCallId) |
| | | .toolName(toolName) |
| | | .mountName(mountName) |
| | | .status("FAILED") |
| | | .inputSummary(summarizeToolPayload(toolInput, 400)) |
| | | .errorMessage(e.getMessage()) |
| | | .durationMs(durationMs) |
| | | .timestamp(System.currentTimeMillis()) |
| | | .build()); |
| | | if (thinkingTraceEmitter != null) { |
| | | thinkingTraceEmitter.onToolResult(toolName, toolCallId, true); |
| | | } |
| | | aiRedisSupport.cacheToolResult(tenantId, requestId, toolName, toolInput, false, null, e.getMessage()); |
| | | toolFailureCount.incrementAndGet(); |
| | | aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName, |
| | | "FAILED", summarizeToolPayload(toolInput, 400), null, e.getMessage(), |
| | | durationMs, userId, tenantId); |
| | | throw e; |
| | | } |
| | | } |
| | | } |
| | | } |
| | |
| | | |
| | | import com.vincent.rsf.server.ai.config.AiDefaults; |
| | | import com.vincent.rsf.server.ai.dto.AiResolvedConfig; |
| | | import com.vincent.rsf.server.ai.store.AiConfigCacheStore; |
| | | import com.vincent.rsf.server.ai.service.AiConfigResolverService; |
| | | import com.vincent.rsf.server.ai.service.AiMcpMountService; |
| | | import com.vincent.rsf.server.ai.service.AiParamService; |
| | |
| | | private final AiParamService aiParamService; |
| | | private final AiPromptService aiPromptService; |
| | | private final AiMcpMountService aiMcpMountService; |
| | | private final AiRedisSupport aiRedisSupport; |
| | | private final AiConfigCacheStore aiConfigCacheStore; |
| | | |
| | | /** |
| | | * 按租户解析一次完整的 AI 运行配置。 |
| | |
| | | } |
| | | String finalPromptCode = StringUtils.hasText(promptCode) ? promptCode : AiDefaults.DEFAULT_PROMPT_CODE; |
| | | // 配置解析是多个入口共享的热点路径,命中缓存时可以避免三张配置表的重复查询。 |
| | | AiResolvedConfig cached = aiRedisSupport.getResolvedConfig(tenantId, finalPromptCode, aiParamId); |
| | | AiResolvedConfig cached = aiConfigCacheStore.getResolvedConfig(tenantId, finalPromptCode, aiParamId); |
| | | if (cached != null) { |
| | | return cached; |
| | | } |
| | |
| | | .prompt(aiPromptService.getActivePrompt(finalPromptCode, tenantId)) |
| | | .mcpMounts(aiMcpMountService.listActiveMounts(tenantId)) |
| | | .build(); |
| | | aiRedisSupport.cacheResolvedConfig(tenantId, finalPromptCode, aiParamId, resolvedConfig); |
| | | aiConfigCacheStore.cacheResolvedConfig(tenantId, finalPromptCode, aiParamId, resolvedConfig); |
| | | return resolvedConfig; |
| | | } |
| | | } |
| | |
| | | package com.vincent.rsf.server.ai.service.impl; |
| | | |
| | | import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; |
| | | import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper; |
| | | import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; |
| | | import com.fasterxml.jackson.databind.ObjectMapper; |
| | | import com.vincent.rsf.framework.exception.CoolException; |
| | | import com.vincent.rsf.server.ai.config.AiDefaults; |
| | | import com.vincent.rsf.server.ai.dto.AiMcpConnectivityTestDto; |
| | |
| | | import com.vincent.rsf.server.ai.dto.AiMcpToolTestRequest; |
| | | import com.vincent.rsf.server.ai.entity.AiMcpMount; |
| | | import com.vincent.rsf.server.ai.mapper.AiMcpMountMapper; |
| | | import com.vincent.rsf.server.ai.service.impl.mcp.AiMcpAdminService; |
| | | import com.vincent.rsf.server.ai.store.AiConfigCacheStore; |
| | | import com.vincent.rsf.server.ai.store.AiConversationCacheStore; |
| | | import com.vincent.rsf.server.ai.store.AiMcpCacheStore; |
| | | import com.vincent.rsf.server.ai.service.AiMcpMountService; |
| | | import com.vincent.rsf.server.ai.service.BuiltinMcpToolRegistry; |
| | | import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory; |
| | | import com.vincent.rsf.server.system.enums.StatusType; |
| | | import lombok.RequiredArgsConstructor; |
| | | import org.springframework.ai.chat.model.ToolContext; |
| | | import org.springframework.ai.tool.ToolCallback; |
| | | import org.springframework.stereotype.Service; |
| | | import org.springframework.util.StringUtils; |
| | | |
| | | import java.util.ArrayList; |
| | | import java.util.Arrays; |
| | | import java.util.Date; |
| | | import java.util.List; |
| | | import java.util.Map; |
| | | |
| | | @Service("aiMcpMountService") |
| | | @RequiredArgsConstructor |
| | | public class AiMcpMountServiceImpl extends ServiceImpl<AiMcpMountMapper, AiMcpMount> implements AiMcpMountService { |
| | | |
| | | private final BuiltinMcpToolRegistry builtinMcpToolRegistry; |
| | | private final McpMountRuntimeFactory mcpMountRuntimeFactory; |
| | | private final ObjectMapper objectMapper; |
| | | private final AiRedisSupport aiRedisSupport; |
| | | private final AiMcpAdminService aiMcpAdminService; |
| | | private final AiMcpCacheStore aiMcpCacheStore; |
| | | private final AiConfigCacheStore aiConfigCacheStore; |
| | | private final AiConversationCacheStore aiConversationCacheStore; |
| | | |
| | | /** 查询某个租户下当前启用的 MCP 挂载列表。 */ |
| | | @Override |
| | |
| | | if (aiMcpMount.getId() == null) { |
| | | throw new CoolException("MCP 挂载 ID 不能为空"); |
| | | } |
| | | AiMcpMount current = requireMount(aiMcpMount.getId(), tenantId); |
| | | AiMcpMount current = aiMcpAdminService.requireMount(aiMcpMount.getId(), tenantId); |
| | | aiMcpMount.setTenantId(current.getTenantId()); |
| | | ensureRequiredFields(aiMcpMount, tenantId); |
| | | } |
| | |
| | | */ |
| | | @Override |
| | | public List<AiMcpToolPreviewDto> previewTools(Long mountId, Long userId, Long tenantId) { |
| | | AiMcpMount mount = requireMount(mountId, tenantId); |
| | | // 工具目录预览初始化成本高,但变化频率低,适合做管理端短缓存。 |
| | | List<AiMcpToolPreviewDto> cached = aiRedisSupport.getToolPreview(tenantId, mountId); |
| | | AiMcpMount mount = aiMcpAdminService.requireMount(mountId, tenantId); |
| | | List<AiMcpToolPreviewDto> cached = aiMcpCacheStore.getToolPreview(tenantId, mountId); |
| | | if (cached != null) { |
| | | return cached; |
| | | } |
| | | long startedAt = System.currentTimeMillis(); |
| | | try (McpMountRuntimeFactory.McpMountRuntime runtime = mcpMountRuntimeFactory.create(List.of(mount), userId)) { |
| | | List<AiMcpToolPreviewDto> tools = buildToolPreviewDtos(runtime.getToolCallbacks(), |
| | | AiDefaults.MCP_TRANSPORT_BUILTIN.equals(mount.getTransportType()) |
| | | ? builtinMcpToolRegistry.listBuiltinToolCatalog(mount.getBuiltinCode()) |
| | | : List.of()); |
| | | if (!runtime.getErrors().isEmpty()) { |
| | | String message = String.join(";", runtime.getErrors()); |
| | | updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY, message, System.currentTimeMillis() - startedAt); |
| | | throw new CoolException(message); |
| | | } |
| | | updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_HEALTHY, |
| | | "工具解析成功,共 " + tools.size() + " 个工具", System.currentTimeMillis() - startedAt); |
| | | aiRedisSupport.cacheToolPreview(tenantId, mountId, tools); |
| | | List<AiMcpToolPreviewDto> tools = aiMcpAdminService.previewTools(mount, userId); |
| | | aiMcpCacheStore.cacheToolPreview(tenantId, mountId, tools); |
| | | return tools; |
| | | } catch (CoolException e) { |
| | | throw e; |
| | | } catch (Exception e) { |
| | | updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY, |
| | | "工具解析失败: " + e.getMessage(), System.currentTimeMillis() - startedAt); |
| | | throw new CoolException("获取工具列表失败: " + e.getMessage()); |
| | | } |
| | | } |
| | | |
| | | /** 对已保存的挂载做真实连通性测试,并把结果回写到运行态字段。 */ |
| | | @Override |
| | | public AiMcpConnectivityTestDto testConnectivity(Long mountId, Long userId, Long tenantId) { |
| | | AiMcpMount mount = requireMount(mountId, tenantId); |
| | | long startedAt = System.currentTimeMillis(); |
| | | try (McpMountRuntimeFactory.McpMountRuntime runtime = mcpMountRuntimeFactory.create(List.of(mount), userId)) { |
| | | long elapsedMs = System.currentTimeMillis() - startedAt; |
| | | if (!runtime.getErrors().isEmpty()) { |
| | | String message = String.join(";", runtime.getErrors()); |
| | | updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY, message, elapsedMs); |
| | | AiMcpMount latest = requireMount(mount.getId(), tenantId); |
| | | AiMcpConnectivityTestDto connectivity = buildConnectivityDto(latest, message, elapsedMs, runtime.getToolCallbacks().length); |
| | | aiRedisSupport.cacheConnectivity(tenantId, mountId, connectivity); |
| | | AiMcpMount mount = aiMcpAdminService.requireMount(mountId, tenantId); |
| | | AiMcpConnectivityTestDto connectivity = aiMcpAdminService.testConnectivity(mount, userId, true); |
| | | aiMcpCacheStore.cacheConnectivity(tenantId, mountId, connectivity); |
| | | return connectivity; |
| | | } |
| | | String message = "连通性测试成功,解析出 " + runtime.getToolCallbacks().length + " 个工具"; |
| | | updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_HEALTHY, message, elapsedMs); |
| | | AiMcpMount latest = requireMount(mount.getId(), tenantId); |
| | | AiMcpConnectivityTestDto connectivity = buildConnectivityDto(latest, message, elapsedMs, runtime.getToolCallbacks().length); |
| | | aiRedisSupport.cacheConnectivity(tenantId, mountId, connectivity); |
| | | return connectivity; |
| | | } catch (CoolException e) { |
| | | throw e; |
| | | } catch (Exception e) { |
| | | long elapsedMs = System.currentTimeMillis() - startedAt; |
| | | String message = "连通性测试失败: " + e.getMessage(); |
| | | updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY, message, elapsedMs); |
| | | AiMcpMount latest = requireMount(mount.getId(), tenantId); |
| | | AiMcpConnectivityTestDto connectivity = buildConnectivityDto(latest, message, elapsedMs, 0); |
| | | aiRedisSupport.cacheConnectivity(tenantId, mountId, connectivity); |
| | | return connectivity; |
| | | } |
| | | } |
| | | |
| | | /** 对表单里的草稿配置做临时连通性测试,不落库。 */ |
| | |
| | | mount.setTenantId(tenantId); |
| | | fillDefaults(mount); |
| | | ensureRequiredFields(mount, tenantId); |
| | | long startedAt = System.currentTimeMillis(); |
| | | try (McpMountRuntimeFactory.McpMountRuntime runtime = mcpMountRuntimeFactory.create(List.of(mount), userId)) { |
| | | long elapsedMs = System.currentTimeMillis() - startedAt; |
| | | if (!runtime.getErrors().isEmpty()) { |
| | | return AiMcpConnectivityTestDto.builder() |
| | | .mountId(mount.getId()) |
| | | .mountName(mount.getName()) |
| | | .healthStatus(AiDefaults.MCP_HEALTH_UNHEALTHY) |
| | | .message(String.join(";", runtime.getErrors())) |
| | | .initElapsedMs(elapsedMs) |
| | | .toolCount(runtime.getToolCallbacks().length) |
| | | .testedAt(new java.text.SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new Date())) |
| | | .build(); |
| | | } |
| | | return AiMcpConnectivityTestDto.builder() |
| | | .mountId(mount.getId()) |
| | | .mountName(mount.getName()) |
| | | .healthStatus(AiDefaults.MCP_HEALTH_HEALTHY) |
| | | .message("草稿连通性测试成功,解析出 " + runtime.getToolCallbacks().length + " 个工具") |
| | | .initElapsedMs(elapsedMs) |
| | | .toolCount(runtime.getToolCallbacks().length) |
| | | .testedAt(new java.text.SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new Date())) |
| | | .build(); |
| | | } catch (CoolException e) { |
| | | throw e; |
| | | } catch (Exception e) { |
| | | throw new CoolException("草稿连通性测试失败: " + e.getMessage()); |
| | | } |
| | | return aiMcpAdminService.testConnectivity(mount, userId, false); |
| | | } |
| | | |
| | | /** |
| | |
| | | */ |
| | | @Override |
| | | public AiMcpToolTestDto testTool(Long mountId, Long userId, Long tenantId, AiMcpToolTestRequest request) { |
| | | if (userId == null) { |
| | | throw new CoolException("当前登录用户不存在"); |
| | | } |
| | | if (tenantId == null) { |
| | | throw new CoolException("当前租户不存在"); |
| | | } |
| | | if (request == null) { |
| | | throw new CoolException("工具测试参数不能为空"); |
| | | } |
| | | if (!StringUtils.hasText(request.getToolName())) { |
| | | throw new CoolException("工具名称不能为空"); |
| | | } |
| | | if (!StringUtils.hasText(request.getInputJson())) { |
| | | throw new CoolException("工具输入 JSON 不能为空"); |
| | | } |
| | | try { |
| | | objectMapper.readTree(request.getInputJson()); |
| | | } catch (Exception e) { |
| | | throw new CoolException("工具输入 JSON 格式错误: " + e.getMessage()); |
| | | } |
| | | AiMcpMount mount = requireMount(mountId, tenantId); |
| | | long startedAt = System.currentTimeMillis(); |
| | | try (McpMountRuntimeFactory.McpMountRuntime runtime = mcpMountRuntimeFactory.create(List.of(mount), userId)) { |
| | | ToolCallback callback = Arrays.stream(runtime.getToolCallbacks()) |
| | | .filter(item -> item != null && item.getToolDefinition() != null) |
| | | .filter(item -> request.getToolName().equals(item.getToolDefinition().name())) |
| | | .findFirst() |
| | | .orElseThrow(() -> new CoolException("未找到要测试的工具: " + request.getToolName())); |
| | | String output = callback.call( |
| | | request.getInputJson(), |
| | | new ToolContext(Map.of("userId", userId, "tenantId", tenantId, "mountId", mountId)) |
| | | ); |
| | | updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_HEALTHY, |
| | | "工具测试成功: " + request.getToolName(), System.currentTimeMillis() - startedAt); |
| | | return AiMcpToolTestDto.builder() |
| | | .toolName(request.getToolName()) |
| | | .inputJson(request.getInputJson()) |
| | | .output(output) |
| | | .build(); |
| | | } catch (CoolException e) { |
| | | updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY, |
| | | "工具测试失败: " + e.getMessage(), System.currentTimeMillis() - startedAt); |
| | | throw e; |
| | | } catch (Exception e) { |
| | | updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY, |
| | | "工具测试失败: " + e.getMessage(), System.currentTimeMillis() - startedAt); |
| | | throw new CoolException("工具测试失败: " + e.getMessage()); |
| | | } |
| | | AiMcpMount mount = aiMcpAdminService.requireMount(mountId, tenantId); |
| | | return aiMcpAdminService.testTool(mount, userId, tenantId, request); |
| | | } |
| | | |
| | | @Override |
| | | public boolean save(AiMcpMount entity) { |
| | | boolean saved = super.save(entity); |
| | | if (saved && entity != null && entity.getTenantId() != null) { |
| | | aiRedisSupport.evictMcpMountCaches(entity.getTenantId(), entity.getId()); |
| | | evictMountRelatedCaches(entity.getTenantId(), entity.getId()); |
| | | } |
| | | return saved; |
| | | } |
| | |
| | | public boolean updateById(AiMcpMount entity) { |
| | | boolean updated = super.updateById(entity); |
| | | if (updated && entity != null && entity.getTenantId() != null) { |
| | | aiRedisSupport.evictMcpMountCaches(entity.getTenantId(), entity.getId()); |
| | | evictMountRelatedCaches(entity.getTenantId(), entity.getId()); |
| | | } |
| | | return updated; |
| | | } |
| | |
| | | if (removed) { |
| | | records.stream() |
| | | .filter(java.util.Objects::nonNull) |
| | | .forEach(item -> aiRedisSupport.evictMcpMountCaches(item.getTenantId(), item.getId())); |
| | | .forEach(item -> evictMountRelatedCaches(item.getTenantId(), item.getId())); |
| | | } |
| | | return removed; |
| | | } |
| | |
| | | throw new CoolException("不支持的 MCP 传输类型: " + aiMcpMount.getTransportType()); |
| | | } |
| | | |
| | | private AiMcpMount requireMount(Long mountId, Long tenantId) { |
| | | /** 按租户加载挂载记录,不存在直接抛错。 */ |
| | | ensureTenantId(tenantId); |
| | | if (mountId == null) { |
| | | throw new CoolException("MCP 挂载 ID 不能为空"); |
| | | } |
| | | AiMcpMount mount = this.getOne(new LambdaQueryWrapper<AiMcpMount>() |
| | | .eq(AiMcpMount::getId, mountId) |
| | | .eq(AiMcpMount::getTenantId, tenantId) |
| | | .eq(AiMcpMount::getDeleted, 0) |
| | | .last("limit 1")); |
| | | if (mount == null) { |
| | | throw new CoolException("MCP 挂载不存在"); |
| | | } |
| | | return mount; |
| | | } |
| | | |
| | | private void ensureBuiltinConflictFree(AiMcpMount aiMcpMount, Long tenantId) { |
| | | /** 校验同租户下是否存在与当前内置编码互斥的启用挂载。 */ |
| | | if (aiMcpMount.getStatus() == null || aiMcpMount.getStatus() != StatusType.ENABLE.val) { |
| | |
| | | } |
| | | } |
| | | |
| | | private List<AiMcpToolPreviewDto> buildToolPreviewDtos(ToolCallback[] callbacks, List<AiMcpToolPreviewDto> governedCatalog) { |
| | | /** 把底层 ToolCallback 和治理目录信息拼成前端需要的结构化工具卡片数据。 */ |
| | | List<AiMcpToolPreviewDto> tools = new ArrayList<>(); |
| | | Map<String, AiMcpToolPreviewDto> catalogMap = new java.util.LinkedHashMap<>(); |
| | | for (AiMcpToolPreviewDto item : governedCatalog) { |
| | | if (item == null || !StringUtils.hasText(item.getName())) { |
| | | continue; |
| | | } |
| | | catalogMap.put(item.getName(), item); |
| | | } |
| | | for (ToolCallback callback : callbacks) { |
| | | if (callback == null || callback.getToolDefinition() == null) { |
| | | continue; |
| | | } |
| | | AiMcpToolPreviewDto governedItem = catalogMap.get(callback.getToolDefinition().name()); |
| | | tools.add(AiMcpToolPreviewDto.builder() |
| | | .name(callback.getToolDefinition().name()) |
| | | .description(callback.getToolDefinition().description()) |
| | | .inputSchema(callback.getToolDefinition().inputSchema()) |
| | | .returnDirect(callback.getToolMetadata() == null ? null : callback.getToolMetadata().returnDirect()) |
| | | .toolGroup(governedItem == null ? null : governedItem.getToolGroup()) |
| | | .toolPurpose(governedItem == null ? null : governedItem.getToolPurpose()) |
| | | .queryBoundary(governedItem == null ? null : governedItem.getQueryBoundary()) |
| | | .exampleQuestions(governedItem == null ? List.of() : governedItem.getExampleQuestions()) |
| | | .build()); |
| | | } |
| | | return tools; |
| | | } |
| | | |
| | | private void updateHealthStatus(Long mountId, String healthStatus, String message, Long initElapsedMs) { |
| | | this.update(new LambdaUpdateWrapper<AiMcpMount>() |
| | | .eq(AiMcpMount::getId, mountId) |
| | | .set(AiMcpMount::getHealthStatus, healthStatus) |
| | | .set(AiMcpMount::getLastTestTime, new Date()) |
| | | .set(AiMcpMount::getLastTestMessage, message) |
| | | .set(AiMcpMount::getLastInitElapsedMs, initElapsedMs)); |
| | | } |
| | | |
| | | private AiMcpConnectivityTestDto buildConnectivityDto(AiMcpMount mount, String message, Long initElapsedMs, Integer toolCount) { |
| | | return AiMcpConnectivityTestDto.builder() |
| | | .mountId(mount.getId()) |
| | | .mountName(mount.getName()) |
| | | .healthStatus(mount.getHealthStatus()) |
| | | .message(message) |
| | | .initElapsedMs(initElapsedMs) |
| | | .toolCount(toolCount) |
| | | .testedAt(mount.getLastTestTime$()) |
| | | .build(); |
| | | private void evictMountRelatedCaches(Long tenantId, Long mountId) { |
| | | aiMcpCacheStore.evictMcpMountCaches(tenantId, mountId); |
| | | aiConfigCacheStore.evictTenantConfigCaches(tenantId); |
| | | aiConversationCacheStore.evictTenantRuntimeCaches(tenantId); |
| | | } |
| | | } |
| | |
| | | |
| | | import java.util.Locale; |
| | | |
| | | final class AiOpenAiApiSupport { |
| | | public final class AiOpenAiApiSupport { |
| | | |
| | | private static final String DEFAULT_COMPLETIONS_PATH = "/v1/chat/completions"; |
| | | private static final String DEFAULT_EMBEDDINGS_PATH = "/v1/embeddings"; |
| | |
| | | private AiOpenAiApiSupport() { |
| | | } |
| | | |
| | | static OpenAiApi buildOpenAiApi(AiParam aiParam) { |
| | | public static OpenAiApi buildOpenAiApi(AiParam aiParam) { |
| | | int timeoutMs = aiParam.getTimeoutMs() == null ? AiDefaults.DEFAULT_TIMEOUT_MS : aiParam.getTimeoutMs(); |
| | | SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory(); |
| | | requestFactory.setConnectTimeout(timeoutMs); |
| | |
| | | .build(); |
| | | } |
| | | |
| | | static EndpointConfig resolveEndpointConfig(String rawBaseUrl) { |
| | | public static EndpointConfig resolveEndpointConfig(String rawBaseUrl) { |
| | | String normalizedBaseUrl = trimTrailingSlash(rawBaseUrl); |
| | | String lowerCaseBaseUrl = normalizedBaseUrl.toLowerCase(Locale.ROOT); |
| | | |
| | |
| | | return normalized; |
| | | } |
| | | |
| | | record EndpointConfig(String baseUrl, String completionsPath, String embeddingsPath) { |
| | | public record EndpointConfig(String baseUrl, String completionsPath, String embeddingsPath) { |
| | | } |
| | | } |
| | |
| | | import com.vincent.rsf.server.ai.dto.AiParamValidateResultDto; |
| | | import com.vincent.rsf.server.ai.entity.AiParam; |
| | | import com.vincent.rsf.server.ai.mapper.AiParamMapper; |
| | | import com.vincent.rsf.server.ai.store.AiConfigCacheStore; |
| | | import com.vincent.rsf.server.ai.store.AiConversationCacheStore; |
| | | import com.vincent.rsf.server.ai.service.AiParamService; |
| | | import com.vincent.rsf.server.system.enums.StatusType; |
| | | import lombok.RequiredArgsConstructor; |
| | |
| | | public class AiParamServiceImpl extends ServiceImpl<AiParamMapper, AiParam> implements AiParamService { |
| | | |
| | | private final AiParamValidationSupport aiParamValidationSupport; |
| | | private final AiRedisSupport aiRedisSupport; |
| | | private final AiConfigCacheStore aiConfigCacheStore; |
| | | private final AiConversationCacheStore aiConversationCacheStore; |
| | | |
| | | @Override |
| | | public AiParam getActiveParam(Long tenantId) { |
| | |
| | | if (!super.updateById(target)) { |
| | | throw new CoolException("设置默认 AI 参数失败"); |
| | | } |
| | | aiRedisSupport.evictTenantConfigCaches(tenantId); |
| | | aiConfigCacheStore.evictTenantConfigCaches(tenantId); |
| | | aiConversationCacheStore.evictTenantRuntimeCaches(tenantId); |
| | | return target; |
| | | } |
| | | |
| | |
| | | public boolean save(AiParam entity) { |
| | | boolean saved = super.save(entity); |
| | | if (saved && entity != null && entity.getTenantId() != null) { |
| | | aiRedisSupport.evictTenantConfigCaches(entity.getTenantId()); |
| | | aiConfigCacheStore.evictTenantConfigCaches(entity.getTenantId()); |
| | | aiConversationCacheStore.evictTenantRuntimeCaches(entity.getTenantId()); |
| | | } |
| | | return saved; |
| | | } |
| | |
| | | public boolean updateById(AiParam entity) { |
| | | boolean updated = super.updateById(entity); |
| | | if (updated && entity != null && entity.getTenantId() != null) { |
| | | aiRedisSupport.evictTenantConfigCaches(entity.getTenantId()); |
| | | aiConfigCacheStore.evictTenantConfigCaches(entity.getTenantId()); |
| | | aiConversationCacheStore.evictTenantRuntimeCaches(entity.getTenantId()); |
| | | } |
| | | return updated; |
| | | } |
| | |
| | | .map(AiParam::getTenantId) |
| | | .filter(java.util.Objects::nonNull) |
| | | .distinct() |
| | | .forEach(aiRedisSupport::evictTenantConfigCaches); |
| | | .forEach(tenantId -> { |
| | | aiConfigCacheStore.evictTenantConfigCaches(tenantId); |
| | | aiConversationCacheStore.evictTenantRuntimeCaches(tenantId); |
| | | }); |
| | | } |
| | | return removed; |
| | | } |
| | |
| | | import com.vincent.rsf.server.ai.dto.AiPromptPreviewRequest; |
| | | import com.vincent.rsf.server.ai.entity.AiPrompt; |
| | | import com.vincent.rsf.server.ai.mapper.AiPromptMapper; |
| | | import com.vincent.rsf.server.ai.store.AiConfigCacheStore; |
| | | import com.vincent.rsf.server.ai.store.AiConversationCacheStore; |
| | | import com.vincent.rsf.server.ai.service.AiPromptService; |
| | | import com.vincent.rsf.server.system.enums.StatusType; |
| | | import lombok.RequiredArgsConstructor; |
| | |
| | | public class AiPromptServiceImpl extends ServiceImpl<AiPromptMapper, AiPrompt> implements AiPromptService { |
| | | |
| | | private final AiPromptRenderSupport aiPromptRenderSupport; |
| | | private final AiRedisSupport aiRedisSupport; |
| | | private final AiConfigCacheStore aiConfigCacheStore; |
| | | private final AiConversationCacheStore aiConversationCacheStore; |
| | | |
| | | @Override |
| | | public AiPrompt getActivePrompt(String code, Long tenantId) { |
| | |
| | | public boolean save(AiPrompt entity) { |
| | | boolean saved = super.save(entity); |
| | | if (saved && entity != null && entity.getTenantId() != null) { |
| | | aiRedisSupport.evictTenantConfigCaches(entity.getTenantId()); |
| | | aiConfigCacheStore.evictTenantConfigCaches(entity.getTenantId()); |
| | | aiConversationCacheStore.evictTenantRuntimeCaches(entity.getTenantId()); |
| | | } |
| | | return saved; |
| | | } |
| | |
| | | public boolean updateById(AiPrompt entity) { |
| | | boolean updated = super.updateById(entity); |
| | | if (updated && entity != null && entity.getTenantId() != null) { |
| | | aiRedisSupport.evictTenantConfigCaches(entity.getTenantId()); |
| | | aiConfigCacheStore.evictTenantConfigCaches(entity.getTenantId()); |
| | | aiConversationCacheStore.evictTenantRuntimeCaches(entity.getTenantId()); |
| | | } |
| | | return updated; |
| | | } |
| | |
| | | .map(AiPrompt::getTenantId) |
| | | .filter(java.util.Objects::nonNull) |
| | | .distinct() |
| | | .forEach(aiRedisSupport::evictTenantConfigCaches); |
| | | .forEach(tenantId -> { |
| | | aiConfigCacheStore.evictTenantConfigCaches(tenantId); |
| | | aiConversationCacheStore.evictTenantRuntimeCaches(tenantId); |
| | | }); |
| | | } |
| | | return removed; |
| | | } |
| | |
| | | import com.vincent.rsf.server.ai.config.AiDefaults; |
| | | import com.vincent.rsf.server.ai.dto.AiMcpToolPreviewDto; |
| | | import com.vincent.rsf.server.ai.entity.AiMcpMount; |
| | | import com.vincent.rsf.server.ai.service.impl.mcp.BuiltinMcpToolCatalogProvider; |
| | | import com.vincent.rsf.server.ai.service.BuiltinMcpToolRegistry; |
| | | import com.vincent.rsf.server.ai.tool.RsfWmsBaseTools; |
| | | import com.vincent.rsf.server.ai.tool.RsfWmsStockTools; |
| | |
| | | private final RsfWmsStockTools rsfWmsStockTools; |
| | | private final RsfWmsTaskTools rsfWmsTaskTools; |
| | | private final RsfWmsBaseTools rsfWmsBaseTools; |
| | | private final BuiltinMcpToolCatalogProvider builtinMcpToolCatalogProvider; |
| | | |
| | | /** |
| | | * 校验内置 MCP 编码是否合法。 |
| | |
| | | if (!StringUtils.hasText(builtinCode)) { |
| | | throw new CoolException("内置 MCP 编码不能为空"); |
| | | } |
| | | if (!supportedBuiltinCodes().contains(builtinCode)) { |
| | | if (!builtinMcpToolCatalogProvider.supportedBuiltinCodes().contains(builtinCode)) { |
| | | throw new CoolException("不支持的内置 MCP 编码: " + builtinCode); |
| | | } |
| | | } |
| | |
| | | @Override |
| | | public List<AiMcpToolPreviewDto> listBuiltinToolCatalog(String builtinCode) { |
| | | validateBuiltinCode(builtinCode); |
| | | if (AiDefaults.MCP_BUILTIN_RSF_WMS.equals(builtinCode)) { |
| | | return new ArrayList<>(catalogByBuiltinCode(builtinCode).values()); |
| | | } |
| | | return new ArrayList<>(catalogByBuiltinCode(builtinCode).values()); |
| | | return new ArrayList<>(builtinMcpToolCatalogProvider.getCatalog(builtinCode).values()); |
| | | } |
| | | |
| | | private List<ToolCallback> createValidatedCallbacks(Object toolBean, String builtinCode) { |
| | |
| | | * 2. 每个工具都必须出现在治理目录里 |
| | | */ |
| | | List<ToolCallback> callbacks = Arrays.asList(ToolCallbacks.from(toolBean)); |
| | | Map<String, AiMcpToolPreviewDto> catalog = catalogByBuiltinCode(builtinCode); |
| | | Map<String, AiMcpToolPreviewDto> catalog = builtinMcpToolCatalogProvider.getCatalog(builtinCode); |
| | | for (ToolCallback callback : callbacks) { |
| | | if (callback == null || callback.getToolDefinition() == null) { |
| | | continue; |
| | |
| | | return callbacks; |
| | | } |
| | | |
| | | private List<String> supportedBuiltinCodes() { |
| | | /** 当前版本允许挂载的全部内置 MCP 编码。 */ |
| | | return List.of(AiDefaults.MCP_BUILTIN_RSF_WMS); |
| | | } |
| | | |
| | | private Map<String, AiMcpToolPreviewDto> catalogByBuiltinCode(String builtinCode) { |
| | | /** |
| | | * 构造内置工具治理目录。 |
| | | * 这里的目录是运行时校验和管理端预览的共同事实来源,不能与工具实现脱节。 |
| | | */ |
| | | if (AiDefaults.MCP_BUILTIN_RSF_WMS.equals(builtinCode)) { |
| | | Map<String, AiMcpToolPreviewDto> catalog = new LinkedHashMap<>(); |
| | | catalog.put("rsf_query_available_inventory", buildCatalogItem( |
| | | "rsf_query_available_inventory", |
| | | "库存查询", |
| | | "查询指定物料当前可用于出库的库存明细。", |
| | | "必须提供物料编码或物料名称,并且最多返回 50 条库存记录。", |
| | | List.of("查询物料 MAT001 当前可出库库存", "按物料名称查询托盘库存明细") |
| | | )); |
| | | catalog.put("rsf_query_station_list", buildCatalogItem( |
| | | "rsf_query_station_list", |
| | | "库存查询", |
| | | "查询指定作业类型可用的设备站点。", |
| | | "必须提供站点类型列表,类型数量最多 10 个,最多返回 50 个站点。", |
| | | List.of("查询入库和出库作业可用站点", "列出 AGV_PICK 类型的作业站点") |
| | | )); |
| | | catalog.put("rsf_query_task", buildCatalogItem( |
| | | "rsf_query_task", |
| | | "任务查询", |
| | | "按任务号、状态、类型或站点条件查询任务;支持从自然语言中自动提取任务号,精确命中时返回任务明细。", |
| | | "过滤条件均为可选,不传过滤条件时默认返回最近任务,最多返回 50 条记录。", |
| | | List.of("查询最近 10 条任务", "查询任务号 TASK24001 的详情") |
| | | )); |
| | | catalog.put("rsf_query_warehouses", buildCatalogItem( |
| | | "rsf_query_warehouses", |
| | | "基础资料", |
| | | "查询仓库基础信息。", |
| | | "至少提供仓库编码或名称,最多返回 50 条仓库记录。", |
| | | List.of("查询编码包含 WH 的仓库", "按仓库名称查询仓库地址") |
| | | )); |
| | | catalog.put("rsf_query_bas_stations", buildCatalogItem( |
| | | "rsf_query_bas_stations", |
| | | "基础资料", |
| | | "查询基础站点信息。", |
| | | "至少提供站点编号、站点名称或使用状态之一,最多返回 50 条站点记录。", |
| | | List.of("查询使用中的基础站点", "按站点编号查询基础站点") |
| | | )); |
| | | catalog.put("rsf_query_dict_data", buildCatalogItem( |
| | | "rsf_query_dict_data", |
| | | "基础资料", |
| | | "查询指定字典类型下的字典数据。", |
| | | "必须提供字典类型编码,最多返回 100 条字典记录。", |
| | | List.of("查询 task_status 字典", "按字典标签过滤 task_type 字典数据") |
| | | )); |
| | | return catalog; |
| | | } |
| | | throw new CoolException("不支持的内置 MCP 编码: " + builtinCode); |
| | | } |
| | | |
| | | private AiMcpToolPreviewDto buildCatalogItem(String name, String toolGroup, String toolPurpose, |
| | | String queryBoundary, List<String> exampleQuestions) { |
| | | /** 统一创建工具目录条目,避免不同工具组出现字段风格不一致。 */ |
| | | return AiMcpToolPreviewDto.builder() |
| | | .name(name) |
| | | .toolGroup(toolGroup) |
| | | .toolPurpose(toolPurpose) |
| | | .queryBoundary(queryBoundary) |
| | | .exampleQuestions(exampleQuestions) |
| | | .build(); |
| | | } |
| | | } |
| | |
| | | package com.vincent.rsf.server.ai.service.impl; |
| | | |
| | | import com.fasterxml.jackson.core.type.TypeReference; |
| | | import com.fasterxml.jackson.databind.ObjectMapper; |
| | | import com.vincent.rsf.framework.exception.CoolException; |
| | | import com.vincent.rsf.server.ai.config.AiDefaults; |
| | | import com.vincent.rsf.server.ai.entity.AiMcpMount; |
| | | import com.vincent.rsf.server.ai.service.BuiltinMcpToolRegistry; |
| | | import com.vincent.rsf.server.ai.service.MountedToolCallback; |
| | | import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory; |
| | | import io.modelcontextprotocol.client.McpClient; |
| | | import io.modelcontextprotocol.client.McpSyncClient; |
| | | import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; |
| | | import io.modelcontextprotocol.client.transport.ServerParameters; |
| | | import io.modelcontextprotocol.client.transport.StdioClientTransport; |
| | | import io.modelcontextprotocol.json.jackson.JacksonMcpJsonMapper; |
| | | import io.modelcontextprotocol.spec.McpSchema; |
| | | import com.vincent.rsf.server.ai.service.impl.mcp.McpClientFactory; |
| | | import lombok.RequiredArgsConstructor; |
| | | import lombok.extern.slf4j.Slf4j; |
| | | import org.springframework.ai.mcp.SyncMcpToolCallbackProvider; |
| | |
| | | import org.springframework.stereotype.Service; |
| | | import org.springframework.util.StringUtils; |
| | | |
| | | import java.time.Duration; |
| | | import java.util.ArrayList; |
| | | import java.util.Arrays; |
| | | import java.util.Collections; |
| | | import java.util.LinkedHashSet; |
| | | import java.util.LinkedHashMap; |
| | | import java.util.List; |
| | | import java.util.Map; |
| | | |
| | | @Slf4j |
| | | @Service |
| | | @RequiredArgsConstructor |
| | | public class McpMountRuntimeFactoryImpl implements McpMountRuntimeFactory { |
| | | |
| | | private final ObjectMapper objectMapper; |
| | | private final BuiltinMcpToolRegistry builtinMcpToolRegistry; |
| | | private final McpClientFactory mcpClientFactory; |
| | | |
| | | /** |
| | | * 把一组 MCP 挂载记录解析成一次对话可直接使用的运行时对象。 |
| | |
| | | mountedNames.add(mount.getName()); |
| | | continue; |
| | | } |
| | | McpSyncClient client = createClient(mount); |
| | | McpSyncClient client = mcpClientFactory.createClient(mount); |
| | | client.initialize(); |
| | | client.listTools(); |
| | | clients.add(client); |
| | |
| | | } |
| | | if (!duplicateNames.isEmpty()) { |
| | | throw new CoolException("MCP 工具名称重复,请调整挂载配置: " + String.join(", ", duplicateNames)); |
| | | } |
| | | } |
| | | |
| | | private McpSyncClient createClient(AiMcpMount mount) { |
| | | /** |
| | | * 按挂载配置动态创建 MCP Client。 |
| | | * 该方法只负责 transport 层初始化,不负责工具去重和错误聚合。 |
| | | */ |
| | | Duration timeout = Duration.ofMillis(mount.getRequestTimeoutMs() == null |
| | | ? AiDefaults.DEFAULT_TIMEOUT_MS |
| | | : mount.getRequestTimeoutMs()); |
| | | JacksonMcpJsonMapper jsonMapper = new JacksonMcpJsonMapper(objectMapper); |
| | | if (AiDefaults.MCP_TRANSPORT_STDIO.equals(mount.getTransportType())) { |
| | | ServerParameters.Builder parametersBuilder = ServerParameters.builder(mount.getCommand()); |
| | | List<String> args = readStringList(mount.getArgsJson()); |
| | | if (!args.isEmpty()) { |
| | | parametersBuilder.args(args); |
| | | } |
| | | Map<String, String> env = readStringMap(mount.getEnvJson()); |
| | | if (!env.isEmpty()) { |
| | | parametersBuilder.env(env); |
| | | } |
| | | StdioClientTransport transport = new StdioClientTransport(parametersBuilder.build(), jsonMapper); |
| | | transport.setStdErrorHandler(message -> log.warn("MCP STDIO stderr [{}]: {}", mount.getName(), message)); |
| | | return McpClient.sync(transport) |
| | | .requestTimeout(timeout) |
| | | .initializationTimeout(timeout) |
| | | .clientInfo(new McpSchema.Implementation("rsf-ai-client", "RSF AI Client", "1.0.0")) |
| | | .build(); |
| | | } |
| | | if (!AiDefaults.MCP_TRANSPORT_SSE_HTTP.equals(mount.getTransportType())) { |
| | | throw new CoolException("不支持的 MCP 传输类型: " + mount.getTransportType()); |
| | | } |
| | | |
| | | if (!StringUtils.hasText(mount.getServerUrl())) { |
| | | throw new CoolException("MCP 服务地址不能为空"); |
| | | } |
| | | HttpClientSseClientTransport.Builder transportBuilder = HttpClientSseClientTransport.builder(mount.getServerUrl()) |
| | | .jsonMapper(jsonMapper) |
| | | .connectTimeout(timeout); |
| | | if (StringUtils.hasText(mount.getEndpoint())) { |
| | | transportBuilder.sseEndpoint(mount.getEndpoint()); |
| | | } |
| | | Map<String, String> headers = readStringMap(mount.getHeadersJson()); |
| | | if (!headers.isEmpty()) { |
| | | transportBuilder.customizeRequest(builder -> headers.forEach(builder::header)); |
| | | } |
| | | return McpClient.sync(transportBuilder.build()) |
| | | .requestTimeout(timeout) |
| | | .initializationTimeout(timeout) |
| | | .clientInfo(new McpSchema.Implementation("rsf-ai-client", "RSF AI Client", "1.0.0")) |
| | | .build(); |
| | | } |
| | | |
| | | private List<String> readStringList(String json) { |
| | | /** 解析挂载表里的 JSON 数组配置,例如 STDIO args。 */ |
| | | if (!StringUtils.hasText(json)) { |
| | | return Collections.emptyList(); |
| | | } |
| | | try { |
| | | return objectMapper.readValue(json, new TypeReference<List<String>>() { |
| | | }); |
| | | } catch (Exception e) { |
| | | throw new CoolException("解析 MCP 列表配置失败: " + e.getMessage()); |
| | | } |
| | | } |
| | | |
| | | private Map<String, String> readStringMap(String json) { |
| | | /** 解析挂载表里的 JSON Map 配置,例如 headers 或环境变量。 */ |
| | | if (!StringUtils.hasText(json)) { |
| | | return Collections.emptyMap(); |
| | | } |
| | | try { |
| | | Map<String, String> result = objectMapper.readValue(json, new TypeReference<LinkedHashMap<String, String>>() { |
| | | }); |
| | | return result == null ? Collections.emptyMap() : result; |
| | | } catch (Exception e) { |
| | | throw new CoolException("解析 MCP Map 配置失败: " + e.getMessage()); |
| | | } |
| | | } |
| | | |
| New file |
| | |
| | | package com.vincent.rsf.server.ai.service.impl.chat; |
| | | |
| | | import com.vincent.rsf.server.ai.dto.AiChatErrorDto; |
| | | import com.vincent.rsf.server.ai.enums.AiErrorCategory; |
| | | import com.vincent.rsf.server.ai.exception.AiChatException; |
| | | import com.vincent.rsf.server.ai.service.AiCallLogService; |
| | | import com.vincent.rsf.server.ai.store.AiStreamStateStore; |
| | | import lombok.RequiredArgsConstructor; |
| | | import lombok.extern.slf4j.Slf4j; |
| | | import org.springframework.stereotype.Component; |
| | | import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; |
| | | |
| | | import java.time.Instant; |
| | | |
| | | @Slf4j |
| | | @Component |
| | | @RequiredArgsConstructor |
| | | public class AiChatFailureHandler { |
| | | |
| | | private final AiSseEventPublisher aiSseEventPublisher; |
| | | private final AiCallLogService aiCallLogService; |
| | | private final AiStreamStateStore aiStreamStateStore; |
| | | |
| | | public AiChatException buildAiException(String code, AiErrorCategory category, String stage, String message, Throwable cause) { |
| | | return new AiChatException(code, category, stage, message, cause); |
| | | } |
| | | |
| | | public void handleStreamFailure(SseEmitter emitter, String requestId, Long sessionId, String model, long startedAt, |
| | | Long firstTokenAt, AiChatException exception, Long callLogId, |
| | | long toolSuccessCount, long toolFailureCount, |
| | | AiThinkingTraceEmitter thinkingTraceEmitter, |
| | | Long tenantId, Long userId, String promptCode) { |
| | | if (isClientAbortException(exception)) { |
| | | log.warn("AI chat aborted by client, requestId={}, sessionId={}, stage={}, message={}", |
| | | requestId, sessionId, exception.getStage(), exception.getMessage()); |
| | | if (thinkingTraceEmitter != null) { |
| | | thinkingTraceEmitter.markTerminated("ABORTED"); |
| | | } |
| | | aiSseEventPublisher.emitSafely(emitter, "status", |
| | | aiSseEventPublisher.buildTerminalStatus(requestId, sessionId, "ABORTED", model, startedAt, firstTokenAt)); |
| | | aiCallLogService.failCallLog( |
| | | callLogId, |
| | | "ABORTED", |
| | | exception.getCategory().name(), |
| | | exception.getStage(), |
| | | exception.getMessage(), |
| | | System.currentTimeMillis() - startedAt, |
| | | aiSseEventPublisher.resolveFirstTokenLatency(startedAt, firstTokenAt), |
| | | toolSuccessCount, |
| | | toolFailureCount |
| | | ); |
| | | aiStreamStateStore.markStreamState(requestId, tenantId, userId, sessionId, promptCode, "ABORTED", exception.getMessage()); |
| | | emitter.completeWithError(exception); |
| | | return; |
| | | } |
| | | log.error("AI chat failed, requestId={}, sessionId={}, category={}, stage={}, message={}", |
| | | requestId, sessionId, exception.getCategory(), exception.getStage(), exception.getMessage(), exception); |
| | | if (thinkingTraceEmitter != null) { |
| | | thinkingTraceEmitter.markTerminated("FAILED"); |
| | | } |
| | | aiSseEventPublisher.emitSafely(emitter, "status", |
| | | aiSseEventPublisher.buildTerminalStatus(requestId, sessionId, "FAILED", model, startedAt, firstTokenAt)); |
| | | aiSseEventPublisher.emitSafely(emitter, "error", AiChatErrorDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .code(exception.getCode()) |
| | | .category(exception.getCategory().name()) |
| | | .stage(exception.getStage()) |
| | | .message(exception.getMessage()) |
| | | .timestamp(Instant.now().toEpochMilli()) |
| | | .build()); |
| | | aiCallLogService.failCallLog( |
| | | callLogId, |
| | | "FAILED", |
| | | exception.getCategory().name(), |
| | | exception.getStage(), |
| | | exception.getMessage(), |
| | | System.currentTimeMillis() - startedAt, |
| | | aiSseEventPublisher.resolveFirstTokenLatency(startedAt, firstTokenAt), |
| | | toolSuccessCount, |
| | | toolFailureCount |
| | | ); |
| | | aiStreamStateStore.markStreamState(requestId, tenantId, userId, sessionId, promptCode, "FAILED", exception.getMessage()); |
| | | emitter.completeWithError(exception); |
| | | } |
| | | |
| | | private boolean isClientAbortException(Throwable throwable) { |
| | | Throwable current = throwable; |
| | | while (current != null) { |
| | | String message = current.getMessage(); |
| | | if (message != null) { |
| | | String normalized = message.toLowerCase(); |
| | | if (normalized.contains("broken pipe") |
| | | || normalized.contains("connection reset") |
| | | || normalized.contains("forcibly closed") |
| | | || normalized.contains("abort")) { |
| | | return true; |
| | | } |
| | | } |
| | | current = current.getCause(); |
| | | } |
| | | return false; |
| | | } |
| | | } |
| New file |
| | |
| | | package com.vincent.rsf.server.ai.service.impl.chat; |
| | | |
| | | import com.vincent.rsf.server.ai.dto.AiChatMemoryDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatMessageDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatModelOptionDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatRequest; |
| | | import com.vincent.rsf.server.ai.dto.AiChatStatusDto; |
| | | import com.vincent.rsf.server.ai.dto.AiResolvedConfig; |
| | | import com.vincent.rsf.server.ai.entity.AiCallLog; |
| | | import com.vincent.rsf.server.ai.entity.AiChatSession; |
| | | import com.vincent.rsf.server.ai.enums.AiErrorCategory; |
| | | import com.vincent.rsf.server.ai.exception.AiChatException; |
| | | import com.vincent.rsf.server.ai.service.AiCallLogService; |
| | | import com.vincent.rsf.server.ai.service.AiChatMemoryService; |
| | | import com.vincent.rsf.server.ai.service.AiConfigResolverService; |
| | | import com.vincent.rsf.server.ai.service.AiParamService; |
| | | import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory; |
| | | import com.vincent.rsf.server.ai.store.AiChatRateLimiter; |
| | | import com.vincent.rsf.server.ai.store.AiStreamStateStore; |
| | | import lombok.RequiredArgsConstructor; |
| | | import lombok.extern.slf4j.Slf4j; |
| | | import org.springframework.ai.chat.metadata.ChatResponseMetadata; |
| | | import org.springframework.ai.chat.model.ChatResponse; |
| | | import org.springframework.ai.chat.prompt.Prompt; |
| | | import org.springframework.ai.openai.OpenAiChatModel; |
| | | import org.springframework.ai.tool.ToolCallback; |
| | | import org.springframework.stereotype.Component; |
| | | import org.springframework.util.StringUtils; |
| | | import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; |
| | | import reactor.core.publisher.Flux; |
| | | |
| | | import java.time.Instant; |
| | | import java.util.List; |
| | | import java.util.concurrent.atomic.AtomicLong; |
| | | import java.util.concurrent.atomic.AtomicReference; |
| | | |
| | | @Slf4j |
| | | @Component |
| | | @RequiredArgsConstructor |
| | | public class AiChatOrchestrator { |
| | | |
| | | private final AiConfigResolverService aiConfigResolverService; |
| | | private final AiChatMemoryService aiChatMemoryService; |
| | | private final AiParamService aiParamService; |
| | | private final McpMountRuntimeFactory mcpMountRuntimeFactory; |
| | | private final AiCallLogService aiCallLogService; |
| | | private final AiChatRateLimiter aiChatRateLimiter; |
| | | private final AiStreamStateStore aiStreamStateStore; |
| | | private final AiChatRuntimeAssembler aiChatRuntimeAssembler; |
| | | private final AiPromptMessageBuilder aiPromptMessageBuilder; |
| | | private final AiOpenAiChatModelFactory aiOpenAiChatModelFactory; |
| | | private final AiToolObservationService aiToolObservationService; |
| | | private final AiSseEventPublisher aiSseEventPublisher; |
| | | private final AiChatFailureHandler aiChatFailureHandler; |
| | | |
| | | public void executeStream(AiChatRequest request, Long userId, Long tenantId, SseEmitter emitter) { |
| | | String requestId = request.getRequestId(); |
| | | long startedAt = System.currentTimeMillis(); |
| | | AtomicReference<Long> firstTokenAtRef = new AtomicReference<>(); |
| | | AtomicLong toolCallSequence = new AtomicLong(0); |
| | | AtomicLong toolSuccessCount = new AtomicLong(0); |
| | | AtomicLong toolFailureCount = new AtomicLong(0); |
| | | Long sessionId = request.getSessionId(); |
| | | Long callLogId = null; |
| | | String model = null; |
| | | String resolvedPromptCode = request.getPromptCode(); |
| | | AiThinkingTraceEmitter thinkingTraceEmitter = null; |
| | | try { |
| | | ensureIdentity(userId, tenantId); |
| | | AiResolvedConfig config = resolveConfig(request, tenantId); |
| | | List<AiChatModelOptionDto> modelOptions = aiParamService.listChatModelOptions(tenantId); |
| | | resolvedPromptCode = config.getPromptCode(); |
| | | if (!aiChatRateLimiter.allowChatRequest(tenantId, userId, config.getPromptCode())) { |
| | | throw aiChatFailureHandler.buildAiException("AI_RATE_LIMITED", AiErrorCategory.REQUEST, "RATE_LIMIT", |
| | | "当前提问过于频繁,请稍后再试", null); |
| | | } |
| | | final String resolvedModel = config.getAiParam().getModel(); |
| | | model = resolvedModel; |
| | | AiChatSession session = resolveSession(request, userId, tenantId, config.getPromptCode()); |
| | | sessionId = session.getId(); |
| | | aiStreamStateStore.markStreamState(requestId, tenantId, userId, sessionId, config.getPromptCode(), "RUNNING", null); |
| | | AiChatMemoryDto memory = loadMemory(userId, tenantId, config.getPromptCode(), session.getId()); |
| | | List<AiChatMessageDto> mergedMessages = aiPromptMessageBuilder.mergeMessages(memory.getShortMemoryMessages(), request.getMessages()); |
| | | AiCallLog callLog = aiCallLogService.startCallLog( |
| | | requestId, |
| | | session.getId(), |
| | | userId, |
| | | tenantId, |
| | | config.getPromptCode(), |
| | | config.getPrompt().getName(), |
| | | config.getAiParam().getModel(), |
| | | config.getMcpMounts().size(), |
| | | config.getMcpMounts().size(), |
| | | config.getMcpMounts().stream().map(item -> item.getName()).toList() |
| | | ); |
| | | callLogId = callLog.getId(); |
| | | try (McpMountRuntimeFactory.McpMountRuntime runtime = createRuntime(config, userId)) { |
| | | aiSseEventPublisher.emitStrict(emitter, "start", aiChatRuntimeAssembler.buildRuntimeSnapshot( |
| | | requestId, |
| | | session.getId(), |
| | | config, |
| | | modelOptions, |
| | | runtime.getMountedCount(), |
| | | runtime.getMountedNames(), |
| | | runtime.getErrors(), |
| | | memory |
| | | )); |
| | | aiSseEventPublisher.emitSafely(emitter, "status", AiChatStatusDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(session.getId()) |
| | | .status("STARTED") |
| | | .model(resolvedModel) |
| | | .timestamp(Instant.now().toEpochMilli()) |
| | | .elapsedMs(0L) |
| | | .build()); |
| | | log.info("AI chat started, requestId={}, userId={}, tenantId={}, sessionId={}, model={}", |
| | | requestId, userId, tenantId, session.getId(), resolvedModel); |
| | | thinkingTraceEmitter = new AiThinkingTraceEmitter(aiSseEventPublisher, emitter, requestId, session.getId()); |
| | | thinkingTraceEmitter.startAnalyze(); |
| | | AiThinkingTraceEmitter activeThinkingTraceEmitter = thinkingTraceEmitter; |
| | | |
| | | ToolCallback[] observableToolCallbacks = aiToolObservationService.wrapToolCallbacks( |
| | | runtime.getToolCallbacks(), emitter, requestId, session.getId(), toolCallSequence, |
| | | toolSuccessCount, toolFailureCount, callLogId, userId, tenantId, activeThinkingTraceEmitter |
| | | ); |
| | | Prompt prompt = new Prompt( |
| | | aiPromptMessageBuilder.buildPromptMessages(memory, mergedMessages, config.getPrompt(), request.getMetadata()), |
| | | aiOpenAiChatModelFactory.buildChatOptions(config.getAiParam(), observableToolCallbacks, userId, tenantId, |
| | | requestId, session.getId(), request.getMetadata()) |
| | | ); |
| | | OpenAiChatModel chatModel = aiOpenAiChatModelFactory.createChatModel(config.getAiParam()); |
| | | if (Boolean.FALSE.equals(config.getAiParam().getStreamingEnabled())) { |
| | | ChatResponse response = invokeChatCall(chatModel, prompt); |
| | | String content = extractContent(response); |
| | | aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), content); |
| | | if (StringUtils.hasText(content)) { |
| | | aiSseEventPublisher.markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt, activeThinkingTraceEmitter); |
| | | aiSseEventPublisher.emitStrict(emitter, "delta", aiSseEventPublisher.buildMessagePayload("requestId", requestId, "content", content)); |
| | | } |
| | | activeThinkingTraceEmitter.completeCurrentPhase(); |
| | | aiSseEventPublisher.emitDone(emitter, requestId, response.getMetadata(), config.getAiParam().getModel(), |
| | | session.getId(), startedAt, firstTokenAtRef.get()); |
| | | aiSseEventPublisher.emitSafely(emitter, "status", |
| | | aiSseEventPublisher.buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get())); |
| | | aiCallLogService.completeCallLog( |
| | | callLogId, |
| | | "COMPLETED", |
| | | System.currentTimeMillis() - startedAt, |
| | | aiSseEventPublisher.resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()), |
| | | response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getPromptTokens(), |
| | | response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getCompletionTokens(), |
| | | response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getTotalTokens(), |
| | | toolSuccessCount.get(), |
| | | toolFailureCount.get() |
| | | ); |
| | | aiStreamStateStore.markStreamState(requestId, tenantId, userId, session.getId(), config.getPromptCode(), "COMPLETED", null); |
| | | log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}", |
| | | requestId, session.getId(), System.currentTimeMillis() - startedAt, |
| | | aiSseEventPublisher.resolveFirstTokenLatency(startedAt, firstTokenAtRef.get())); |
| | | emitter.complete(); |
| | | return; |
| | | } |
| | | |
| | | Flux<ChatResponse> responseFlux = invokeChatStream(chatModel, prompt); |
| | | AtomicReference<ChatResponseMetadata> lastMetadata = new AtomicReference<>(); |
| | | StringBuilder assistantContent = new StringBuilder(); |
| | | try { |
| | | responseFlux.doOnNext(response -> { |
| | | lastMetadata.set(response.getMetadata()); |
| | | String content = extractContent(response); |
| | | if (StringUtils.hasText(content)) { |
| | | aiSseEventPublisher.markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt, activeThinkingTraceEmitter); |
| | | assistantContent.append(content); |
| | | aiSseEventPublisher.emitStrict(emitter, "delta", |
| | | aiSseEventPublisher.buildMessagePayload("requestId", requestId, "content", content)); |
| | | } |
| | | }) |
| | | .blockLast(); |
| | | } catch (Exception e) { |
| | | throw aiChatFailureHandler.buildAiException("AI_MODEL_STREAM_ERROR", AiErrorCategory.MODEL, "MODEL_STREAM", |
| | | e == null ? "AI 模型流式调用失败" : e.getMessage(), e); |
| | | } |
| | | aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), assistantContent.toString()); |
| | | activeThinkingTraceEmitter.completeCurrentPhase(); |
| | | aiSseEventPublisher.emitDone(emitter, requestId, lastMetadata.get(), config.getAiParam().getModel(), |
| | | session.getId(), startedAt, firstTokenAtRef.get()); |
| | | aiSseEventPublisher.emitSafely(emitter, "status", |
| | | aiSseEventPublisher.buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get())); |
| | | aiCallLogService.completeCallLog( |
| | | callLogId, |
| | | "COMPLETED", |
| | | System.currentTimeMillis() - startedAt, |
| | | aiSseEventPublisher.resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()), |
| | | lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getPromptTokens(), |
| | | lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getCompletionTokens(), |
| | | lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getTotalTokens(), |
| | | toolSuccessCount.get(), |
| | | toolFailureCount.get() |
| | | ); |
| | | aiStreamStateStore.markStreamState(requestId, tenantId, userId, session.getId(), config.getPromptCode(), "COMPLETED", null); |
| | | log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}", |
| | | requestId, session.getId(), System.currentTimeMillis() - startedAt, |
| | | aiSseEventPublisher.resolveFirstTokenLatency(startedAt, firstTokenAtRef.get())); |
| | | emitter.complete(); |
| | | } |
| | | } catch (AiChatException e) { |
| | | aiChatFailureHandler.handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(), e, |
| | | callLogId, toolSuccessCount.get(), toolFailureCount.get(), thinkingTraceEmitter, |
| | | tenantId, userId, resolvedPromptCode); |
| | | } catch (Exception e) { |
| | | aiChatFailureHandler.handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(), |
| | | aiChatFailureHandler.buildAiException("AI_INTERNAL_ERROR", AiErrorCategory.INTERNAL, "INTERNAL", |
| | | e == null ? "AI 对话失败" : e.getMessage(), e), |
| | | callLogId, toolSuccessCount.get(), toolFailureCount.get(), thinkingTraceEmitter, |
| | | tenantId, userId, resolvedPromptCode); |
| | | } finally { |
| | | log.debug("AI chat stream finished, requestId={}", requestId); |
| | | } |
| | | } |
| | | |
| | | private void ensureIdentity(Long userId, Long tenantId) { |
| | | if (userId == null) { |
| | | throw aiChatFailureHandler.buildAiException("AI_AUTH_USER_MISSING", AiErrorCategory.AUTH, "AUTH_VALIDATE", "当前登录用户不存在", null); |
| | | } |
| | | if (tenantId == null) { |
| | | throw aiChatFailureHandler.buildAiException("AI_AUTH_TENANT_MISSING", AiErrorCategory.AUTH, "AUTH_VALIDATE", "当前租户不存在", null); |
| | | } |
| | | } |
| | | |
| | | private AiResolvedConfig resolveConfig(AiChatRequest request, Long tenantId) { |
| | | try { |
| | | return aiConfigResolverService.resolve(request.getPromptCode(), tenantId, request.getAiParamId()); |
| | | } catch (Exception e) { |
| | | throw aiChatFailureHandler.buildAiException("AI_CONFIG_RESOLVE_ERROR", AiErrorCategory.CONFIG, "CONFIG_RESOLVE", |
| | | e == null ? "AI 配置解析失败" : e.getMessage(), e); |
| | | } |
| | | } |
| | | |
| | | private AiChatSession resolveSession(AiChatRequest request, Long userId, Long tenantId, String promptCode) { |
| | | try { |
| | | return aiChatMemoryService.resolveSession(userId, tenantId, promptCode, request.getSessionId(), |
| | | aiPromptMessageBuilder.resolveTitleSeed(request.getMessages())); |
| | | } catch (Exception e) { |
| | | throw aiChatFailureHandler.buildAiException("AI_SESSION_RESOLVE_ERROR", AiErrorCategory.REQUEST, "SESSION_RESOLVE", |
| | | e == null ? "AI 会话解析失败" : e.getMessage(), e); |
| | | } |
| | | } |
| | | |
| | | private AiChatMemoryDto loadMemory(Long userId, Long tenantId, String promptCode, Long sessionId) { |
| | | try { |
| | | return aiChatMemoryService.getMemory(userId, tenantId, promptCode, sessionId); |
| | | } catch (Exception e) { |
| | | throw aiChatFailureHandler.buildAiException("AI_MEMORY_LOAD_ERROR", AiErrorCategory.REQUEST, "MEMORY_LOAD", |
| | | e == null ? "AI 会话记忆加载失败" : e.getMessage(), e); |
| | | } |
| | | } |
| | | |
| | | private McpMountRuntimeFactory.McpMountRuntime createRuntime(AiResolvedConfig config, Long userId) { |
| | | try { |
| | | return mcpMountRuntimeFactory.create(config.getMcpMounts(), userId); |
| | | } catch (Exception e) { |
| | | throw aiChatFailureHandler.buildAiException("AI_MCP_MOUNT_ERROR", AiErrorCategory.MCP, "MCP_MOUNT", |
| | | e == null ? "MCP 挂载失败" : e.getMessage(), e); |
| | | } |
| | | } |
| | | |
| | | private ChatResponse invokeChatCall(OpenAiChatModel chatModel, Prompt prompt) { |
| | | try { |
| | | return chatModel.call(prompt); |
| | | } catch (Exception e) { |
| | | throw aiChatFailureHandler.buildAiException("AI_MODEL_CALL_ERROR", AiErrorCategory.MODEL, "MODEL_CALL", |
| | | e == null ? "AI 模型调用失败" : e.getMessage(), e); |
| | | } |
| | | } |
| | | |
| | | private Flux<ChatResponse> invokeChatStream(OpenAiChatModel chatModel, Prompt prompt) { |
| | | try { |
| | | return chatModel.stream(prompt); |
| | | } catch (Exception e) { |
| | | throw aiChatFailureHandler.buildAiException("AI_MODEL_STREAM_ERROR", AiErrorCategory.MODEL, "MODEL_STREAM_INIT", |
| | | e == null ? "AI 模型流式调用失败" : e.getMessage(), e); |
| | | } |
| | | } |
| | | |
| | | private String extractContent(ChatResponse response) { |
| | | if (response == null || response.getResult() == null || response.getResult().getOutput() == null) { |
| | | return null; |
| | | } |
| | | return response.getResult().getOutput().getText(); |
| | | } |
| | | } |
| New file |
| | |
| | | package com.vincent.rsf.server.ai.service.impl.chat; |
| | | |
| | | import com.vincent.rsf.server.ai.dto.AiChatMemoryDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatModelOptionDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatRuntimeDto; |
| | | import com.vincent.rsf.server.ai.dto.AiResolvedConfig; |
| | | import org.springframework.stereotype.Component; |
| | | |
| | | import java.util.List; |
| | | |
| | | @Component |
| | | public class AiChatRuntimeAssembler { |
| | | |
| | | public AiChatRuntimeDto buildRuntimeSnapshot(String requestId, Long sessionId, AiResolvedConfig config, |
| | | List<AiChatModelOptionDto> modelOptions, Integer mountedMcpCount, |
| | | List<String> mountedMcpNames, List<String> mountErrors, |
| | | AiChatMemoryDto memory) { |
| | | return AiChatRuntimeDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .aiParamId(config.getAiParam().getId()) |
| | | .promptCode(config.getPromptCode()) |
| | | .promptName(config.getPrompt().getName()) |
| | | .model(config.getAiParam().getModel()) |
| | | .modelOptions(modelOptions) |
| | | .configuredMcpCount(config.getMcpMounts().size()) |
| | | .mountedMcpCount(mountedMcpCount) |
| | | .mountedMcpNames(mountedMcpNames) |
| | | .mountErrors(mountErrors) |
| | | .memorySummary(memory.getMemorySummary()) |
| | | .memoryFacts(memory.getMemoryFacts()) |
| | | .recentMessageCount(memory.getRecentMessageCount()) |
| | | .persistedMessages(memory.getPersistedMessages()) |
| | | .build(); |
| | | } |
| | | } |
| New file |
| | |
| | | package com.vincent.rsf.server.ai.service.impl.chat; |
| | | |
| | | import com.vincent.rsf.framework.exception.CoolException; |
| | | import com.vincent.rsf.server.ai.entity.AiParam; |
| | | import com.vincent.rsf.server.ai.enums.AiErrorCategory; |
| | | import com.vincent.rsf.server.ai.exception.AiChatException; |
| | | import com.vincent.rsf.server.ai.service.impl.AiOpenAiApiSupport; |
| | | import io.micrometer.observation.ObservationRegistry; |
| | | import lombok.RequiredArgsConstructor; |
| | | import org.springframework.ai.model.tool.DefaultToolCallingManager; |
| | | import org.springframework.ai.model.tool.ToolCallingManager; |
| | | import org.springframework.ai.openai.OpenAiChatModel; |
| | | import org.springframework.ai.openai.OpenAiChatOptions; |
| | | import org.springframework.ai.openai.api.OpenAiApi; |
| | | import org.springframework.ai.tool.ToolCallback; |
| | | import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor; |
| | | import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver; |
| | | import org.springframework.ai.util.json.schema.SchemaType; |
| | | import org.springframework.context.support.GenericApplicationContext; |
| | | import org.springframework.stereotype.Component; |
| | | |
| | | import java.util.Arrays; |
| | | import java.util.LinkedHashMap; |
| | | import java.util.Map; |
| | | |
| | | @Component |
| | | @RequiredArgsConstructor |
| | | public class AiOpenAiChatModelFactory { |
| | | |
| | | private final GenericApplicationContext applicationContext; |
| | | private final ObservationRegistry observationRegistry; |
| | | |
| | | public OpenAiChatModel createChatModel(AiParam aiParam) { |
| | | OpenAiApi openAiApi = AiOpenAiApiSupport.buildOpenAiApi(aiParam); |
| | | ToolCallingManager toolCallingManager = DefaultToolCallingManager.builder() |
| | | .observationRegistry(observationRegistry) |
| | | .toolCallbackResolver(new SpringBeanToolCallbackResolver(applicationContext, SchemaType.OPEN_API_SCHEMA)) |
| | | .toolExecutionExceptionProcessor(new DefaultToolExecutionExceptionProcessor(false)) |
| | | .build(); |
| | | return new OpenAiChatModel( |
| | | openAiApi, |
| | | OpenAiChatOptions.builder() |
| | | .model(aiParam.getModel()) |
| | | .temperature(aiParam.getTemperature()) |
| | | .topP(aiParam.getTopP()) |
| | | .maxTokens(aiParam.getMaxTokens()) |
| | | .streamUsage(true) |
| | | .build(), |
| | | toolCallingManager, |
| | | org.springframework.retry.support.RetryTemplate.builder().maxAttempts(1).build(), |
| | | observationRegistry |
| | | ); |
| | | } |
| | | |
| | | public OpenAiChatOptions buildChatOptions(AiParam aiParam, ToolCallback[] toolCallbacks, Long userId, Long tenantId, |
| | | String requestId, Long sessionId, Map<String, Object> metadata) { |
| | | if (userId == null) { |
| | | throw new AiChatException("AI_AUTH_USER_MISSING", AiErrorCategory.AUTH, "OPTIONS_BUILD", "当前登录用户不存在", null); |
| | | } |
| | | OpenAiChatOptions.Builder builder = OpenAiChatOptions.builder() |
| | | .model(aiParam.getModel()) |
| | | .temperature(aiParam.getTemperature()) |
| | | .topP(aiParam.getTopP()) |
| | | .maxTokens(aiParam.getMaxTokens()) |
| | | .streamUsage(true) |
| | | .user(String.valueOf(userId)); |
| | | if (toolCallbacks != null && toolCallbacks.length > 0) { |
| | | builder.toolCallbacks(Arrays.stream(toolCallbacks).toList()); |
| | | } |
| | | Map<String, Object> toolContext = new LinkedHashMap<>(); |
| | | toolContext.put("userId", userId); |
| | | toolContext.put("tenantId", tenantId); |
| | | toolContext.put("requestId", requestId); |
| | | toolContext.put("sessionId", sessionId); |
| | | Map<String, String> metadataMap = new LinkedHashMap<>(); |
| | | if (metadata != null) { |
| | | metadata.forEach((key, value) -> { |
| | | String normalized = value == null ? "" : String.valueOf(value); |
| | | metadataMap.put(key, normalized); |
| | | toolContext.put(key, normalized); |
| | | }); |
| | | } |
| | | builder.toolContext(toolContext); |
| | | if (!metadataMap.isEmpty()) { |
| | | builder.metadata(metadataMap); |
| | | } |
| | | return builder.build(); |
| | | } |
| | | } |
| New file |
| | |
| | | package com.vincent.rsf.server.ai.service.impl.chat; |
| | | |
| | | import com.vincent.rsf.framework.common.Cools; |
| | | import com.vincent.rsf.framework.exception.CoolException; |
| | | import com.vincent.rsf.server.ai.dto.AiChatMemoryDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatMessageDto; |
| | | import com.vincent.rsf.server.ai.entity.AiPrompt; |
| | | import org.springframework.ai.chat.messages.AssistantMessage; |
| | | import org.springframework.ai.chat.messages.Message; |
| | | import org.springframework.ai.chat.messages.SystemMessage; |
| | | import org.springframework.ai.chat.messages.UserMessage; |
| | | import org.springframework.stereotype.Component; |
| | | import org.springframework.util.StringUtils; |
| | | |
| | | import java.util.ArrayList; |
| | | import java.util.List; |
| | | import java.util.Map; |
| | | import java.util.Objects; |
| | | |
| | | @Component |
| | | public class AiPromptMessageBuilder { |
| | | |
| | | public List<Message> buildPromptMessages(AiChatMemoryDto memory, List<AiChatMessageDto> sourceMessages, AiPrompt aiPrompt, |
| | | Map<String, Object> metadata) { |
| | | if (Cools.isEmpty(sourceMessages)) { |
| | | throw new CoolException("对话消息不能为空"); |
| | | } |
| | | List<Message> messages = new ArrayList<>(); |
| | | if (StringUtils.hasText(aiPrompt.getSystemPrompt())) { |
| | | messages.add(new SystemMessage(aiPrompt.getSystemPrompt())); |
| | | } |
| | | if (memory != null && StringUtils.hasText(memory.getMemorySummary())) { |
| | | messages.add(new SystemMessage("历史摘要:\n" + memory.getMemorySummary())); |
| | | } |
| | | if (memory != null && StringUtils.hasText(memory.getMemoryFacts())) { |
| | | messages.add(new SystemMessage("关键事实:\n" + memory.getMemoryFacts())); |
| | | } |
| | | int lastUserIndex = -1; |
| | | for (int i = 0; i < sourceMessages.size(); i++) { |
| | | AiChatMessageDto item = sourceMessages.get(i); |
| | | if (item != null && "user".equalsIgnoreCase(item.getRole())) { |
| | | lastUserIndex = i; |
| | | } |
| | | } |
| | | for (int i = 0; i < sourceMessages.size(); i++) { |
| | | AiChatMessageDto item = sourceMessages.get(i); |
| | | if (item == null || !StringUtils.hasText(item.getContent())) { |
| | | continue; |
| | | } |
| | | String role = item.getRole() == null ? "user" : item.getRole().toLowerCase(); |
| | | if ("system".equals(role)) { |
| | | continue; |
| | | } |
| | | String content = item.getContent(); |
| | | if ("user".equals(role) && i == lastUserIndex) { |
| | | content = renderUserPrompt(aiPrompt.getUserPromptTemplate(), content, metadata); |
| | | } |
| | | if ("assistant".equals(role)) { |
| | | messages.add(new AssistantMessage(content)); |
| | | } else { |
| | | messages.add(new UserMessage(content)); |
| | | } |
| | | } |
| | | if (messages.stream().noneMatch(item -> item instanceof UserMessage)) { |
| | | throw new CoolException("至少需要一条用户消息"); |
| | | } |
| | | return messages; |
| | | } |
| | | |
| | | public List<AiChatMessageDto> mergeMessages(List<AiChatMessageDto> persistedMessages, List<AiChatMessageDto> memoryMessages) { |
| | | List<AiChatMessageDto> merged = new ArrayList<>(); |
| | | if (!Cools.isEmpty(persistedMessages)) { |
| | | merged.addAll(persistedMessages); |
| | | } |
| | | if (!Cools.isEmpty(memoryMessages)) { |
| | | merged.addAll(memoryMessages); |
| | | } |
| | | if (merged.isEmpty()) { |
| | | throw new CoolException("对话消息不能为空"); |
| | | } |
| | | return merged; |
| | | } |
| | | |
| | | public String resolveTitleSeed(List<AiChatMessageDto> messages) { |
| | | if (Cools.isEmpty(messages)) { |
| | | throw new CoolException("对话消息不能为空"); |
| | | } |
| | | for (int i = messages.size() - 1; i >= 0; i--) { |
| | | AiChatMessageDto item = messages.get(i); |
| | | if (item != null && "user".equalsIgnoreCase(item.getRole()) && StringUtils.hasText(item.getContent())) { |
| | | return item.getContent(); |
| | | } |
| | | } |
| | | throw new CoolException("至少需要一条用户消息"); |
| | | } |
| | | |
| | | private String renderUserPrompt(String userPromptTemplate, String content, Map<String, Object> metadata) { |
| | | if (!StringUtils.hasText(userPromptTemplate)) { |
| | | return content; |
| | | } |
| | | String rendered = userPromptTemplate |
| | | .replace("{{input}}", content) |
| | | .replace("{input}", content); |
| | | if (metadata != null) { |
| | | for (Map.Entry<String, Object> entry : metadata.entrySet()) { |
| | | String value = entry.getValue() == null ? "" : String.valueOf(entry.getValue()); |
| | | rendered = rendered.replace("{{" + entry.getKey() + "}}", value); |
| | | rendered = rendered.replace("{" + entry.getKey() + "}", value); |
| | | } |
| | | } |
| | | if (Objects.equals(rendered, userPromptTemplate)) { |
| | | return userPromptTemplate + "\n\n" + content; |
| | | } |
| | | return rendered; |
| | | } |
| | | } |
| New file |
| | |
| | | package com.vincent.rsf.server.ai.service.impl.chat; |
| | | |
| | | import com.fasterxml.jackson.databind.ObjectMapper; |
| | | import com.vincent.rsf.framework.exception.CoolException; |
| | | import com.vincent.rsf.server.ai.dto.AiChatDoneDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatStatusDto; |
| | | import com.vincent.rsf.server.ai.enums.AiErrorCategory; |
| | | import com.vincent.rsf.server.ai.exception.AiChatException; |
| | | import lombok.RequiredArgsConstructor; |
| | | import lombok.extern.slf4j.Slf4j; |
| | | import org.springframework.ai.chat.metadata.ChatResponseMetadata; |
| | | import org.springframework.ai.chat.metadata.Usage; |
| | | import org.springframework.http.MediaType; |
| | | import org.springframework.stereotype.Component; |
| | | import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; |
| | | |
| | | import java.io.IOException; |
| | | import java.time.Instant; |
| | | import java.util.LinkedHashMap; |
| | | import java.util.Map; |
| | | import java.util.concurrent.atomic.AtomicReference; |
| | | |
| | | @Slf4j |
| | | @Component |
| | | @RequiredArgsConstructor |
| | | public class AiSseEventPublisher { |
| | | |
| | | private final ObjectMapper objectMapper; |
| | | |
| | | public void markFirstToken(AtomicReference<Long> firstTokenAtRef, SseEmitter emitter, String requestId, |
| | | Long sessionId, String model, long startedAt, AiThinkingTraceEmitter thinkingTraceEmitter) { |
| | | if (!firstTokenAtRef.compareAndSet(null, System.currentTimeMillis())) { |
| | | return; |
| | | } |
| | | if (thinkingTraceEmitter != null) { |
| | | thinkingTraceEmitter.startAnswer(); |
| | | } |
| | | emitSafely(emitter, "status", AiChatStatusDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .status("FIRST_TOKEN") |
| | | .model(model) |
| | | .timestamp(Instant.now().toEpochMilli()) |
| | | .elapsedMs(System.currentTimeMillis() - startedAt) |
| | | .firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAtRef.get())) |
| | | .build()); |
| | | } |
| | | |
| | | public void emitDone(SseEmitter emitter, String requestId, ChatResponseMetadata metadata, String fallbackModel, |
| | | Long sessionId, long startedAt, Long firstTokenAt) { |
| | | Usage usage = metadata == null ? null : metadata.getUsage(); |
| | | emitStrict(emitter, "done", AiChatDoneDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .model(metadata != null && metadata.getModel() != null && !metadata.getModel().isBlank() ? metadata.getModel() : fallbackModel) |
| | | .elapsedMs(System.currentTimeMillis() - startedAt) |
| | | .firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAt)) |
| | | .promptTokens(usage == null ? null : usage.getPromptTokens()) |
| | | .completionTokens(usage == null ? null : usage.getCompletionTokens()) |
| | | .totalTokens(usage == null ? null : usage.getTotalTokens()) |
| | | .build()); |
| | | } |
| | | |
| | | public AiChatStatusDto buildTerminalStatus(String requestId, Long sessionId, String status, String model, long startedAt, Long firstTokenAt) { |
| | | return AiChatStatusDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .status(status) |
| | | .model(model) |
| | | .timestamp(Instant.now().toEpochMilli()) |
| | | .elapsedMs(System.currentTimeMillis() - startedAt) |
| | | .firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAt)) |
| | | .build(); |
| | | } |
| | | |
| | | public Long resolveFirstTokenLatency(long startedAt, Long firstTokenAt) { |
| | | return firstTokenAt == null ? null : Math.max(0L, firstTokenAt - startedAt); |
| | | } |
| | | |
| | | public Map<String, String> buildMessagePayload(String... keyValues) { |
| | | Map<String, String> payload = new LinkedHashMap<>(); |
| | | if (keyValues == null || keyValues.length == 0) { |
| | | return payload; |
| | | } |
| | | if (keyValues.length % 2 != 0) { |
| | | throw new CoolException("消息载荷参数必须成对出现"); |
| | | } |
| | | for (int i = 0; i < keyValues.length; i += 2) { |
| | | payload.put(keyValues[i], keyValues[i + 1] == null ? "" : keyValues[i + 1]); |
| | | } |
| | | return payload; |
| | | } |
| | | |
| | | public void emitStrict(SseEmitter emitter, String eventName, Object payload) { |
| | | try { |
| | | String data = objectMapper.writeValueAsString(payload); |
| | | emitter.send(SseEmitter.event() |
| | | .name(eventName) |
| | | .data(data, MediaType.APPLICATION_JSON)); |
| | | } catch (IOException e) { |
| | | throw new AiChatException("AI_SSE_EMIT_ERROR", AiErrorCategory.STREAM, "SSE_EMIT", "SSE 输出失败: " + e.getMessage(), e); |
| | | } |
| | | } |
| | | |
| | | public void emitSafely(SseEmitter emitter, String eventName, Object payload) { |
| | | try { |
| | | emitStrict(emitter, eventName, payload); |
| | | } catch (Exception e) { |
| | | log.warn("AI SSE event emit skipped, eventName={}, message={}", eventName, e.getMessage()); |
| | | } |
| | | } |
| | | } |
| New file |
| | |
| | | package com.vincent.rsf.server.ai.service.impl.chat; |
| | | |
| | | import com.vincent.rsf.server.ai.dto.AiChatThinkingEventDto; |
| | | import org.springframework.util.StringUtils; |
| | | import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; |
| | | |
| | | import java.time.Instant; |
| | | import java.util.Objects; |
| | | |
| | | public class AiThinkingTraceEmitter { |
| | | |
| | | private final AiSseEventPublisher aiSseEventPublisher; |
| | | private final SseEmitter emitter; |
| | | private final String requestId; |
| | | private final Long sessionId; |
| | | private String currentPhase; |
| | | private String currentStatus; |
| | | |
| | | public AiThinkingTraceEmitter(AiSseEventPublisher aiSseEventPublisher, SseEmitter emitter, String requestId, Long sessionId) { |
| | | this.aiSseEventPublisher = aiSseEventPublisher; |
| | | this.emitter = emitter; |
| | | this.requestId = requestId; |
| | | this.sessionId = sessionId; |
| | | } |
| | | |
| | | public void startAnalyze() { |
| | | if (currentPhase != null) { |
| | | return; |
| | | } |
| | | currentPhase = "ANALYZE"; |
| | | currentStatus = "STARTED"; |
| | | emitThinkingEvent("ANALYZE", "STARTED", "正在分析问题", |
| | | "已接收你的问题,正在理解意图并判断是否需要调用工具。", null); |
| | | } |
| | | |
| | | public void onToolStart(String toolName, String toolCallId) { |
| | | switchPhase("TOOL_CALL", "STARTED", "正在调用工具", "已判断需要调用工具,正在查询相关信息。", null); |
| | | currentStatus = "UPDATED"; |
| | | emitThinkingEvent("TOOL_CALL", "UPDATED", "正在调用工具", |
| | | "正在调用工具 " + safeLabel(toolName, "未知工具") + " 获取所需信息。", toolCallId); |
| | | } |
| | | |
| | | public void onToolResult(String toolName, String toolCallId, boolean failed) { |
| | | currentPhase = "TOOL_CALL"; |
| | | currentStatus = failed ? "FAILED" : "UPDATED"; |
| | | emitThinkingEvent("TOOL_CALL", failed ? "FAILED" : "UPDATED", |
| | | failed ? "工具调用失败" : "工具调用完成", |
| | | failed |
| | | ? "工具 " + safeLabel(toolName, "未知工具") + " 调用失败,正在评估失败影响并整理可用信息。" |
| | | : "工具 " + safeLabel(toolName, "未知工具") + " 已返回结果,正在继续分析并提炼关键信息。", |
| | | toolCallId); |
| | | } |
| | | |
| | | public void startAnswer() { |
| | | switchPhase("ANSWER", "STARTED", "正在整理答案", "已完成分析,正在组织最终回复内容。", null); |
| | | } |
| | | |
| | | public void completeCurrentPhase() { |
| | | if (!StringUtils.hasText(currentPhase) || isTerminalStatus(currentStatus)) { |
| | | return; |
| | | } |
| | | currentStatus = "COMPLETED"; |
| | | emitThinkingEvent(currentPhase, "COMPLETED", resolveCompleteTitle(currentPhase), |
| | | resolveCompleteContent(currentPhase), null); |
| | | } |
| | | |
| | | public void markTerminated(String terminalStatus) { |
| | | if (!StringUtils.hasText(currentPhase) || isTerminalStatus(currentStatus)) { |
| | | return; |
| | | } |
| | | currentStatus = terminalStatus; |
| | | emitThinkingEvent(currentPhase, terminalStatus, |
| | | "ABORTED".equals(terminalStatus) ? "思考已中止" : "思考失败", |
| | | "ABORTED".equals(terminalStatus) |
| | | ? "本轮对话已被中止,思考过程提前结束。" |
| | | : "本轮对话在生成答案前失败,当前思考过程已停止。", |
| | | null); |
| | | } |
| | | |
| | | private void switchPhase(String nextPhase, String nextStatus, String title, String content, String toolCallId) { |
| | | if (!Objects.equals(currentPhase, nextPhase)) { |
| | | completeCurrentPhase(); |
| | | } |
| | | currentPhase = nextPhase; |
| | | currentStatus = nextStatus; |
| | | emitThinkingEvent(nextPhase, nextStatus, title, content, toolCallId); |
| | | } |
| | | |
| | | private void emitThinkingEvent(String phase, String status, String title, String content, String toolCallId) { |
| | | aiSseEventPublisher.emitSafely(emitter, "thinking", AiChatThinkingEventDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .phase(phase) |
| | | .status(status) |
| | | .title(title) |
| | | .content(content) |
| | | .toolCallId(toolCallId) |
| | | .timestamp(Instant.now().toEpochMilli()) |
| | | .build()); |
| | | } |
| | | |
| | | private boolean isTerminalStatus(String status) { |
| | | return "COMPLETED".equals(status) || "FAILED".equals(status) || "ABORTED".equals(status); |
| | | } |
| | | |
| | | private String resolveCompleteTitle(String phase) { |
| | | if ("ANSWER".equals(phase)) { |
| | | return "答案整理完成"; |
| | | } |
| | | if ("TOOL_CALL".equals(phase)) { |
| | | return "工具分析完成"; |
| | | } |
| | | return "问题分析完成"; |
| | | } |
| | | |
| | | private String resolveCompleteContent(String phase) { |
| | | if ("ANSWER".equals(phase)) { |
| | | return "最终答复已生成完成。"; |
| | | } |
| | | if ("TOOL_CALL".equals(phase)) { |
| | | return "工具调用阶段已结束,相关信息已整理完毕。"; |
| | | } |
| | | return "问题意图和处理方向已分析完成。"; |
| | | } |
| | | |
| | | private String safeLabel(String value, String fallback) { |
| | | return StringUtils.hasText(value) ? value : fallback; |
| | | } |
| | | } |
| New file |
| | |
| | | package com.vincent.rsf.server.ai.service.impl.chat; |
| | | |
| | | import com.vincent.rsf.framework.exception.CoolException; |
| | | import com.vincent.rsf.server.ai.dto.AiChatToolEventDto; |
| | | import com.vincent.rsf.server.ai.service.AiCallLogService; |
| | | import com.vincent.rsf.server.ai.service.MountedToolCallback; |
| | | import com.vincent.rsf.server.ai.store.AiCachedToolResult; |
| | | import com.vincent.rsf.server.ai.store.AiToolResultStore; |
| | | import lombok.RequiredArgsConstructor; |
| | | import org.springframework.ai.chat.model.ToolContext; |
| | | import org.springframework.ai.tool.ToolCallback; |
| | | import org.springframework.stereotype.Component; |
| | | import org.springframework.util.StringUtils; |
| | | import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; |
| | | |
| | | import java.util.ArrayList; |
| | | import java.util.List; |
| | | import java.util.concurrent.atomic.AtomicLong; |
| | | |
| | | @Component |
| | | @RequiredArgsConstructor |
| | | public class AiToolObservationService { |
| | | |
| | | private final AiSseEventPublisher aiSseEventPublisher; |
| | | private final AiToolResultStore aiToolResultStore; |
| | | private final AiCallLogService aiCallLogService; |
| | | |
| | | public ToolCallback[] wrapToolCallbacks(ToolCallback[] toolCallbacks, SseEmitter emitter, String requestId, |
| | | Long sessionId, AtomicLong toolCallSequence, |
| | | AtomicLong toolSuccessCount, AtomicLong toolFailureCount, |
| | | Long callLogId, Long userId, Long tenantId, |
| | | AiThinkingTraceEmitter thinkingTraceEmitter) { |
| | | if (toolCallbacks == null || toolCallbacks.length == 0) { |
| | | return toolCallbacks; |
| | | } |
| | | List<ToolCallback> wrappedCallbacks = new ArrayList<>(); |
| | | for (ToolCallback callback : toolCallbacks) { |
| | | if (callback == null) { |
| | | continue; |
| | | } |
| | | wrappedCallbacks.add(new ObservableToolCallback(callback, emitter, requestId, sessionId, toolCallSequence, |
| | | toolSuccessCount, toolFailureCount, callLogId, userId, tenantId, thinkingTraceEmitter)); |
| | | } |
| | | return wrappedCallbacks.toArray(new ToolCallback[0]); |
| | | } |
| | | |
| | | private String summarizeToolPayload(String content, int maxLength) { |
| | | if (!StringUtils.hasText(content)) { |
| | | return null; |
| | | } |
| | | String normalized = content.trim() |
| | | .replace("\r", " ") |
| | | .replace("\n", " ") |
| | | .replaceAll("\\s+", " "); |
| | | return normalized.length() > maxLength ? normalized.substring(0, maxLength) : normalized; |
| | | } |
| | | |
| | | private class ObservableToolCallback implements ToolCallback { |
| | | |
| | | private final ToolCallback delegate; |
| | | private final SseEmitter emitter; |
| | | private final String requestId; |
| | | private final Long sessionId; |
| | | private final AtomicLong toolCallSequence; |
| | | private final AtomicLong toolSuccessCount; |
| | | private final AtomicLong toolFailureCount; |
| | | private final Long callLogId; |
| | | private final Long userId; |
| | | private final Long tenantId; |
| | | private final AiThinkingTraceEmitter thinkingTraceEmitter; |
| | | |
| | | private ObservableToolCallback(ToolCallback delegate, SseEmitter emitter, String requestId, |
| | | Long sessionId, AtomicLong toolCallSequence, |
| | | AtomicLong toolSuccessCount, AtomicLong toolFailureCount, |
| | | Long callLogId, Long userId, Long tenantId, |
| | | AiThinkingTraceEmitter thinkingTraceEmitter) { |
| | | this.delegate = delegate; |
| | | this.emitter = emitter; |
| | | this.requestId = requestId; |
| | | this.sessionId = sessionId; |
| | | this.toolCallSequence = toolCallSequence; |
| | | this.toolSuccessCount = toolSuccessCount; |
| | | this.toolFailureCount = toolFailureCount; |
| | | this.callLogId = callLogId; |
| | | this.userId = userId; |
| | | this.tenantId = tenantId; |
| | | this.thinkingTraceEmitter = thinkingTraceEmitter; |
| | | } |
| | | |
| | | @Override |
| | | public org.springframework.ai.tool.definition.ToolDefinition getToolDefinition() { |
| | | return delegate.getToolDefinition(); |
| | | } |
| | | |
| | | @Override |
| | | public org.springframework.ai.tool.metadata.ToolMetadata getToolMetadata() { |
| | | return delegate.getToolMetadata(); |
| | | } |
| | | |
| | | @Override |
| | | public String call(String toolInput) { |
| | | return call(toolInput, null); |
| | | } |
| | | |
| | | @Override |
| | | public String call(String toolInput, ToolContext toolContext) { |
| | | String toolName = delegate.getToolDefinition() == null ? "unknown" : delegate.getToolDefinition().name(); |
| | | String mountName = delegate instanceof MountedToolCallback ? ((MountedToolCallback) delegate).getMountName() : null; |
| | | String toolCallId = requestId + "-tool-" + toolCallSequence.incrementAndGet(); |
| | | long startedAt = System.currentTimeMillis(); |
| | | AiCachedToolResult cachedToolResult = aiToolResultStore.getToolResult(tenantId, requestId, toolName, toolInput); |
| | | if (cachedToolResult != null) { |
| | | aiSseEventPublisher.emitSafely(emitter, "tool_result", AiChatToolEventDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .toolCallId(toolCallId) |
| | | .toolName(toolName) |
| | | .mountName(mountName) |
| | | .status(cachedToolResult.isSuccess() ? "COMPLETED" : "FAILED") |
| | | .inputSummary(summarizeToolPayload(toolInput, 400)) |
| | | .outputSummary(summarizeToolPayload(cachedToolResult.getOutput(), 600)) |
| | | .errorMessage(cachedToolResult.getErrorMessage()) |
| | | .durationMs(0L) |
| | | .timestamp(System.currentTimeMillis()) |
| | | .build()); |
| | | if (thinkingTraceEmitter != null) { |
| | | thinkingTraceEmitter.onToolResult(toolName, toolCallId, !cachedToolResult.isSuccess()); |
| | | } |
| | | if (cachedToolResult.isSuccess()) { |
| | | toolSuccessCount.incrementAndGet(); |
| | | aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName, |
| | | "COMPLETED", summarizeToolPayload(toolInput, 400), summarizeToolPayload(cachedToolResult.getOutput(), 600), |
| | | null, 0L, userId, tenantId); |
| | | return cachedToolResult.getOutput(); |
| | | } |
| | | toolFailureCount.incrementAndGet(); |
| | | aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName, |
| | | "FAILED", summarizeToolPayload(toolInput, 400), null, cachedToolResult.getErrorMessage(), |
| | | 0L, userId, tenantId); |
| | | throw new CoolException(cachedToolResult.getErrorMessage()); |
| | | } |
| | | if (thinkingTraceEmitter != null) { |
| | | thinkingTraceEmitter.onToolStart(toolName, toolCallId); |
| | | } |
| | | aiSseEventPublisher.emitSafely(emitter, "tool_start", AiChatToolEventDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .toolCallId(toolCallId) |
| | | .toolName(toolName) |
| | | .mountName(mountName) |
| | | .status("STARTED") |
| | | .inputSummary(summarizeToolPayload(toolInput, 400)) |
| | | .timestamp(startedAt) |
| | | .build()); |
| | | try { |
| | | String output = toolContext == null ? delegate.call(toolInput) : delegate.call(toolInput, toolContext); |
| | | long durationMs = System.currentTimeMillis() - startedAt; |
| | | aiSseEventPublisher.emitSafely(emitter, "tool_result", AiChatToolEventDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .toolCallId(toolCallId) |
| | | .toolName(toolName) |
| | | .mountName(mountName) |
| | | .status("COMPLETED") |
| | | .inputSummary(summarizeToolPayload(toolInput, 400)) |
| | | .outputSummary(summarizeToolPayload(output, 600)) |
| | | .durationMs(durationMs) |
| | | .timestamp(System.currentTimeMillis()) |
| | | .build()); |
| | | if (thinkingTraceEmitter != null) { |
| | | thinkingTraceEmitter.onToolResult(toolName, toolCallId, false); |
| | | } |
| | | aiToolResultStore.cacheToolResult(tenantId, requestId, toolName, toolInput, true, output, null); |
| | | toolSuccessCount.incrementAndGet(); |
| | | aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName, |
| | | "COMPLETED", summarizeToolPayload(toolInput, 400), summarizeToolPayload(output, 600), |
| | | null, durationMs, userId, tenantId); |
| | | return output; |
| | | } catch (RuntimeException e) { |
| | | long durationMs = System.currentTimeMillis() - startedAt; |
| | | aiSseEventPublisher.emitSafely(emitter, "tool_error", AiChatToolEventDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .toolCallId(toolCallId) |
| | | .toolName(toolName) |
| | | .mountName(mountName) |
| | | .status("FAILED") |
| | | .inputSummary(summarizeToolPayload(toolInput, 400)) |
| | | .errorMessage(e.getMessage()) |
| | | .durationMs(durationMs) |
| | | .timestamp(System.currentTimeMillis()) |
| | | .build()); |
| | | if (thinkingTraceEmitter != null) { |
| | | thinkingTraceEmitter.onToolResult(toolName, toolCallId, true); |
| | | } |
| | | aiToolResultStore.cacheToolResult(tenantId, requestId, toolName, toolInput, false, null, e.getMessage()); |
| | | toolFailureCount.incrementAndGet(); |
| | | aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName, |
| | | "FAILED", summarizeToolPayload(toolInput, 400), null, e.getMessage(), |
| | | durationMs, userId, tenantId); |
| | | throw e; |
| | | } |
| | | } |
| | | } |
| | | } |
| New file |
| | |
| | | package com.vincent.rsf.server.ai.service.impl.conversation; |
| | | |
| | | import com.vincent.rsf.framework.exception.CoolException; |
| | | import com.vincent.rsf.server.ai.dto.AiChatMessageDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatSessionPinRequest; |
| | | import com.vincent.rsf.server.ai.dto.AiChatSessionRenameRequest; |
| | | import com.vincent.rsf.server.ai.dto.AiChatSessionDto; |
| | | import com.vincent.rsf.server.ai.entity.AiChatMessage; |
| | | import com.vincent.rsf.server.ai.entity.AiChatSession; |
| | | import com.vincent.rsf.server.ai.mapper.AiChatMessageMapper; |
| | | import com.vincent.rsf.server.ai.mapper.AiChatSessionMapper; |
| | | import com.vincent.rsf.server.system.enums.StatusType; |
| | | import lombok.RequiredArgsConstructor; |
| | | import org.springframework.stereotype.Service; |
| | | import org.springframework.transaction.support.TransactionSynchronization; |
| | | import org.springframework.transaction.support.TransactionSynchronizationManager; |
| | | import org.springframework.transaction.annotation.Transactional; |
| | | import org.springframework.util.StringUtils; |
| | | |
| | | import java.util.ArrayList; |
| | | import java.util.Date; |
| | | import java.util.List; |
| | | |
| | | @Service |
| | | @RequiredArgsConstructor |
| | | public class AiConversationCommandService { |
| | | |
| | | private final AiChatSessionMapper aiChatSessionMapper; |
| | | private final AiChatMessageMapper aiChatMessageMapper; |
| | | private final AiConversationQueryService aiConversationQueryService; |
| | | private final AiMemoryProfileService aiMemoryProfileService; |
| | | |
| | | @Transactional(rollbackFor = Exception.class) |
| | | public AiChatSession resolveSession(Long userId, Long tenantId, String promptCode, Long sessionId, String titleSeed) { |
| | | aiConversationQueryService.ensureIdentity(userId, tenantId); |
| | | String resolvedPromptCode = aiConversationQueryService.requirePromptCode(promptCode); |
| | | if (sessionId != null) { |
| | | return aiConversationQueryService.getSession(sessionId, userId, tenantId, resolvedPromptCode); |
| | | } |
| | | Date now = new Date(); |
| | | AiChatSession session = new AiChatSession() |
| | | .setTitle(aiConversationQueryService.buildSessionTitle(titleSeed)) |
| | | .setPromptCode(resolvedPromptCode) |
| | | .setUserId(userId) |
| | | .setTenantId(tenantId) |
| | | .setLastMessageTime(now) |
| | | .setPinned(0) |
| | | .setStatus(StatusType.ENABLE.val) |
| | | .setDeleted(0) |
| | | .setCreateBy(userId) |
| | | .setCreateTime(now) |
| | | .setUpdateBy(userId) |
| | | .setUpdateTime(now); |
| | | aiChatSessionMapper.insert(session); |
| | | afterConversationMutationCommitted(() -> aiConversationQueryService.evictConversationCaches(tenantId, userId)); |
| | | return session; |
| | | } |
| | | |
| | | @Transactional(rollbackFor = Exception.class) |
| | | public void saveRound(AiChatSession session, Long userId, Long tenantId, List<AiChatMessageDto> memoryMessages, String assistantContent) { |
| | | if (session == null || session.getId() == null) { |
| | | throw new CoolException("AI 会话不存在"); |
| | | } |
| | | aiConversationQueryService.ensureIdentity(userId, tenantId); |
| | | List<AiChatMessageDto> normalizedMessages = normalizeMessages(memoryMessages); |
| | | if (normalizedMessages.isEmpty()) { |
| | | throw new CoolException("本轮没有可保存的对话消息"); |
| | | } |
| | | int nextSeqNo = aiConversationQueryService.findNextSeqNo(session.getId()); |
| | | Date now = new Date(); |
| | | List<AiChatMessage> records = new ArrayList<>(); |
| | | for (AiChatMessageDto message : normalizedMessages) { |
| | | records.add(buildMessageEntity(session.getId(), nextSeqNo++, message.getRole(), message.getContent(), userId, tenantId, now)); |
| | | } |
| | | if (StringUtils.hasText(assistantContent)) { |
| | | records.add(buildMessageEntity(session.getId(), nextSeqNo, "assistant", assistantContent, userId, tenantId, now)); |
| | | } |
| | | aiChatMessageMapper.insertBatch(records); |
| | | aiChatSessionMapper.updateById(new AiChatSession() |
| | | .setId(session.getId()) |
| | | .setTitle(aiConversationQueryService.resolveUpdatedTitle(session.getTitle(), normalizedMessages)) |
| | | .setLastMessageTime(now) |
| | | .setUpdateBy(userId) |
| | | .setUpdateTime(now)); |
| | | afterConversationMutationCommitted(() -> { |
| | | aiConversationQueryService.evictConversationCaches(tenantId, userId); |
| | | aiMemoryProfileService.scheduleMemoryProfileRefresh(session.getId(), userId, tenantId); |
| | | }); |
| | | } |
| | | |
| | | @Transactional(rollbackFor = Exception.class) |
| | | public void removeSession(Long userId, Long tenantId, Long sessionId) { |
| | | aiConversationQueryService.ensureIdentity(userId, tenantId); |
| | | AiChatSession session = aiConversationQueryService.requireOwnedSession(sessionId, userId, tenantId); |
| | | aiChatSessionMapper.updateById(new AiChatSession() |
| | | .setId(session.getId()) |
| | | .setDeleted(1) |
| | | .setUpdateBy(userId) |
| | | .setUpdateTime(new Date())); |
| | | aiChatMessageMapper.softDeleteBySessionId(sessionId); |
| | | afterConversationMutationCommitted(() -> aiConversationQueryService.evictConversationCaches(tenantId, userId)); |
| | | } |
| | | |
| | | @Transactional(rollbackFor = Exception.class) |
| | | public AiChatSessionDto renameSession(Long userId, Long tenantId, Long sessionId, AiChatSessionRenameRequest request) { |
| | | aiConversationQueryService.ensureIdentity(userId, tenantId); |
| | | if (request == null || !StringUtils.hasText(request.getTitle())) { |
| | | throw new CoolException("会话标题不能为空"); |
| | | } |
| | | AiChatSession session = aiConversationQueryService.requireOwnedSession(sessionId, userId, tenantId); |
| | | aiChatSessionMapper.updateById(new AiChatSession() |
| | | .setId(sessionId) |
| | | .setTitle(aiConversationQueryService.buildSessionTitle(request.getTitle())) |
| | | .setUpdateBy(userId) |
| | | .setUpdateTime(new Date())); |
| | | afterConversationMutationCommitted(() -> aiConversationQueryService.evictConversationCaches(tenantId, userId)); |
| | | return reloadSessionDto(sessionId, userId, tenantId, session.getPromptCode()); |
| | | } |
| | | |
| | | @Transactional(rollbackFor = Exception.class) |
| | | public AiChatSessionDto pinSession(Long userId, Long tenantId, Long sessionId, AiChatSessionPinRequest request) { |
| | | aiConversationQueryService.ensureIdentity(userId, tenantId); |
| | | if (request == null || request.getPinned() == null) { |
| | | throw new CoolException("置顶状态不能为空"); |
| | | } |
| | | AiChatSession session = aiConversationQueryService.requireOwnedSession(sessionId, userId, tenantId); |
| | | aiChatSessionMapper.updateById(new AiChatSession() |
| | | .setId(sessionId) |
| | | .setPinned(Boolean.TRUE.equals(request.getPinned()) ? 1 : 0) |
| | | .setUpdateBy(userId) |
| | | .setUpdateTime(new Date())); |
| | | afterConversationMutationCommitted(() -> aiConversationQueryService.evictConversationCaches(tenantId, userId)); |
| | | return reloadSessionDto(sessionId, userId, tenantId, session.getPromptCode()); |
| | | } |
| | | |
| | | @Transactional(rollbackFor = Exception.class) |
| | | public void clearSessionMemory(Long userId, Long tenantId, Long sessionId) { |
| | | aiConversationQueryService.ensureIdentity(userId, tenantId); |
| | | AiChatSession session = aiConversationQueryService.requireOwnedSession(sessionId, userId, tenantId); |
| | | aiChatMessageMapper.softDeleteBySessionId(sessionId); |
| | | aiChatSessionMapper.updateById(new AiChatSession() |
| | | .setId(sessionId) |
| | | .setMemorySummary(null) |
| | | .setMemoryFacts(null) |
| | | .setUpdateBy(userId) |
| | | .setUpdateTime(new Date()) |
| | | .setLastMessageTime(session.getCreateTime())); |
| | | afterConversationMutationCommitted(() -> aiConversationQueryService.evictConversationCaches(tenantId, userId)); |
| | | } |
| | | |
| | | @Transactional(rollbackFor = Exception.class) |
| | | public void retainLatestRound(Long userId, Long tenantId, Long sessionId) { |
| | | aiConversationQueryService.ensureIdentity(userId, tenantId); |
| | | aiConversationQueryService.requireOwnedSession(sessionId, userId, tenantId); |
| | | List<AiChatMessage> records = aiConversationQueryService.listMessageRecords(sessionId); |
| | | if (records.isEmpty()) { |
| | | return; |
| | | } |
| | | List<AiChatMessage> retained = aiConversationQueryService.tailMessageRecordsByRounds(records, 1); |
| | | List<Long> retainedIds = retained.stream().map(AiChatMessage::getId).toList(); |
| | | List<Long> deletedIds = records.stream() |
| | | .map(AiChatMessage::getId) |
| | | .filter(id -> !retainedIds.contains(id)) |
| | | .toList(); |
| | | if (!deletedIds.isEmpty()) { |
| | | aiChatMessageMapper.softDeleteByIds(deletedIds); |
| | | } |
| | | afterConversationMutationCommitted(() -> { |
| | | aiConversationQueryService.evictConversationCaches(tenantId, userId); |
| | | aiMemoryProfileService.scheduleMemoryProfileRefresh(sessionId, userId, tenantId); |
| | | }); |
| | | } |
| | | |
| | | private AiChatSessionDto reloadSessionDto(Long sessionId, Long userId, Long tenantId, String promptCode) { |
| | | return aiConversationQueryService.listSessions(userId, tenantId, promptCode, null).stream() |
| | | .filter(item -> sessionId.equals(item.getSessionId())) |
| | | .findFirst() |
| | | .orElseThrow(() -> new CoolException("AI 会话不存在或无权访问")); |
| | | } |
| | | |
| | | private List<AiChatMessageDto> normalizeMessages(List<AiChatMessageDto> memoryMessages) { |
| | | List<AiChatMessageDto> normalized = new ArrayList<>(); |
| | | if (memoryMessages == null || memoryMessages.isEmpty()) { |
| | | return normalized; |
| | | } |
| | | for (AiChatMessageDto item : memoryMessages) { |
| | | if (item == null || !StringUtils.hasText(item.getContent())) { |
| | | continue; |
| | | } |
| | | String role = item.getRole() == null ? "user" : item.getRole().toLowerCase(); |
| | | if ("system".equals(role)) { |
| | | continue; |
| | | } |
| | | AiChatMessageDto normalizedItem = new AiChatMessageDto(); |
| | | normalizedItem.setRole("assistant".equals(role) ? "assistant" : "user"); |
| | | normalizedItem.setContent(item.getContent().trim()); |
| | | normalized.add(normalizedItem); |
| | | } |
| | | return normalized; |
| | | } |
| | | |
| | | private AiChatMessage buildMessageEntity(Long sessionId, int seqNo, String role, String content, Long userId, Long tenantId, Date createTime) { |
| | | return new AiChatMessage() |
| | | .setSessionId(sessionId) |
| | | .setSeqNo(seqNo) |
| | | .setRole(role) |
| | | .setContent(content) |
| | | .setContentLength(content == null ? 0 : content.length()) |
| | | .setUserId(userId) |
| | | .setTenantId(tenantId) |
| | | .setDeleted(0) |
| | | .setCreateBy(userId) |
| | | .setCreateTime(createTime); |
| | | } |
| | | |
| | | private void afterConversationMutationCommitted(Runnable action) { |
| | | if (action == null) { |
| | | return; |
| | | } |
| | | if (!TransactionSynchronizationManager.isSynchronizationActive()) { |
| | | action.run(); |
| | | return; |
| | | } |
| | | TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronization() { |
| | | @Override |
| | | public void afterCommit() { |
| | | action.run(); |
| | | } |
| | | }); |
| | | } |
| | | } |
| New file |
| | |
| | | package com.vincent.rsf.server.ai.service.impl.conversation; |
| | | |
| | | import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; |
| | | import com.vincent.rsf.framework.common.Cools; |
| | | import com.vincent.rsf.framework.exception.CoolException; |
| | | import com.vincent.rsf.server.ai.config.AiDefaults; |
| | | import com.vincent.rsf.server.ai.dto.AiChatMemoryDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatMessageDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatSessionDto; |
| | | import com.vincent.rsf.server.ai.entity.AiChatMessage; |
| | | import com.vincent.rsf.server.ai.entity.AiChatSession; |
| | | import com.vincent.rsf.server.ai.mapper.AiChatMessageMapper; |
| | | import com.vincent.rsf.server.ai.mapper.AiChatSessionMapper; |
| | | import com.vincent.rsf.server.ai.store.AiConversationCacheStore; |
| | | import com.vincent.rsf.server.system.enums.StatusType; |
| | | import lombok.RequiredArgsConstructor; |
| | | import org.springframework.stereotype.Service; |
| | | import org.springframework.util.StringUtils; |
| | | |
| | | import java.util.ArrayList; |
| | | import java.util.LinkedHashMap; |
| | | import java.util.List; |
| | | import java.util.Map; |
| | | |
| | | @Service |
| | | @RequiredArgsConstructor |
| | | public class AiConversationQueryService { |
| | | |
| | | private final AiChatSessionMapper aiChatSessionMapper; |
| | | private final AiChatMessageMapper aiChatMessageMapper; |
| | | private final AiConversationCacheStore aiConversationCacheStore; |
| | | |
| | | public AiChatMemoryDto getMemory(Long userId, Long tenantId, String promptCode, Long sessionId) { |
| | | ensureIdentity(userId, tenantId); |
| | | String resolvedPromptCode = requirePromptCode(promptCode); |
| | | AiChatMemoryDto cached = aiConversationCacheStore.getMemory(tenantId, userId, resolvedPromptCode, sessionId); |
| | | if (cached != null) { |
| | | return cached; |
| | | } |
| | | AiChatSession session = sessionId == null |
| | | ? findLatestSession(userId, tenantId, resolvedPromptCode) |
| | | : getSession(sessionId, userId, tenantId, resolvedPromptCode); |
| | | AiChatMemoryDto memory; |
| | | if (session == null) { |
| | | memory = AiChatMemoryDto.builder() |
| | | .sessionId(null) |
| | | .memorySummary(null) |
| | | .memoryFacts(null) |
| | | .recentMessageCount(0) |
| | | .persistedMessages(List.of()) |
| | | .shortMemoryMessages(List.of()) |
| | | .build(); |
| | | aiConversationCacheStore.cacheMemory(tenantId, userId, resolvedPromptCode, sessionId, memory); |
| | | return memory; |
| | | } |
| | | List<AiChatMessageDto> persistedMessages = listMessages(session.getId()); |
| | | List<AiChatMessageDto> shortMemoryMessages = tailMessagesByRounds(persistedMessages, AiDefaults.MEMORY_RECENT_ROUNDS); |
| | | memory = AiChatMemoryDto.builder() |
| | | .sessionId(session.getId()) |
| | | .memorySummary(session.getMemorySummary()) |
| | | .memoryFacts(session.getMemoryFacts()) |
| | | .recentMessageCount(shortMemoryMessages.size()) |
| | | .persistedMessages(persistedMessages) |
| | | .shortMemoryMessages(shortMemoryMessages) |
| | | .build(); |
| | | aiConversationCacheStore.cacheMemory(tenantId, userId, resolvedPromptCode, session.getId(), memory); |
| | | if (sessionId == null || !session.getId().equals(sessionId)) { |
| | | aiConversationCacheStore.cacheMemory(tenantId, userId, resolvedPromptCode, null, memory); |
| | | } |
| | | return memory; |
| | | } |
| | | |
| | | public List<AiChatSessionDto> listSessions(Long userId, Long tenantId, String promptCode, String keyword) { |
| | | ensureIdentity(userId, tenantId); |
| | | String resolvedPromptCode = requirePromptCode(promptCode); |
| | | List<AiChatSessionDto> cached = aiConversationCacheStore.getSessionList(tenantId, userId, resolvedPromptCode, keyword); |
| | | if (cached != null) { |
| | | return cached; |
| | | } |
| | | List<AiChatSession> sessions = aiChatSessionMapper.selectList(new LambdaQueryWrapper<AiChatSession>() |
| | | .eq(AiChatSession::getUserId, userId) |
| | | .eq(AiChatSession::getTenantId, tenantId) |
| | | .eq(AiChatSession::getPromptCode, resolvedPromptCode) |
| | | .eq(AiChatSession::getDeleted, 0) |
| | | .eq(AiChatSession::getStatus, StatusType.ENABLE.val) |
| | | .like(StringUtils.hasText(keyword), AiChatSession::getTitle, keyword == null ? null : keyword.trim()) |
| | | .orderByDesc(AiChatSession::getPinned) |
| | | .orderByDesc(AiChatSession::getLastMessageTime) |
| | | .orderByDesc(AiChatSession::getId)); |
| | | if (Cools.isEmpty(sessions)) { |
| | | aiConversationCacheStore.cacheSessionList(tenantId, userId, resolvedPromptCode, keyword, List.of()); |
| | | return List.of(); |
| | | } |
| | | List<Long> sessionIds = sessions.stream().map(AiChatSession::getId).toList(); |
| | | Map<Long, AiChatMessage> latestMessageMap = new LinkedHashMap<>(); |
| | | for (AiChatMessage message : aiChatMessageMapper.selectLatestMessagesBySessionIds(sessionIds)) { |
| | | latestMessageMap.put(message.getSessionId(), message); |
| | | } |
| | | List<AiChatSessionDto> result = new ArrayList<>(); |
| | | for (AiChatSession session : sessions) { |
| | | result.add(buildSessionDto(session, latestMessageMap.get(session.getId()))); |
| | | } |
| | | aiConversationCacheStore.cacheSessionList(tenantId, userId, resolvedPromptCode, keyword, result); |
| | | return result; |
| | | } |
| | | |
| | | public void evictConversationCaches(Long tenantId, Long userId) { |
| | | aiConversationCacheStore.evictUserConversationCaches(tenantId, userId); |
| | | } |
| | | |
| | | public AiChatSession findLatestSession(Long userId, Long tenantId, String promptCode) { |
| | | return aiChatSessionMapper.selectOne(new LambdaQueryWrapper<AiChatSession>() |
| | | .eq(AiChatSession::getUserId, userId) |
| | | .eq(AiChatSession::getTenantId, tenantId) |
| | | .eq(AiChatSession::getPromptCode, promptCode) |
| | | .eq(AiChatSession::getDeleted, 0) |
| | | .eq(AiChatSession::getStatus, StatusType.ENABLE.val) |
| | | .orderByDesc(AiChatSession::getLastMessageTime) |
| | | .orderByDesc(AiChatSession::getId) |
| | | .last("limit 1")); |
| | | } |
| | | |
| | | public AiChatSession getSession(Long sessionId, Long userId, Long tenantId, String promptCode) { |
| | | AiChatSession session = aiChatSessionMapper.selectOne(new LambdaQueryWrapper<AiChatSession>() |
| | | .eq(AiChatSession::getId, sessionId) |
| | | .eq(AiChatSession::getUserId, userId) |
| | | .eq(AiChatSession::getTenantId, tenantId) |
| | | .eq(AiChatSession::getPromptCode, promptCode) |
| | | .eq(AiChatSession::getDeleted, 0) |
| | | .eq(AiChatSession::getStatus, StatusType.ENABLE.val) |
| | | .last("limit 1")); |
| | | if (session == null) { |
| | | throw new CoolException("AI 会话不存在或无权访问"); |
| | | } |
| | | return session; |
| | | } |
| | | |
| | | public AiChatSession requireOwnedSession(Long sessionId, Long userId, Long tenantId) { |
| | | if (sessionId == null) { |
| | | throw new CoolException("AI 会话 ID 不能为空"); |
| | | } |
| | | AiChatSession session = aiChatSessionMapper.selectOne(new LambdaQueryWrapper<AiChatSession>() |
| | | .eq(AiChatSession::getId, sessionId) |
| | | .eq(AiChatSession::getUserId, userId) |
| | | .eq(AiChatSession::getTenantId, tenantId) |
| | | .eq(AiChatSession::getDeleted, 0) |
| | | .eq(AiChatSession::getStatus, StatusType.ENABLE.val) |
| | | .last("limit 1")); |
| | | if (session == null) { |
| | | throw new CoolException("AI 会话不存在或无权访问"); |
| | | } |
| | | return session; |
| | | } |
| | | |
| | | public List<AiChatMessageDto> listMessages(Long sessionId) { |
| | | List<AiChatMessage> records = listMessageRecords(sessionId); |
| | | if (Cools.isEmpty(records)) { |
| | | return List.of(); |
| | | } |
| | | List<AiChatMessageDto> messages = new ArrayList<>(); |
| | | for (AiChatMessage record : records) { |
| | | if (!StringUtils.hasText(record.getContent())) { |
| | | continue; |
| | | } |
| | | AiChatMessageDto item = new AiChatMessageDto(); |
| | | item.setRole(record.getRole()); |
| | | item.setContent(record.getContent()); |
| | | messages.add(item); |
| | | } |
| | | return messages; |
| | | } |
| | | |
| | | public List<AiChatMessage> listMessageRecords(Long sessionId) { |
| | | return aiChatMessageMapper.selectList(new LambdaQueryWrapper<AiChatMessage>() |
| | | .eq(AiChatMessage::getSessionId, sessionId) |
| | | .eq(AiChatMessage::getDeleted, 0) |
| | | .orderByAsc(AiChatMessage::getSeqNo) |
| | | .orderByAsc(AiChatMessage::getId)); |
| | | } |
| | | |
| | | public int findNextSeqNo(Long sessionId) { |
| | | AiChatMessage lastMessage = aiChatMessageMapper.selectOne(new LambdaQueryWrapper<AiChatMessage>() |
| | | .eq(AiChatMessage::getSessionId, sessionId) |
| | | .eq(AiChatMessage::getDeleted, 0) |
| | | .orderByDesc(AiChatMessage::getSeqNo) |
| | | .orderByDesc(AiChatMessage::getId) |
| | | .last("limit 1")); |
| | | return lastMessage == null || lastMessage.getSeqNo() == null ? 1 : lastMessage.getSeqNo() + 1; |
| | | } |
| | | |
| | | public String resolveUpdatedTitle(String currentTitle, List<AiChatMessageDto> memoryMessages) { |
| | | if (StringUtils.hasText(currentTitle)) { |
| | | return currentTitle; |
| | | } |
| | | for (AiChatMessageDto item : memoryMessages) { |
| | | if ("user".equals(item.getRole()) && StringUtils.hasText(item.getContent())) { |
| | | return buildSessionTitle(item.getContent()); |
| | | } |
| | | } |
| | | return null; |
| | | } |
| | | |
| | | public String buildSessionTitle(String titleSeed) { |
| | | if (!StringUtils.hasText(titleSeed)) { |
| | | throw new CoolException("AI 会话标题不能为空"); |
| | | } |
| | | String title = titleSeed.trim() |
| | | .replace("\r", " ") |
| | | .replace("\n", " ") |
| | | .replaceAll("\\s+", " "); |
| | | int punctuationIndex = findSummaryBreakIndex(title); |
| | | if (punctuationIndex > 0) { |
| | | title = title.substring(0, punctuationIndex).trim(); |
| | | } |
| | | return title.length() > 48 ? title.substring(0, 48) : title; |
| | | } |
| | | |
| | | public List<AiChatMessageDto> tailMessagesByRounds(List<AiChatMessageDto> source, int rounds) { |
| | | if (Cools.isEmpty(source) || rounds <= 0) { |
| | | return List.of(); |
| | | } |
| | | int userCount = 0; |
| | | int startIndex = source.size(); |
| | | for (int i = source.size() - 1; i >= 0; i--) { |
| | | AiChatMessageDto item = source.get(i); |
| | | startIndex = i; |
| | | if (item != null && "user".equalsIgnoreCase(item.getRole())) { |
| | | userCount++; |
| | | if (userCount >= rounds) { |
| | | break; |
| | | } |
| | | } |
| | | } |
| | | return new ArrayList<>(source.subList(Math.max(0, startIndex), source.size())); |
| | | } |
| | | |
| | | public List<AiChatMessage> tailMessageRecordsByRounds(List<AiChatMessage> source, int rounds) { |
| | | if (Cools.isEmpty(source) || rounds <= 0) { |
| | | return List.of(); |
| | | } |
| | | int userCount = 0; |
| | | int startIndex = source.size(); |
| | | for (int i = source.size() - 1; i >= 0; i--) { |
| | | AiChatMessage item = source.get(i); |
| | | startIndex = i; |
| | | if (item != null && "user".equalsIgnoreCase(item.getRole())) { |
| | | userCount++; |
| | | if (userCount >= rounds) { |
| | | break; |
| | | } |
| | | } |
| | | } |
| | | return new ArrayList<>(source.subList(Math.max(0, startIndex), source.size())); |
| | | } |
| | | |
| | | public String compactText(String content, int maxLength) { |
| | | if (!StringUtils.hasText(content)) { |
| | | return null; |
| | | } |
| | | String normalized = content.trim() |
| | | .replace("\r", " ") |
| | | .replace("\n", " ") |
| | | .replaceAll("\\s+", " "); |
| | | return normalized.length() > maxLength ? normalized.substring(0, maxLength) : normalized; |
| | | } |
| | | |
| | | public String requirePromptCode(String promptCode) { |
| | | if (!StringUtils.hasText(promptCode)) { |
| | | throw new CoolException("Prompt 编码不能为空"); |
| | | } |
| | | return promptCode; |
| | | } |
| | | |
| | | public void ensureIdentity(Long userId, Long tenantId) { |
| | | if (userId == null) { |
| | | throw new CoolException("当前登录用户不存在"); |
| | | } |
| | | if (tenantId == null) { |
| | | throw new CoolException("当前租户不存在"); |
| | | } |
| | | } |
| | | |
| | | private int findSummaryBreakIndex(String title) { |
| | | String[] separators = {"。", "!", "?", ".", "!", "?"}; |
| | | int result = -1; |
| | | for (String separator : separators) { |
| | | int index = title.indexOf(separator); |
| | | if (index > 0 && (result < 0 || index < result)) { |
| | | result = index; |
| | | } |
| | | } |
| | | return result; |
| | | } |
| | | |
| | | private AiChatSessionDto buildSessionDto(AiChatSession session, AiChatMessage lastMessage) { |
| | | return AiChatSessionDto.builder() |
| | | .sessionId(session.getId()) |
| | | .title(session.getTitle()) |
| | | .promptCode(session.getPromptCode()) |
| | | .pinned(session.getPinned() != null && session.getPinned() == 1) |
| | | .lastMessagePreview(buildLastMessagePreview(lastMessage)) |
| | | .lastMessageTime(session.getLastMessageTime()) |
| | | .build(); |
| | | } |
| | | |
| | | private String buildLastMessagePreview(AiChatMessage message) { |
| | | if (message == null || !StringUtils.hasText(message.getContent())) { |
| | | return null; |
| | | } |
| | | String preview = message.getContent().trim() |
| | | .replace("\r", " ") |
| | | .replace("\n", " ") |
| | | .replaceAll("\\s+", " "); |
| | | String prefix = "assistant".equalsIgnoreCase(message.getRole()) ? "AI: " : "你: "; |
| | | String normalized = prefix + preview; |
| | | return normalized.length() > 80 ? normalized.substring(0, 80) : normalized; |
| | | } |
| | | } |
| New file |
| | |
| | | package com.vincent.rsf.server.ai.service.impl.conversation; |
| | | |
| | | import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; |
| | | import com.vincent.rsf.framework.common.Cools; |
| | | import com.vincent.rsf.server.ai.config.AiDefaults; |
| | | import com.vincent.rsf.server.ai.dto.AiChatMessageDto; |
| | | import com.vincent.rsf.server.ai.entity.AiChatMessage; |
| | | import com.vincent.rsf.server.ai.entity.AiChatSession; |
| | | import com.vincent.rsf.server.ai.mapper.AiChatMessageMapper; |
| | | import com.vincent.rsf.server.ai.mapper.AiChatSessionMapper; |
| | | import lombok.RequiredArgsConstructor; |
| | | import lombok.extern.slf4j.Slf4j; |
| | | import org.springframework.beans.factory.annotation.Qualifier; |
| | | import org.springframework.stereotype.Service; |
| | | import org.springframework.util.StringUtils; |
| | | |
| | | import java.util.Date; |
| | | import java.util.List; |
| | | import java.util.Set; |
| | | import java.util.concurrent.ConcurrentHashMap; |
| | | import java.util.concurrent.Executor; |
| | | |
| | | @Slf4j |
| | | @Service |
| | | @RequiredArgsConstructor |
| | | public class AiMemoryProfileService { |
| | | |
| | | private final AiConversationQueryService aiConversationQueryService; |
| | | private final AiChatSessionMapper aiChatSessionMapper; |
| | | private final AiChatMessageMapper aiChatMessageMapper; |
| | | @Qualifier("aiMemoryTaskExecutor") |
| | | private final Executor aiMemoryTaskExecutor; |
| | | |
| | | private final Set<Long> refreshingSessionIds = ConcurrentHashMap.newKeySet(); |
| | | private final Set<Long> pendingRefreshSessionIds = ConcurrentHashMap.newKeySet(); |
| | | |
| | | public void scheduleMemoryProfileRefresh(Long sessionId, Long userId, Long tenantId) { |
| | | if (sessionId == null) { |
| | | return; |
| | | } |
| | | if (!refreshingSessionIds.add(sessionId)) { |
| | | pendingRefreshSessionIds.add(sessionId); |
| | | return; |
| | | } |
| | | aiMemoryTaskExecutor.execute(() -> runMemoryProfileRefreshLoop(sessionId, userId, tenantId)); |
| | | } |
| | | |
| | | private void runMemoryProfileRefreshLoop(Long sessionId, Long userId, Long tenantId) { |
| | | try { |
| | | boolean shouldContinue; |
| | | do { |
| | | pendingRefreshSessionIds.remove(sessionId); |
| | | try { |
| | | refreshMemoryProfile(sessionId, userId); |
| | | aiConversationQueryService.evictConversationCaches(tenantId, userId); |
| | | } catch (Exception e) { |
| | | log.warn("AI memory profile refresh failed, sessionId={}, userId={}, tenantId={}, message={}", |
| | | sessionId, userId, tenantId, e.getMessage(), e); |
| | | } |
| | | shouldContinue = pendingRefreshSessionIds.remove(sessionId); |
| | | } while (shouldContinue); |
| | | } finally { |
| | | refreshingSessionIds.remove(sessionId); |
| | | if (pendingRefreshSessionIds.remove(sessionId) && refreshingSessionIds.add(sessionId)) { |
| | | aiMemoryTaskExecutor.execute(() -> runMemoryProfileRefreshLoop(sessionId, userId, tenantId)); |
| | | } |
| | | } |
| | | } |
| | | |
| | | private void refreshMemoryProfile(Long sessionId, Long userId) { |
| | | List<AiChatMessageDto> messages = aiConversationQueryService.listMessages(sessionId); |
| | | List<AiChatMessageDto> shortMemoryMessages = aiConversationQueryService.tailMessagesByRounds(messages, AiDefaults.MEMORY_RECENT_ROUNDS); |
| | | List<AiChatMessageDto> historyMessages = messages.size() > shortMemoryMessages.size() |
| | | ? messages.subList(0, messages.size() - shortMemoryMessages.size()) |
| | | : List.of(); |
| | | String memorySummary = historyMessages.size() >= AiDefaults.MEMORY_SUMMARY_TRIGGER_MESSAGES |
| | | ? buildMemorySummary(historyMessages) |
| | | : null; |
| | | String memoryFacts = buildMemoryFacts(messages); |
| | | AiChatMessage lastMessage = aiChatMessageMapper.selectOne(new LambdaQueryWrapper<AiChatMessage>() |
| | | .eq(AiChatMessage::getSessionId, sessionId) |
| | | .eq(AiChatMessage::getDeleted, 0) |
| | | .orderByDesc(AiChatMessage::getSeqNo) |
| | | .orderByDesc(AiChatMessage::getId) |
| | | .last("limit 1")); |
| | | aiChatSessionMapper.updateById(new AiChatSession() |
| | | .setId(sessionId) |
| | | .setMemorySummary(memorySummary) |
| | | .setMemoryFacts(memoryFacts) |
| | | .setLastMessageTime(lastMessage == null ? null : lastMessage.getCreateTime()) |
| | | .setUpdateBy(userId) |
| | | .setUpdateTime(new Date())); |
| | | } |
| | | |
| | | private String buildMemorySummary(List<AiChatMessageDto> historyMessages) { |
| | | StringBuilder builder = new StringBuilder("较早对话摘要:\n"); |
| | | for (AiChatMessageDto item : historyMessages) { |
| | | if (item == null || !StringUtils.hasText(item.getContent())) { |
| | | continue; |
| | | } |
| | | String prefix = "assistant".equalsIgnoreCase(item.getRole()) ? "- AI: " : "- 用户: "; |
| | | String content = aiConversationQueryService.compactText(item.getContent(), 120); |
| | | if (!StringUtils.hasText(content)) { |
| | | continue; |
| | | } |
| | | builder.append(prefix).append(content).append("\n"); |
| | | if (builder.length() >= AiDefaults.MEMORY_SUMMARY_MAX_LENGTH) { |
| | | break; |
| | | } |
| | | } |
| | | return aiConversationQueryService.compactText(builder.toString(), AiDefaults.MEMORY_SUMMARY_MAX_LENGTH); |
| | | } |
| | | |
| | | private String buildMemoryFacts(List<AiChatMessageDto> messages) { |
| | | if (Cools.isEmpty(messages)) { |
| | | return null; |
| | | } |
| | | StringBuilder builder = new StringBuilder("关键事实:\n"); |
| | | int userFacts = 0; |
| | | for (int i = messages.size() - 1; i >= 0 && userFacts < 4; i--) { |
| | | AiChatMessageDto item = messages.get(i); |
| | | if (item == null || !"user".equalsIgnoreCase(item.getRole()) || !StringUtils.hasText(item.getContent())) { |
| | | continue; |
| | | } |
| | | builder.append("- 用户关注: ").append(aiConversationQueryService.compactText(item.getContent(), 100)).append("\n"); |
| | | userFacts++; |
| | | } |
| | | return userFacts == 0 ? null : aiConversationQueryService.compactText(builder.toString(), AiDefaults.MEMORY_FACTS_MAX_LENGTH); |
| | | } |
| | | } |
| New file |
| | |
| | | package com.vincent.rsf.server.ai.service.impl.mcp; |
| | | |
| | | import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; |
| | | import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper; |
| | | import com.fasterxml.jackson.databind.ObjectMapper; |
| | | import com.vincent.rsf.framework.exception.CoolException; |
| | | import com.vincent.rsf.server.ai.config.AiDefaults; |
| | | import com.vincent.rsf.server.ai.dto.AiMcpConnectivityTestDto; |
| | | import com.vincent.rsf.server.ai.dto.AiMcpToolPreviewDto; |
| | | import com.vincent.rsf.server.ai.dto.AiMcpToolTestDto; |
| | | import com.vincent.rsf.server.ai.dto.AiMcpToolTestRequest; |
| | | import com.vincent.rsf.server.ai.entity.AiMcpMount; |
| | | import com.vincent.rsf.server.ai.mapper.AiMcpMountMapper; |
| | | import com.vincent.rsf.server.ai.service.BuiltinMcpToolRegistry; |
| | | import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory; |
| | | import lombok.RequiredArgsConstructor; |
| | | import org.springframework.ai.chat.model.ToolContext; |
| | | import org.springframework.ai.tool.ToolCallback; |
| | | import org.springframework.stereotype.Service; |
| | | import org.springframework.util.StringUtils; |
| | | |
| | | import java.text.SimpleDateFormat; |
| | | import java.util.ArrayList; |
| | | import java.util.Arrays; |
| | | import java.util.Date; |
| | | import java.util.LinkedHashMap; |
| | | import java.util.List; |
| | | import java.util.Map; |
| | | |
| | | @Service |
| | | @RequiredArgsConstructor |
| | | public class AiMcpAdminService { |
| | | |
| | | private final AiMcpMountMapper aiMcpMountMapper; |
| | | private final BuiltinMcpToolRegistry builtinMcpToolRegistry; |
| | | private final McpMountRuntimeFactory mcpMountRuntimeFactory; |
| | | private final ObjectMapper objectMapper; |
| | | |
| | | public List<AiMcpToolPreviewDto> previewTools(AiMcpMount mount, Long userId) { |
| | | long startedAt = System.currentTimeMillis(); |
| | | try (McpMountRuntimeFactory.McpMountRuntime runtime = mcpMountRuntimeFactory.create(List.of(mount), userId)) { |
| | | List<AiMcpToolPreviewDto> tools = buildToolPreviewDtos(runtime.getToolCallbacks(), |
| | | AiDefaults.MCP_TRANSPORT_BUILTIN.equals(mount.getTransportType()) |
| | | ? builtinMcpToolRegistry.listBuiltinToolCatalog(mount.getBuiltinCode()) |
| | | : List.of()); |
| | | if (!runtime.getErrors().isEmpty()) { |
| | | String message = String.join(";", runtime.getErrors()); |
| | | if (mount.getId() != null) { |
| | | updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY, message, System.currentTimeMillis() - startedAt); |
| | | } |
| | | throw new CoolException(message); |
| | | } |
| | | if (mount.getId() != null) { |
| | | updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_HEALTHY, |
| | | "工具解析成功,共 " + tools.size() + " 个工具", System.currentTimeMillis() - startedAt); |
| | | } |
| | | return tools; |
| | | } catch (CoolException e) { |
| | | throw e; |
| | | } catch (Exception e) { |
| | | if (mount.getId() != null) { |
| | | updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY, |
| | | "工具解析失败: " + e.getMessage(), System.currentTimeMillis() - startedAt); |
| | | } |
| | | throw new CoolException("获取工具列表失败: " + e.getMessage()); |
| | | } |
| | | } |
| | | |
| | | public AiMcpConnectivityTestDto testConnectivity(AiMcpMount mount, Long userId, boolean persistHealth) { |
| | | long startedAt = System.currentTimeMillis(); |
| | | try (McpMountRuntimeFactory.McpMountRuntime runtime = mcpMountRuntimeFactory.create(List.of(mount), userId)) { |
| | | long elapsedMs = System.currentTimeMillis() - startedAt; |
| | | if (!runtime.getErrors().isEmpty()) { |
| | | String message = String.join(";", runtime.getErrors()); |
| | | if (persistHealth && mount.getId() != null) { |
| | | updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY, message, elapsedMs); |
| | | AiMcpMount latest = requireMount(mount.getId(), mount.getTenantId()); |
| | | return buildConnectivityDto(latest, message, elapsedMs, runtime.getToolCallbacks().length); |
| | | } |
| | | return AiMcpConnectivityTestDto.builder() |
| | | .mountId(mount.getId()) |
| | | .mountName(mount.getName()) |
| | | .healthStatus(AiDefaults.MCP_HEALTH_UNHEALTHY) |
| | | .message(message) |
| | | .initElapsedMs(elapsedMs) |
| | | .toolCount(runtime.getToolCallbacks().length) |
| | | .testedAt(new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new Date())) |
| | | .build(); |
| | | } |
| | | String message = persistHealth && mount.getId() != null |
| | | ? "连通性测试成功,解析出 " + runtime.getToolCallbacks().length + " 个工具" |
| | | : "草稿连通性测试成功,解析出 " + runtime.getToolCallbacks().length + " 个工具"; |
| | | if (persistHealth && mount.getId() != null) { |
| | | updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_HEALTHY, message, elapsedMs); |
| | | AiMcpMount latest = requireMount(mount.getId(), mount.getTenantId()); |
| | | return buildConnectivityDto(latest, message, elapsedMs, runtime.getToolCallbacks().length); |
| | | } |
| | | return AiMcpConnectivityTestDto.builder() |
| | | .mountId(mount.getId()) |
| | | .mountName(mount.getName()) |
| | | .healthStatus(AiDefaults.MCP_HEALTH_HEALTHY) |
| | | .message(message) |
| | | .initElapsedMs(elapsedMs) |
| | | .toolCount(runtime.getToolCallbacks().length) |
| | | .testedAt(new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new Date())) |
| | | .build(); |
| | | } catch (CoolException e) { |
| | | throw e; |
| | | } catch (Exception e) { |
| | | long elapsedMs = System.currentTimeMillis() - startedAt; |
| | | String message = (persistHealth ? "连通性测试失败: " : "草稿连通性测试失败: ") + e.getMessage(); |
| | | if (persistHealth && mount.getId() != null) { |
| | | updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY, message, elapsedMs); |
| | | AiMcpMount latest = requireMount(mount.getId(), mount.getTenantId()); |
| | | return buildConnectivityDto(latest, message, elapsedMs, 0); |
| | | } |
| | | throw new CoolException(message); |
| | | } |
| | | } |
| | | |
| | | public AiMcpToolTestDto testTool(AiMcpMount mount, Long userId, Long tenantId, AiMcpToolTestRequest request) { |
| | | if (userId == null) { |
| | | throw new CoolException("当前登录用户不存在"); |
| | | } |
| | | if (tenantId == null) { |
| | | throw new CoolException("当前租户不存在"); |
| | | } |
| | | if (request == null) { |
| | | throw new CoolException("工具测试参数不能为空"); |
| | | } |
| | | if (!StringUtils.hasText(request.getToolName())) { |
| | | throw new CoolException("工具名称不能为空"); |
| | | } |
| | | if (!StringUtils.hasText(request.getInputJson())) { |
| | | throw new CoolException("工具输入 JSON 不能为空"); |
| | | } |
| | | try { |
| | | objectMapper.readTree(request.getInputJson()); |
| | | } catch (Exception e) { |
| | | throw new CoolException("工具输入 JSON 格式错误: " + e.getMessage()); |
| | | } |
| | | long startedAt = System.currentTimeMillis(); |
| | | try (McpMountRuntimeFactory.McpMountRuntime runtime = mcpMountRuntimeFactory.create(List.of(mount), userId)) { |
| | | ToolCallback callback = Arrays.stream(runtime.getToolCallbacks()) |
| | | .filter(item -> item != null && item.getToolDefinition() != null) |
| | | .filter(item -> request.getToolName().equals(item.getToolDefinition().name())) |
| | | .findFirst() |
| | | .orElseThrow(() -> new CoolException("未找到要测试的工具: " + request.getToolName())); |
| | | String output = callback.call( |
| | | request.getInputJson(), |
| | | new ToolContext(Map.of("userId", userId, "tenantId", tenantId, "mountId", mount.getId())) |
| | | ); |
| | | if (mount.getId() != null) { |
| | | updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_HEALTHY, |
| | | "工具测试成功: " + request.getToolName(), System.currentTimeMillis() - startedAt); |
| | | } |
| | | return AiMcpToolTestDto.builder() |
| | | .toolName(request.getToolName()) |
| | | .inputJson(request.getInputJson()) |
| | | .output(output) |
| | | .build(); |
| | | } catch (CoolException e) { |
| | | if (mount.getId() != null) { |
| | | updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY, |
| | | "工具测试失败: " + e.getMessage(), System.currentTimeMillis() - startedAt); |
| | | } |
| | | throw e; |
| | | } catch (Exception e) { |
| | | if (mount.getId() != null) { |
| | | updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY, |
| | | "工具测试失败: " + e.getMessage(), System.currentTimeMillis() - startedAt); |
| | | } |
| | | throw new CoolException("工具测试失败: " + e.getMessage()); |
| | | } |
| | | } |
| | | |
| | | public AiMcpMount requireMount(Long mountId, Long tenantId) { |
| | | if (tenantId == null) { |
| | | throw new CoolException("当前租户不存在"); |
| | | } |
| | | if (mountId == null) { |
| | | throw new CoolException("MCP 挂载 ID 不能为空"); |
| | | } |
| | | AiMcpMount mount = aiMcpMountMapper.selectOne(new LambdaQueryWrapper<AiMcpMount>() |
| | | .eq(AiMcpMount::getId, mountId) |
| | | .eq(AiMcpMount::getTenantId, tenantId) |
| | | .eq(AiMcpMount::getDeleted, 0) |
| | | .last("limit 1")); |
| | | if (mount == null) { |
| | | throw new CoolException("MCP 挂载不存在"); |
| | | } |
| | | return mount; |
| | | } |
| | | |
| | | private List<AiMcpToolPreviewDto> buildToolPreviewDtos(ToolCallback[] callbacks, List<AiMcpToolPreviewDto> governedCatalog) { |
| | | List<AiMcpToolPreviewDto> tools = new ArrayList<>(); |
| | | Map<String, AiMcpToolPreviewDto> catalogMap = new LinkedHashMap<>(); |
| | | for (AiMcpToolPreviewDto item : governedCatalog) { |
| | | if (item == null || !StringUtils.hasText(item.getName())) { |
| | | continue; |
| | | } |
| | | catalogMap.put(item.getName(), item); |
| | | } |
| | | for (ToolCallback callback : callbacks) { |
| | | if (callback == null || callback.getToolDefinition() == null) { |
| | | continue; |
| | | } |
| | | AiMcpToolPreviewDto governedItem = catalogMap.get(callback.getToolDefinition().name()); |
| | | tools.add(AiMcpToolPreviewDto.builder() |
| | | .name(callback.getToolDefinition().name()) |
| | | .description(callback.getToolDefinition().description()) |
| | | .inputSchema(callback.getToolDefinition().inputSchema()) |
| | | .returnDirect(callback.getToolMetadata() == null ? null : callback.getToolMetadata().returnDirect()) |
| | | .toolGroup(governedItem == null ? null : governedItem.getToolGroup()) |
| | | .toolPurpose(governedItem == null ? null : governedItem.getToolPurpose()) |
| | | .queryBoundary(governedItem == null ? null : governedItem.getQueryBoundary()) |
| | | .exampleQuestions(governedItem == null ? List.of() : governedItem.getExampleQuestions()) |
| | | .build()); |
| | | } |
| | | return tools; |
| | | } |
| | | |
| | | private void updateHealthStatus(Long mountId, String healthStatus, String message, Long initElapsedMs) { |
| | | aiMcpMountMapper.update(null, new LambdaUpdateWrapper<AiMcpMount>() |
| | | .eq(AiMcpMount::getId, mountId) |
| | | .set(AiMcpMount::getHealthStatus, healthStatus) |
| | | .set(AiMcpMount::getLastTestTime, new Date()) |
| | | .set(AiMcpMount::getLastTestMessage, message) |
| | | .set(AiMcpMount::getLastInitElapsedMs, initElapsedMs)); |
| | | } |
| | | |
| | | private AiMcpConnectivityTestDto buildConnectivityDto(AiMcpMount mount, String message, Long initElapsedMs, Integer toolCount) { |
| | | return AiMcpConnectivityTestDto.builder() |
| | | .mountId(mount.getId()) |
| | | .mountName(mount.getName()) |
| | | .healthStatus(mount.getHealthStatus()) |
| | | .message(message) |
| | | .initElapsedMs(initElapsedMs) |
| | | .toolCount(toolCount) |
| | | .testedAt(mount.getLastTestTime$()) |
| | | .build(); |
| | | } |
| | | } |
| New file |
| | |
| | | package com.vincent.rsf.server.ai.service.impl.mcp; |
| | | |
| | | import com.vincent.rsf.framework.exception.CoolException; |
| | | import com.vincent.rsf.server.ai.config.AiDefaults; |
| | | import com.vincent.rsf.server.ai.dto.AiMcpToolPreviewDto; |
| | | import org.springframework.stereotype.Component; |
| | | |
| | | import java.util.LinkedHashMap; |
| | | import java.util.List; |
| | | import java.util.Map; |
| | | |
| | | @Component |
| | | public class BuiltinMcpToolCatalogProvider { |
| | | |
| | | public List<String> supportedBuiltinCodes() { |
| | | return List.of(AiDefaults.MCP_BUILTIN_RSF_WMS); |
| | | } |
| | | |
| | | public Map<String, AiMcpToolPreviewDto> getCatalog(String builtinCode) { |
| | | if (AiDefaults.MCP_BUILTIN_RSF_WMS.equals(builtinCode)) { |
| | | Map<String, AiMcpToolPreviewDto> catalog = new LinkedHashMap<>(); |
| | | catalog.put("rsf_query_available_inventory", buildCatalogItem( |
| | | "rsf_query_available_inventory", |
| | | "库存查询", |
| | | "查询指定物料当前可用于出库的库存明细。", |
| | | "必须提供物料编码或物料名称,并且最多返回 50 条库存记录。", |
| | | List.of("查询物料 MAT001 当前可出库库存", "按物料名称查询托盘库存明细") |
| | | )); |
| | | catalog.put("rsf_query_station_list", buildCatalogItem( |
| | | "rsf_query_station_list", |
| | | "库存查询", |
| | | "查询指定作业类型可用的设备站点。", |
| | | "必须提供站点类型列表,类型数量最多 10 个,最多返回 50 个站点。", |
| | | List.of("查询入库和出库作业可用站点", "列出 AGV_PICK 类型的作业站点") |
| | | )); |
| | | catalog.put("rsf_query_task", buildCatalogItem( |
| | | "rsf_query_task", |
| | | "任务查询", |
| | | "按任务号、状态、类型或站点条件查询任务;支持从自然语言中自动提取任务号,精确命中时返回任务明细。", |
| | | "过滤条件均为可选,不传过滤条件时默认返回最近任务,最多返回 50 条记录。", |
| | | List.of("查询最近 10 条任务", "查询任务号 TASK24001 的详情") |
| | | )); |
| | | catalog.put("rsf_query_warehouses", buildCatalogItem( |
| | | "rsf_query_warehouses", |
| | | "基础资料", |
| | | "查询仓库基础信息。", |
| | | "至少提供仓库编码或名称,最多返回 50 条仓库记录。", |
| | | List.of("查询编码包含 WH 的仓库", "按仓库名称查询仓库地址") |
| | | )); |
| | | catalog.put("rsf_query_bas_stations", buildCatalogItem( |
| | | "rsf_query_bas_stations", |
| | | "基础资料", |
| | | "查询基础站点信息。", |
| | | "至少提供站点编号、站点名称或使用状态之一,最多返回 50 条站点记录。", |
| | | List.of("查询使用中的基础站点", "按站点编号查询基础站点") |
| | | )); |
| | | catalog.put("rsf_query_dict_data", buildCatalogItem( |
| | | "rsf_query_dict_data", |
| | | "基础资料", |
| | | "查询指定字典类型下的字典数据。", |
| | | "必须提供字典类型编码,最多返回 100 条字典记录。", |
| | | List.of("查询 task_status 字典", "按字典标签过滤 task_type 字典数据") |
| | | )); |
| | | return catalog; |
| | | } |
| | | throw new CoolException("不支持的内置 MCP 编码: " + builtinCode); |
| | | } |
| | | |
| | | private AiMcpToolPreviewDto buildCatalogItem(String name, String toolGroup, String toolPurpose, |
| | | String queryBoundary, List<String> exampleQuestions) { |
| | | return AiMcpToolPreviewDto.builder() |
| | | .name(name) |
| | | .toolGroup(toolGroup) |
| | | .toolPurpose(toolPurpose) |
| | | .queryBoundary(queryBoundary) |
| | | .exampleQuestions(exampleQuestions) |
| | | .build(); |
| | | } |
| | | } |
| New file |
| | |
| | | package com.vincent.rsf.server.ai.service.impl.mcp; |
| | | |
| | | import com.fasterxml.jackson.core.type.TypeReference; |
| | | import com.fasterxml.jackson.databind.ObjectMapper; |
| | | import com.vincent.rsf.framework.exception.CoolException; |
| | | import com.vincent.rsf.server.ai.config.AiDefaults; |
| | | import com.vincent.rsf.server.ai.entity.AiMcpMount; |
| | | import io.modelcontextprotocol.client.McpClient; |
| | | import io.modelcontextprotocol.client.McpSyncClient; |
| | | import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; |
| | | import io.modelcontextprotocol.client.transport.ServerParameters; |
| | | import io.modelcontextprotocol.client.transport.StdioClientTransport; |
| | | import io.modelcontextprotocol.json.jackson.JacksonMcpJsonMapper; |
| | | import io.modelcontextprotocol.spec.McpSchema; |
| | | import lombok.RequiredArgsConstructor; |
| | | import lombok.extern.slf4j.Slf4j; |
| | | import org.springframework.stereotype.Component; |
| | | import org.springframework.util.StringUtils; |
| | | |
| | | import java.time.Duration; |
| | | import java.util.Collections; |
| | | import java.util.LinkedHashMap; |
| | | import java.util.List; |
| | | import java.util.Map; |
| | | |
| | | @Slf4j |
| | | @Component |
| | | @RequiredArgsConstructor |
| | | public class McpClientFactory { |
| | | |
| | | private final ObjectMapper objectMapper; |
| | | |
| | | public McpSyncClient createClient(AiMcpMount mount) { |
| | | Duration timeout = Duration.ofMillis(mount.getRequestTimeoutMs() == null |
| | | ? AiDefaults.DEFAULT_TIMEOUT_MS |
| | | : mount.getRequestTimeoutMs()); |
| | | JacksonMcpJsonMapper jsonMapper = new JacksonMcpJsonMapper(objectMapper); |
| | | if (AiDefaults.MCP_TRANSPORT_STDIO.equals(mount.getTransportType())) { |
| | | ServerParameters.Builder parametersBuilder = ServerParameters.builder(mount.getCommand()); |
| | | List<String> args = readStringList(mount.getArgsJson()); |
| | | if (!args.isEmpty()) { |
| | | parametersBuilder.args(args); |
| | | } |
| | | Map<String, String> env = readStringMap(mount.getEnvJson()); |
| | | if (!env.isEmpty()) { |
| | | parametersBuilder.env(env); |
| | | } |
| | | StdioClientTransport transport = new StdioClientTransport(parametersBuilder.build(), jsonMapper); |
| | | transport.setStdErrorHandler(message -> log.warn("MCP STDIO stderr [{}]: {}", mount.getName(), message)); |
| | | return McpClient.sync(transport) |
| | | .requestTimeout(timeout) |
| | | .initializationTimeout(timeout) |
| | | .clientInfo(new McpSchema.Implementation("rsf-ai-client", "RSF AI Client", "1.0.0")) |
| | | .build(); |
| | | } |
| | | if (!AiDefaults.MCP_TRANSPORT_SSE_HTTP.equals(mount.getTransportType())) { |
| | | throw new CoolException("不支持的 MCP 传输类型: " + mount.getTransportType()); |
| | | } |
| | | |
| | | if (!StringUtils.hasText(mount.getServerUrl())) { |
| | | throw new CoolException("MCP 服务地址不能为空"); |
| | | } |
| | | HttpClientSseClientTransport.Builder transportBuilder = HttpClientSseClientTransport.builder(mount.getServerUrl()) |
| | | .jsonMapper(jsonMapper) |
| | | .connectTimeout(timeout); |
| | | if (StringUtils.hasText(mount.getEndpoint())) { |
| | | transportBuilder.sseEndpoint(mount.getEndpoint()); |
| | | } |
| | | Map<String, String> headers = readStringMap(mount.getHeadersJson()); |
| | | if (!headers.isEmpty()) { |
| | | transportBuilder.customizeRequest(builder -> headers.forEach(builder::header)); |
| | | } |
| | | return McpClient.sync(transportBuilder.build()) |
| | | .requestTimeout(timeout) |
| | | .initializationTimeout(timeout) |
| | | .clientInfo(new McpSchema.Implementation("rsf-ai-client", "RSF AI Client", "1.0.0")) |
| | | .build(); |
| | | } |
| | | |
| | | private List<String> readStringList(String json) { |
| | | if (!StringUtils.hasText(json)) { |
| | | return Collections.emptyList(); |
| | | } |
| | | try { |
| | | return objectMapper.readValue(json, new TypeReference<List<String>>() { |
| | | }); |
| | | } catch (Exception e) { |
| | | throw new CoolException("解析 MCP 列表配置失败: " + e.getMessage()); |
| | | } |
| | | } |
| | | |
| | | private Map<String, String> readStringMap(String json) { |
| | | if (!StringUtils.hasText(json)) { |
| | | return Collections.emptyMap(); |
| | | } |
| | | try { |
| | | Map<String, String> result = objectMapper.readValue(json, new TypeReference<LinkedHashMap<String, String>>() { |
| | | }); |
| | | return result == null ? Collections.emptyMap() : result; |
| | | } catch (Exception e) { |
| | | throw new CoolException("解析 MCP Map 配置失败: " + e.getMessage()); |
| | | } |
| | | } |
| | | } |
| New file |
| | |
| | | package com.vincent.rsf.server.ai.store; |
| | | |
| | | import lombok.AllArgsConstructor; |
| | | import lombok.Builder; |
| | | import lombok.Data; |
| | | import lombok.NoArgsConstructor; |
| | | |
| | | @Data |
| | | @Builder |
| | | @NoArgsConstructor |
| | | @AllArgsConstructor |
| | | public class AiCachedToolResult { |
| | | |
| | | private boolean success; |
| | | |
| | | private String output; |
| | | |
| | | private String errorMessage; |
| | | } |
| New file |
| | |
| | | package com.vincent.rsf.server.ai.store; |
| | | |
| | | import com.vincent.rsf.server.ai.config.AiDefaults; |
| | | import com.vincent.rsf.server.ai.store.support.AiRedisExecutor; |
| | | import com.vincent.rsf.server.ai.store.support.AiRedisKeys; |
| | | import lombok.RequiredArgsConstructor; |
| | | import org.springframework.stereotype.Component; |
| | | |
| | | import java.time.Instant; |
| | | import java.util.UUID; |
| | | |
| | | @Component |
| | | @RequiredArgsConstructor |
| | | public class AiChatRateLimiter { |
| | | |
| | | private final AiRedisExecutor aiRedisExecutor; |
| | | private final AiRedisKeys aiRedisKeys; |
| | | |
| | | public boolean allowChatRequest(Long tenantId, Long userId, String promptCode) { |
| | | String key = aiRedisKeys.buildRateLimitKey(tenantId, userId, promptCode); |
| | | long now = Instant.now().toEpochMilli(); |
| | | long windowStart = now - (AiDefaults.CHAT_RATE_LIMIT_WINDOW_SECONDS * 1000L); |
| | | Boolean allowed = aiRedisExecutor.execute(jedis -> { |
| | | jedis.zremrangeByScore(key, 0, windowStart); |
| | | long count = jedis.zcard(key); |
| | | if (count >= AiDefaults.CHAT_RATE_LIMIT_MAX_REQUESTS) { |
| | | jedis.expire(key, AiDefaults.CHAT_RATE_LIMIT_WINDOW_SECONDS); |
| | | return Boolean.FALSE; |
| | | } |
| | | jedis.zadd(key, now, now + ":" + UUID.randomUUID()); |
| | | jedis.expire(key, AiDefaults.CHAT_RATE_LIMIT_WINDOW_SECONDS); |
| | | return Boolean.TRUE; |
| | | }); |
| | | return Boolean.TRUE.equals(allowed); |
| | | } |
| | | } |
| New file |
| | |
| | | package com.vincent.rsf.server.ai.store; |
| | | |
| | | import com.vincent.rsf.server.ai.config.AiDefaults; |
| | | import com.vincent.rsf.server.ai.dto.AiResolvedConfig; |
| | | import com.vincent.rsf.server.ai.store.support.AiRedisExecutor; |
| | | import com.vincent.rsf.server.ai.store.support.AiRedisIndexSupport; |
| | | import com.vincent.rsf.server.ai.store.support.AiRedisKeys; |
| | | import lombok.RequiredArgsConstructor; |
| | | import org.springframework.stereotype.Component; |
| | | |
| | | @Component |
| | | @RequiredArgsConstructor |
| | | public class AiConfigCacheStore { |
| | | |
| | | private final AiRedisExecutor aiRedisExecutor; |
| | | private final AiRedisIndexSupport aiRedisIndexSupport; |
| | | private final AiRedisKeys aiRedisKeys; |
| | | |
| | | public AiResolvedConfig getResolvedConfig(Long tenantId, String promptCode, Long aiParamId) { |
| | | return aiRedisExecutor.readJson(aiRedisKeys.buildConfigKey(tenantId, promptCode, aiParamId), AiResolvedConfig.class); |
| | | } |
| | | |
| | | public void cacheResolvedConfig(Long tenantId, String promptCode, Long aiParamId, AiResolvedConfig config) { |
| | | String key = aiRedisKeys.buildConfigKey(tenantId, promptCode, aiParamId); |
| | | aiRedisExecutor.writeJson(key, config, AiDefaults.CONFIG_CACHE_TTL_SECONDS); |
| | | aiRedisIndexSupport.remember(aiRedisKeys.buildConfigIndexKey(tenantId), key); |
| | | } |
| | | |
| | | public void evictTenantConfigCaches(Long tenantId) { |
| | | aiRedisIndexSupport.deleteTrackedKeys(aiRedisKeys.buildConfigIndexKey(tenantId)); |
| | | } |
| | | } |
| New file |
| | |
| | | package com.vincent.rsf.server.ai.store; |
| | | |
| | | import com.fasterxml.jackson.core.type.TypeReference; |
| | | import com.vincent.rsf.server.ai.config.AiDefaults; |
| | | import com.vincent.rsf.server.ai.dto.AiChatMemoryDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatRuntimeDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatSessionDto; |
| | | import com.vincent.rsf.server.ai.store.support.AiRedisExecutor; |
| | | import com.vincent.rsf.server.ai.store.support.AiRedisIndexSupport; |
| | | import com.vincent.rsf.server.ai.store.support.AiRedisKeys; |
| | | import lombok.RequiredArgsConstructor; |
| | | import org.springframework.stereotype.Component; |
| | | |
| | | import java.util.List; |
| | | |
| | | @Component |
| | | @RequiredArgsConstructor |
| | | public class AiConversationCacheStore { |
| | | |
| | | private final AiRedisExecutor aiRedisExecutor; |
| | | private final AiRedisIndexSupport aiRedisIndexSupport; |
| | | private final AiRedisKeys aiRedisKeys; |
| | | |
| | | public AiChatRuntimeDto getRuntime(Long tenantId, Long userId, String promptCode, Long sessionId, Long aiParamId) { |
| | | return aiRedisExecutor.readJson(aiRedisKeys.buildRuntimeKey(tenantId, userId, promptCode, sessionId, aiParamId), AiChatRuntimeDto.class); |
| | | } |
| | | |
| | | public void cacheRuntime(Long tenantId, Long userId, String promptCode, Long sessionId, Long aiParamId, AiChatRuntimeDto runtime) { |
| | | String key = aiRedisKeys.buildRuntimeKey(tenantId, userId, promptCode, sessionId, aiParamId); |
| | | aiRedisExecutor.writeJson(key, runtime, AiDefaults.RUNTIME_CACHE_TTL_SECONDS); |
| | | rememberConversationKey(tenantId, userId, key); |
| | | aiRedisIndexSupport.remember(aiRedisKeys.buildTenantRuntimeIndexKey(tenantId), key); |
| | | } |
| | | |
| | | public AiChatMemoryDto getMemory(Long tenantId, Long userId, String promptCode, Long sessionId) { |
| | | return aiRedisExecutor.readJson(aiRedisKeys.buildMemoryKey(tenantId, userId, promptCode, sessionId), AiChatMemoryDto.class); |
| | | } |
| | | |
| | | public void cacheMemory(Long tenantId, Long userId, String promptCode, Long sessionId, AiChatMemoryDto memory) { |
| | | String key = aiRedisKeys.buildMemoryKey(tenantId, userId, promptCode, sessionId); |
| | | aiRedisExecutor.writeJson(key, memory, AiDefaults.MEMORY_CACHE_TTL_SECONDS); |
| | | rememberConversationKey(tenantId, userId, key); |
| | | } |
| | | |
| | | public List<AiChatSessionDto> getSessionList(Long tenantId, Long userId, String promptCode, String keyword) { |
| | | return aiRedisExecutor.readJson(aiRedisKeys.buildSessionsKey(tenantId, userId, promptCode, keyword), new TypeReference<List<AiChatSessionDto>>() { |
| | | }); |
| | | } |
| | | |
| | | public void cacheSessionList(Long tenantId, Long userId, String promptCode, String keyword, List<AiChatSessionDto> sessions) { |
| | | String key = aiRedisKeys.buildSessionsKey(tenantId, userId, promptCode, keyword); |
| | | aiRedisExecutor.writeJson(key, sessions, AiDefaults.SESSION_LIST_CACHE_TTL_SECONDS); |
| | | rememberConversationKey(tenantId, userId, key); |
| | | } |
| | | |
| | | public void evictUserConversationCaches(Long tenantId, Long userId) { |
| | | aiRedisIndexSupport.deleteTrackedKeys(aiRedisKeys.buildConversationIndexKey(tenantId, userId)); |
| | | } |
| | | |
| | | public void evictTenantRuntimeCaches(Long tenantId) { |
| | | aiRedisIndexSupport.deleteTrackedKeys(aiRedisKeys.buildTenantRuntimeIndexKey(tenantId)); |
| | | } |
| | | |
| | | private void rememberConversationKey(Long tenantId, Long userId, String key) { |
| | | aiRedisIndexSupport.remember(aiRedisKeys.buildConversationIndexKey(tenantId, userId), key); |
| | | } |
| | | } |
| New file |
| | |
| | | package com.vincent.rsf.server.ai.store; |
| | | |
| | | import com.fasterxml.jackson.core.type.TypeReference; |
| | | import com.vincent.rsf.server.ai.config.AiDefaults; |
| | | import com.vincent.rsf.server.ai.dto.AiMcpConnectivityTestDto; |
| | | import com.vincent.rsf.server.ai.dto.AiMcpToolPreviewDto; |
| | | import com.vincent.rsf.server.ai.store.support.AiRedisExecutor; |
| | | import com.vincent.rsf.server.ai.store.support.AiRedisKeys; |
| | | import lombok.RequiredArgsConstructor; |
| | | import org.springframework.stereotype.Component; |
| | | |
| | | import java.util.List; |
| | | |
| | | @Component |
| | | @RequiredArgsConstructor |
| | | public class AiMcpCacheStore { |
| | | |
| | | private final AiRedisExecutor aiRedisExecutor; |
| | | private final AiRedisKeys aiRedisKeys; |
| | | |
| | | public List<AiMcpToolPreviewDto> getToolPreview(Long tenantId, Long mountId) { |
| | | return aiRedisExecutor.readJson(aiRedisKeys.buildMcpPreviewKey(tenantId, mountId), new TypeReference<List<AiMcpToolPreviewDto>>() { |
| | | }); |
| | | } |
| | | |
| | | public void cacheToolPreview(Long tenantId, Long mountId, List<AiMcpToolPreviewDto> tools) { |
| | | aiRedisExecutor.writeJson(aiRedisKeys.buildMcpPreviewKey(tenantId, mountId), tools, AiDefaults.MCP_PREVIEW_CACHE_TTL_SECONDS); |
| | | } |
| | | |
| | | public AiMcpConnectivityTestDto getConnectivity(Long tenantId, Long mountId) { |
| | | return aiRedisExecutor.readJson(aiRedisKeys.buildMcpHealthKey(tenantId, mountId), AiMcpConnectivityTestDto.class); |
| | | } |
| | | |
| | | public void cacheConnectivity(Long tenantId, Long mountId, AiMcpConnectivityTestDto connectivity) { |
| | | aiRedisExecutor.writeJson(aiRedisKeys.buildMcpHealthKey(tenantId, mountId), connectivity, AiDefaults.MCP_HEALTH_CACHE_TTL_SECONDS); |
| | | } |
| | | |
| | | public void evictMcpMountCaches(Long tenantId, Long mountId) { |
| | | if (mountId == null) { |
| | | return; |
| | | } |
| | | aiRedisExecutor.delete(aiRedisKeys.buildMcpPreviewKey(tenantId, mountId)); |
| | | aiRedisExecutor.delete(aiRedisKeys.buildMcpHealthKey(tenantId, mountId)); |
| | | } |
| | | } |
| New file |
| | |
| | | package com.vincent.rsf.server.ai.store; |
| | | |
| | | import com.vincent.rsf.server.ai.dto.AiObserveStatsDto; |
| | | import com.vincent.rsf.server.ai.store.support.AiRedisExecutor; |
| | | import com.vincent.rsf.server.ai.store.support.AiRedisKeys; |
| | | import lombok.RequiredArgsConstructor; |
| | | import org.springframework.stereotype.Component; |
| | | import org.springframework.util.StringUtils; |
| | | |
| | | import java.util.LinkedHashMap; |
| | | import java.util.Map; |
| | | import java.util.function.Supplier; |
| | | |
| | | @Component |
| | | @RequiredArgsConstructor |
| | | public class AiObserveStatsStore { |
| | | |
| | | private static final String FIELD_CALL_COUNT = "callCount"; |
| | | private static final String FIELD_SUCCESS_COUNT = "successCount"; |
| | | private static final String FIELD_FAILURE_COUNT = "failureCount"; |
| | | private static final String FIELD_ELAPSED_SUM = "elapsedSum"; |
| | | private static final String FIELD_ELAPSED_COUNT = "elapsedCount"; |
| | | private static final String FIELD_FIRST_TOKEN_SUM = "firstTokenSum"; |
| | | private static final String FIELD_FIRST_TOKEN_COUNT = "firstTokenCount"; |
| | | private static final String FIELD_TOTAL_TOKENS_SUM = "totalTokensSum"; |
| | | private static final String FIELD_TOTAL_TOKENS_COUNT = "totalTokensCount"; |
| | | private static final String FIELD_TOOL_CALL_COUNT = "toolCallCount"; |
| | | private static final String FIELD_TOOL_SUCCESS_COUNT = "toolSuccessCount"; |
| | | private static final String FIELD_TOOL_FAILURE_COUNT = "toolFailureCount"; |
| | | |
| | | private final AiRedisExecutor aiRedisExecutor; |
| | | private final AiRedisKeys aiRedisKeys; |
| | | |
| | | public void recordObserveCallStarted(Long tenantId) { |
| | | aiRedisExecutor.executeVoid(jedis -> jedis.hincrBy(aiRedisKeys.buildObserveStatsKey(tenantId), FIELD_CALL_COUNT, 1)); |
| | | } |
| | | |
| | | public void recordObserveCallFinished(Long tenantId, String status, Long elapsedMs, Long firstTokenLatencyMs, Integer totalTokens) { |
| | | aiRedisExecutor.executeVoid(jedis -> { |
| | | String key = aiRedisKeys.buildObserveStatsKey(tenantId); |
| | | if ("COMPLETED".equals(status)) { |
| | | jedis.hincrBy(key, FIELD_SUCCESS_COUNT, 1); |
| | | } else if ("FAILED".equals(status)) { |
| | | jedis.hincrBy(key, FIELD_FAILURE_COUNT, 1); |
| | | } |
| | | if (elapsedMs != null) { |
| | | jedis.hincrBy(key, FIELD_ELAPSED_SUM, elapsedMs); |
| | | jedis.hincrBy(key, FIELD_ELAPSED_COUNT, 1); |
| | | } |
| | | if (firstTokenLatencyMs != null) { |
| | | jedis.hincrBy(key, FIELD_FIRST_TOKEN_SUM, firstTokenLatencyMs); |
| | | jedis.hincrBy(key, FIELD_FIRST_TOKEN_COUNT, 1); |
| | | } |
| | | if (totalTokens != null) { |
| | | jedis.hincrBy(key, FIELD_TOTAL_TOKENS_SUM, totalTokens.longValue()); |
| | | jedis.hincrBy(key, FIELD_TOTAL_TOKENS_COUNT, 1); |
| | | } |
| | | }); |
| | | } |
| | | |
| | | public void recordObserveToolCall(Long tenantId, String toolName, String status) { |
| | | aiRedisExecutor.executeVoid(jedis -> { |
| | | String key = aiRedisKeys.buildObserveStatsKey(tenantId); |
| | | jedis.hincrBy(key, FIELD_TOOL_CALL_COUNT, 1); |
| | | if ("COMPLETED".equals(status)) { |
| | | jedis.hincrBy(key, FIELD_TOOL_SUCCESS_COUNT, 1); |
| | | } else if ("FAILED".equals(status)) { |
| | | jedis.hincrBy(key, FIELD_TOOL_FAILURE_COUNT, 1); |
| | | } |
| | | if (StringUtils.hasText(toolName)) { |
| | | jedis.zincrby(aiRedisKeys.buildToolRankKey(tenantId), 1D, toolName); |
| | | if ("FAILED".equals(status)) { |
| | | jedis.zincrby(aiRedisKeys.buildToolFailRankKey(tenantId), 1D, toolName); |
| | | } |
| | | } |
| | | }); |
| | | } |
| | | |
| | | public AiObserveStatsDto getObserveStats(Long tenantId, Supplier<AiObserveStatsDto> fallbackLoader) { |
| | | AiObserveStatsDto cached = readObserveStats(tenantId); |
| | | if (cached != null) { |
| | | return cached; |
| | | } |
| | | AiObserveStatsDto snapshot = fallbackLoader.get(); |
| | | if (snapshot != null) { |
| | | seedObserveStats(tenantId, snapshot); |
| | | } |
| | | return snapshot; |
| | | } |
| | | |
| | | private AiObserveStatsDto readObserveStats(Long tenantId) { |
| | | Map<String, String> fields = aiRedisExecutor.execute(jedis -> { |
| | | String key = aiRedisKeys.buildObserveStatsKey(tenantId); |
| | | if (!jedis.exists(key)) { |
| | | return null; |
| | | } |
| | | return jedis.hgetAll(key); |
| | | }); |
| | | if (fields == null || fields.isEmpty()) { |
| | | return null; |
| | | } |
| | | long callCount = parseLong(fields.get(FIELD_CALL_COUNT)); |
| | | long successCount = parseLong(fields.get(FIELD_SUCCESS_COUNT)); |
| | | long failureCount = parseLong(fields.get(FIELD_FAILURE_COUNT)); |
| | | long elapsedSum = parseLong(fields.get(FIELD_ELAPSED_SUM)); |
| | | long elapsedCount = parseLong(fields.get(FIELD_ELAPSED_COUNT)); |
| | | long firstTokenSum = parseLong(fields.get(FIELD_FIRST_TOKEN_SUM)); |
| | | long firstTokenCount = parseLong(fields.get(FIELD_FIRST_TOKEN_COUNT)); |
| | | long totalTokensSum = parseLong(fields.get(FIELD_TOTAL_TOKENS_SUM)); |
| | | long totalTokensCount = parseLong(fields.get(FIELD_TOTAL_TOKENS_COUNT)); |
| | | long toolCallCount = parseLong(fields.get(FIELD_TOOL_CALL_COUNT)); |
| | | long toolSuccessCount = parseLong(fields.get(FIELD_TOOL_SUCCESS_COUNT)); |
| | | long toolFailureCount = parseLong(fields.get(FIELD_TOOL_FAILURE_COUNT)); |
| | | return AiObserveStatsDto.builder() |
| | | .callCount(callCount) |
| | | .successCount(successCount) |
| | | .failureCount(failureCount) |
| | | .avgElapsedMs(elapsedCount == 0 ? 0L : elapsedSum / elapsedCount) |
| | | .avgFirstTokenLatencyMs(firstTokenCount == 0 ? 0L : firstTokenSum / firstTokenCount) |
| | | .totalTokens(totalTokensSum) |
| | | .avgTotalTokens(totalTokensCount == 0 ? 0L : totalTokensSum / totalTokensCount) |
| | | .toolCallCount(toolCallCount) |
| | | .toolSuccessCount(toolSuccessCount) |
| | | .toolFailureCount(toolFailureCount) |
| | | .toolSuccessRate(toolCallCount == 0 ? 0D : (toolSuccessCount * 100D) / toolCallCount) |
| | | .build(); |
| | | } |
| | | |
| | | private void seedObserveStats(Long tenantId, AiObserveStatsDto snapshot) { |
| | | aiRedisExecutor.executeVoid(jedis -> { |
| | | String key = aiRedisKeys.buildObserveStatsKey(tenantId); |
| | | Map<String, String> values = new LinkedHashMap<>(); |
| | | values.put(FIELD_CALL_COUNT, String.valueOf(defaultLong(snapshot.getCallCount()))); |
| | | values.put(FIELD_SUCCESS_COUNT, String.valueOf(defaultLong(snapshot.getSuccessCount()))); |
| | | values.put(FIELD_FAILURE_COUNT, String.valueOf(defaultLong(snapshot.getFailureCount()))); |
| | | values.put(FIELD_ELAPSED_SUM, String.valueOf(defaultLong(snapshot.getAvgElapsedMs()) * defaultLong(snapshot.getCallCount()))); |
| | | values.put(FIELD_ELAPSED_COUNT, String.valueOf(defaultLong(snapshot.getCallCount()))); |
| | | values.put(FIELD_FIRST_TOKEN_SUM, String.valueOf(defaultLong(snapshot.getAvgFirstTokenLatencyMs()) * defaultLong(snapshot.getCallCount()))); |
| | | values.put(FIELD_FIRST_TOKEN_COUNT, String.valueOf(defaultLong(snapshot.getCallCount()))); |
| | | values.put(FIELD_TOTAL_TOKENS_SUM, String.valueOf(defaultLong(snapshot.getTotalTokens()))); |
| | | values.put(FIELD_TOTAL_TOKENS_COUNT, String.valueOf(defaultLong(snapshot.getCallCount()))); |
| | | values.put(FIELD_TOOL_CALL_COUNT, String.valueOf(defaultLong(snapshot.getToolCallCount()))); |
| | | values.put(FIELD_TOOL_SUCCESS_COUNT, String.valueOf(defaultLong(snapshot.getToolSuccessCount()))); |
| | | values.put(FIELD_TOOL_FAILURE_COUNT, String.valueOf(defaultLong(snapshot.getToolFailureCount()))); |
| | | jedis.hset(key, values); |
| | | }); |
| | | } |
| | | |
| | | private long parseLong(String source) { |
| | | if (!StringUtils.hasText(source)) { |
| | | return 0L; |
| | | } |
| | | try { |
| | | return Long.parseLong(source); |
| | | } catch (Exception e) { |
| | | return 0L; |
| | | } |
| | | } |
| | | |
| | | private long defaultLong(Long value) { |
| | | return value == null ? 0L : value; |
| | | } |
| | | } |
| New file |
| | |
| | | package com.vincent.rsf.server.ai.store; |
| | | |
| | | import com.vincent.rsf.server.ai.config.AiDefaults; |
| | | import com.vincent.rsf.server.ai.store.support.AiRedisExecutor; |
| | | import com.vincent.rsf.server.ai.store.support.AiRedisKeys; |
| | | import lombok.AllArgsConstructor; |
| | | import lombok.Builder; |
| | | import lombok.Data; |
| | | import lombok.NoArgsConstructor; |
| | | import lombok.RequiredArgsConstructor; |
| | | import org.springframework.stereotype.Component; |
| | | import org.springframework.util.StringUtils; |
| | | |
| | | import java.time.Instant; |
| | | |
| | | @Component |
| | | @RequiredArgsConstructor |
| | | public class AiStreamStateStore { |
| | | |
| | | private final AiRedisExecutor aiRedisExecutor; |
| | | private final AiRedisKeys aiRedisKeys; |
| | | |
| | | public void markStreamState(String requestId, Long tenantId, Long userId, Long sessionId, String promptCode, |
| | | String status, String errorMessage) { |
| | | if (!StringUtils.hasText(requestId)) { |
| | | return; |
| | | } |
| | | aiRedisExecutor.writeJson(aiRedisKeys.buildStreamStateKey(tenantId, requestId), AiStreamState.builder() |
| | | .requestId(requestId) |
| | | .tenantId(tenantId) |
| | | .userId(userId) |
| | | .sessionId(sessionId) |
| | | .promptCode(promptCode) |
| | | .status(status) |
| | | .errorMessage(errorMessage) |
| | | .timestamp(Instant.now().toEpochMilli()) |
| | | .build(), AiDefaults.STREAM_STATE_TTL_SECONDS); |
| | | } |
| | | |
| | | @Data |
| | | @Builder |
| | | @NoArgsConstructor |
| | | @AllArgsConstructor |
| | | private static class AiStreamState { |
| | | |
| | | private String requestId; |
| | | |
| | | private Long tenantId; |
| | | |
| | | private Long userId; |
| | | |
| | | private Long sessionId; |
| | | |
| | | private String promptCode; |
| | | |
| | | private String status; |
| | | |
| | | private String errorMessage; |
| | | |
| | | private Long timestamp; |
| | | } |
| | | } |
| New file |
| | |
| | | package com.vincent.rsf.server.ai.store; |
| | | |
| | | import com.vincent.rsf.server.ai.config.AiDefaults; |
| | | import com.vincent.rsf.server.ai.store.support.AiRedisExecutor; |
| | | import com.vincent.rsf.server.ai.store.support.AiRedisKeys; |
| | | import lombok.RequiredArgsConstructor; |
| | | import org.springframework.stereotype.Component; |
| | | |
| | | @Component |
| | | @RequiredArgsConstructor |
| | | public class AiToolResultStore { |
| | | |
| | | private final AiRedisExecutor aiRedisExecutor; |
| | | private final AiRedisKeys aiRedisKeys; |
| | | |
| | | public AiCachedToolResult getToolResult(Long tenantId, String requestId, String toolName, String toolInput) { |
| | | return aiRedisExecutor.readJson(aiRedisKeys.buildToolResultKey(tenantId, requestId, toolName, toolInput), AiCachedToolResult.class); |
| | | } |
| | | |
| | | public void cacheToolResult(Long tenantId, String requestId, String toolName, String toolInput, |
| | | boolean success, String output, String errorMessage) { |
| | | aiRedisExecutor.writeJson(aiRedisKeys.buildToolResultKey(tenantId, requestId, toolName, toolInput), |
| | | AiCachedToolResult.builder() |
| | | .success(success) |
| | | .output(output) |
| | | .errorMessage(errorMessage) |
| | | .build(), |
| | | AiDefaults.TOOL_RESULT_CACHE_TTL_SECONDS); |
| | | } |
| | | } |
| New file |
| | |
| | | package com.vincent.rsf.server.ai.store.support; |
| | | |
| | | import com.fasterxml.jackson.core.type.TypeReference; |
| | | import com.fasterxml.jackson.databind.ObjectMapper; |
| | | import com.vincent.rsf.server.common.service.RedisService; |
| | | import lombok.RequiredArgsConstructor; |
| | | import lombok.extern.slf4j.Slf4j; |
| | | import org.springframework.stereotype.Component; |
| | | import org.springframework.util.StringUtils; |
| | | import redis.clients.jedis.Jedis; |
| | | |
| | | import java.util.function.Consumer; |
| | | import java.util.function.Function; |
| | | |
| | | @Slf4j |
| | | @Component |
| | | @RequiredArgsConstructor |
| | | public class AiRedisExecutor { |
| | | |
| | | private final RedisService redisService; |
| | | private final ObjectMapper objectMapper; |
| | | |
| | | public <T> T readJson(String key, Class<T> type) { |
| | | return readJson(key, value -> objectMapper.readValue(value, type)); |
| | | } |
| | | |
| | | public <T> T readJson(String key, TypeReference<T> typeReference) { |
| | | return readJson(key, value -> objectMapper.readValue(value, typeReference)); |
| | | } |
| | | |
| | | public <T> T readJson(String key, JsonReader<T> reader) { |
| | | return execute(jedis -> { |
| | | String value = jedis.get(key); |
| | | if (!StringUtils.hasText(value)) { |
| | | return null; |
| | | } |
| | | try { |
| | | return reader.read(value); |
| | | } catch (Exception e) { |
| | | log.warn("AI redis cache deserialize failed, key={}, message={}", key, e.getMessage()); |
| | | jedis.del(key); |
| | | return null; |
| | | } |
| | | }); |
| | | } |
| | | |
| | | public void writeJson(String key, Object value, int ttlSeconds) { |
| | | if (value == null) { |
| | | delete(key); |
| | | return; |
| | | } |
| | | executeVoid(jedis -> { |
| | | try { |
| | | jedis.setex(key, ttlSeconds, objectMapper.writeValueAsString(value)); |
| | | } catch (Exception e) { |
| | | log.warn("AI redis cache serialize failed, key={}, message={}", key, e.getMessage()); |
| | | } |
| | | }); |
| | | } |
| | | |
| | | public void delete(String key) { |
| | | executeVoid(jedis -> jedis.del(key)); |
| | | } |
| | | |
| | | public <T> T execute(Function<Jedis, T> action) { |
| | | Jedis jedis = redisService.getJedis(); |
| | | if (jedis == null) { |
| | | return null; |
| | | } |
| | | try (jedis) { |
| | | return action.apply(jedis); |
| | | } catch (Exception e) { |
| | | log.warn("AI redis operation skipped, message={}", e.getMessage()); |
| | | return null; |
| | | } |
| | | } |
| | | |
| | | public void executeVoid(Consumer<Jedis> action) { |
| | | Jedis jedis = redisService.getJedis(); |
| | | if (jedis == null) { |
| | | return; |
| | | } |
| | | try (jedis) { |
| | | action.accept(jedis); |
| | | } catch (Exception e) { |
| | | log.warn("AI redis operation skipped, message={}", e.getMessage()); |
| | | } |
| | | } |
| | | |
| | | @FunctionalInterface |
| | | public interface JsonReader<T> { |
| | | T read(String value) throws Exception; |
| | | } |
| | | } |
| New file |
| | |
| | | package com.vincent.rsf.server.ai.store.support; |
| | | |
| | | import lombok.RequiredArgsConstructor; |
| | | import org.springframework.stereotype.Component; |
| | | |
| | | import java.util.Set; |
| | | |
| | | @Component |
| | | @RequiredArgsConstructor |
| | | public class AiRedisIndexSupport { |
| | | |
| | | private final AiRedisExecutor aiRedisExecutor; |
| | | |
| | | public void remember(String indexKey, String cacheKey) { |
| | | if (indexKey == null || cacheKey == null) { |
| | | return; |
| | | } |
| | | aiRedisExecutor.executeVoid(jedis -> jedis.sadd(indexKey, cacheKey)); |
| | | } |
| | | |
| | | public void forget(String indexKey, String cacheKey) { |
| | | if (indexKey == null || cacheKey == null) { |
| | | return; |
| | | } |
| | | aiRedisExecutor.executeVoid(jedis -> jedis.srem(indexKey, cacheKey)); |
| | | } |
| | | |
| | | public void deleteTrackedKeys(String indexKey) { |
| | | if (indexKey == null) { |
| | | return; |
| | | } |
| | | aiRedisExecutor.executeVoid(jedis -> { |
| | | Set<String> keys = jedis.smembers(indexKey); |
| | | if (keys != null && !keys.isEmpty()) { |
| | | jedis.del(keys.toArray(new String[0])); |
| | | } |
| | | jedis.del(indexKey); |
| | | }); |
| | | } |
| | | } |
| New file |
| | |
| | | package com.vincent.rsf.server.ai.store.support; |
| | | |
| | | import org.springframework.stereotype.Component; |
| | | import org.springframework.util.StringUtils; |
| | | |
| | | import java.net.URLEncoder; |
| | | import java.nio.charset.StandardCharsets; |
| | | import java.security.MessageDigest; |
| | | |
| | | @Component |
| | | public class AiRedisKeys { |
| | | |
| | | private static final String CONFIG_KEY_PREFIX = "AI:CONFIG:"; |
| | | private static final String RUNTIME_KEY_PREFIX = "AI:RUNTIME:"; |
| | | private static final String MEMORY_KEY_PREFIX = "AI:MEMORY:"; |
| | | private static final String SESSIONS_KEY_PREFIX = "AI:SESSIONS:"; |
| | | private static final String MCP_PREVIEW_KEY_PREFIX = "AI:MCP:PREVIEW:"; |
| | | private static final String MCP_HEALTH_KEY_PREFIX = "AI:MCP:HEALTH:"; |
| | | private static final String STREAM_STATE_KEY_PREFIX = "AI:STREAM:"; |
| | | private static final String TOOL_RESULT_KEY_PREFIX = "AI:TOOL:RESULT:"; |
| | | private static final String RATE_LIMIT_KEY_PREFIX = "AI:RATE:"; |
| | | private static final String OBSERVE_STATS_KEY_PREFIX = "AI:OBSERVE:STATS:"; |
| | | private static final String OBSERVE_TOOL_RANK_KEY_PREFIX = "AI:OBSERVE:TOOL:RANK:"; |
| | | private static final String OBSERVE_TOOL_FAIL_RANK_KEY_PREFIX = "AI:OBSERVE:TOOL:FAIL:RANK:"; |
| | | private static final String CONFIG_INDEX_PREFIX = "AI:IDX:CONFIG:"; |
| | | private static final String CONVERSATION_INDEX_PREFIX = "AI:IDX:CONVERSATION:"; |
| | | private static final String TENANT_RUNTIME_INDEX_PREFIX = "AI:IDX:RUNTIME:TENANT:"; |
| | | |
| | | public String buildConfigKey(Long tenantId, String promptCode, Long aiParamId) { |
| | | return CONFIG_KEY_PREFIX + tenantId + ":" + safeToken(promptCode) + ":" + aiParamToken(aiParamId); |
| | | } |
| | | |
| | | public String buildRuntimeKey(Long tenantId, Long userId, String promptCode, Long sessionId, Long aiParamId) { |
| | | return RUNTIME_KEY_PREFIX + tenantId + ":" + userId + ":" + safeToken(promptCode) + ":" + sessionToken(sessionId) + ":" + aiParamToken(aiParamId); |
| | | } |
| | | |
| | | public String buildMemoryKey(Long tenantId, Long userId, String promptCode, Long sessionId) { |
| | | return MEMORY_KEY_PREFIX + tenantId + ":" + userId + ":" + safeToken(promptCode) + ":" + sessionToken(sessionId); |
| | | } |
| | | |
| | | public String buildSessionsKey(Long tenantId, Long userId, String promptCode, String keyword) { |
| | | return SESSIONS_KEY_PREFIX + tenantId + ":" + userId + ":" + safeToken(promptCode) + ":" + safeToken(keyword); |
| | | } |
| | | |
| | | public String buildMcpPreviewKey(Long tenantId, Long mountId) { |
| | | return MCP_PREVIEW_KEY_PREFIX + tenantId + ":" + mountId; |
| | | } |
| | | |
| | | public String buildMcpHealthKey(Long tenantId, Long mountId) { |
| | | return MCP_HEALTH_KEY_PREFIX + tenantId + ":" + mountId; |
| | | } |
| | | |
| | | public String buildStreamStateKey(Long tenantId, String requestId) { |
| | | return STREAM_STATE_KEY_PREFIX + tenantId + ":" + safeToken(requestId); |
| | | } |
| | | |
| | | public String buildToolResultKey(Long tenantId, String requestId, String toolName, String toolInput) { |
| | | return TOOL_RESULT_KEY_PREFIX + tenantId + ":" + safeToken(requestId) + ":" + safeToken(toolName) + ":" + digest(toolInput); |
| | | } |
| | | |
| | | public String buildRateLimitKey(Long tenantId, Long userId, String promptCode) { |
| | | return RATE_LIMIT_KEY_PREFIX + tenantId + ":" + userId + ":" + safeToken(promptCode); |
| | | } |
| | | |
| | | public String buildObserveStatsKey(Long tenantId) { |
| | | return OBSERVE_STATS_KEY_PREFIX + tenantId; |
| | | } |
| | | |
| | | public String buildToolRankKey(Long tenantId) { |
| | | return OBSERVE_TOOL_RANK_KEY_PREFIX + tenantId; |
| | | } |
| | | |
| | | public String buildToolFailRankKey(Long tenantId) { |
| | | return OBSERVE_TOOL_FAIL_RANK_KEY_PREFIX + tenantId; |
| | | } |
| | | |
| | | public String buildConfigIndexKey(Long tenantId) { |
| | | return CONFIG_INDEX_PREFIX + tenantId; |
| | | } |
| | | |
| | | public String buildConversationIndexKey(Long tenantId, Long userId) { |
| | | return CONVERSATION_INDEX_PREFIX + tenantId + ":" + userId; |
| | | } |
| | | |
| | | public String buildTenantRuntimeIndexKey(Long tenantId) { |
| | | return TENANT_RUNTIME_INDEX_PREFIX + tenantId; |
| | | } |
| | | |
| | | public String sessionToken(Long sessionId) { |
| | | return sessionId == null ? "LATEST" : String.valueOf(sessionId); |
| | | } |
| | | |
| | | public String aiParamToken(Long aiParamId) { |
| | | return aiParamId == null ? "DEFAULT" : String.valueOf(aiParamId); |
| | | } |
| | | |
| | | public String safeToken(String source) { |
| | | if (!StringUtils.hasText(source)) { |
| | | return "_"; |
| | | } |
| | | return URLEncoder.encode(source.trim(), StandardCharsets.UTF_8); |
| | | } |
| | | |
| | | public String digest(String source) { |
| | | try { |
| | | MessageDigest messageDigest = MessageDigest.getInstance("SHA-256"); |
| | | byte[] bytes = messageDigest.digest((source == null ? "" : source).getBytes(StandardCharsets.UTF_8)); |
| | | StringBuilder builder = new StringBuilder(); |
| | | for (byte value : bytes) { |
| | | builder.append(String.format("%02x", value)); |
| | | } |
| | | return builder.toString(); |
| | | } catch (Exception e) { |
| | | return safeToken(source); |
| | | } |
| | | } |
| | | } |
| New file |
| | |
| | | <?xml version="1.0" encoding="UTF-8"?> |
| | | <!DOCTYPE mapper |
| | | PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" |
| | | "http://mybatis.org/dtd/mybatis-3-mapper.dtd"> |
| | | <mapper namespace="com.vincent.rsf.server.ai.mapper.AiChatMessageMapper"> |
| | | |
| | | <insert id="insertBatch"> |
| | | insert into sys_ai_chat_message |
| | | (session_id, seq_no, role, content, content_length, user_id, tenant_id, deleted, create_by, create_time) |
| | | values |
| | | <foreach collection="list" item="item" separator=","> |
| | | (#{item.sessionId}, #{item.seqNo}, #{item.role}, #{item.content}, #{item.contentLength}, |
| | | #{item.userId}, #{item.tenantId}, #{item.deleted}, #{item.createBy}, #{item.createTime}) |
| | | </foreach> |
| | | </insert> |
| | | |
| | | <update id="softDeleteByIds"> |
| | | update sys_ai_chat_message |
| | | set deleted = 1 |
| | | where deleted = 0 |
| | | and id in |
| | | <foreach collection="ids" item="id" open="(" separator="," close=")"> |
| | | #{id} |
| | | </foreach> |
| | | </update> |
| | | |
| | | <update id="softDeleteBySessionId"> |
| | | update sys_ai_chat_message |
| | | set deleted = 1 |
| | | where session_id = #{sessionId} |
| | | and deleted = 0 |
| | | </update> |
| | | |
| | | <select id="selectLatestMessagesBySessionIds" resultType="com.vincent.rsf.server.ai.entity.AiChatMessage"> |
| | | select m.id, |
| | | m.session_id, |
| | | m.seq_no, |
| | | m.role, |
| | | m.content, |
| | | m.content_length, |
| | | m.user_id, |
| | | m.tenant_id, |
| | | m.deleted, |
| | | m.create_by, |
| | | m.create_time |
| | | from sys_ai_chat_message m |
| | | inner join ( |
| | | select latest.session_id, max(candidate.id) as max_id |
| | | from sys_ai_chat_message candidate |
| | | inner join ( |
| | | select session_id, max(seq_no) as max_seq_no |
| | | from sys_ai_chat_message |
| | | where deleted = 0 |
| | | and session_id in |
| | | <foreach collection="sessionIds" item="sessionId" open="(" separator="," close=")"> |
| | | #{sessionId} |
| | | </foreach> |
| | | group by session_id |
| | | ) latest on latest.session_id = candidate.session_id |
| | | and latest.max_seq_no = candidate.seq_no |
| | | where candidate.deleted = 0 |
| | | group by latest.session_id |
| | | ) latest on latest.max_id = m.id |
| | | where m.deleted = 0 |
| | | </select> |
| | | </mapper> |
| New file |
| | |
| | | -- AI module performance indexes |
| | | -- MySQL repeatable script: create each index only when the same index name does not already exist. |
| | | -- If an existing index uses the same name but different columns, compare it manually before changing it. |
| | | |
| | | set @db_name = database(); |
| | | |
| | | set @sql = if ( |
| | | exists ( |
| | | select 1 |
| | | from information_schema.statistics |
| | | where table_schema = @db_name |
| | | and table_name = 'sys_ai_chat_session' |
| | | and index_name = 'idx_sys_ai_chat_session_user_prompt_active' |
| | | ), |
| | | 'select ''skip idx_sys_ai_chat_session_user_prompt_active''', |
| | | 'create index idx_sys_ai_chat_session_user_prompt_active on sys_ai_chat_session (tenant_id, user_id, prompt_code, deleted, status, pinned, last_message_time, id)' |
| | | ); |
| | | prepare stmt from @sql; |
| | | execute stmt; |
| | | deallocate prepare stmt; |
| | | |
| | | set @sql = if ( |
| | | exists ( |
| | | select 1 |
| | | from information_schema.statistics |
| | | where table_schema = @db_name |
| | | and table_name = 'sys_ai_chat_message' |
| | | and index_name = 'idx_sys_ai_chat_message_session_active_seq' |
| | | ), |
| | | 'select ''skip idx_sys_ai_chat_message_session_active_seq''', |
| | | 'create index idx_sys_ai_chat_message_session_active_seq on sys_ai_chat_message (session_id, deleted, seq_no, id)' |
| | | ); |
| | | prepare stmt from @sql; |
| | | execute stmt; |
| | | deallocate prepare stmt; |
| | | |
| | | set @sql = if ( |
| | | exists ( |
| | | select 1 |
| | | from information_schema.statistics |
| | | where table_schema = @db_name |
| | | and table_name = 'sys_ai_call_log' |
| | | and index_name = 'idx_sys_ai_call_log_tenant_status' |
| | | ), |
| | | 'select ''skip idx_sys_ai_call_log_tenant_status''', |
| | | 'create index idx_sys_ai_call_log_tenant_status on sys_ai_call_log (tenant_id, deleted, status, id)' |
| | | ); |
| | | prepare stmt from @sql; |
| | | execute stmt; |
| | | deallocate prepare stmt; |
| | | |
| | | set @sql = if ( |
| | | exists ( |
| | | select 1 |
| | | from information_schema.statistics |
| | | where table_schema = @db_name |
| | | and table_name = 'sys_ai_mcp_call_log' |
| | | and index_name = 'idx_sys_ai_mcp_call_log_call' |
| | | ), |
| | | 'select ''skip idx_sys_ai_mcp_call_log_call''', |
| | | 'create index idx_sys_ai_mcp_call_log_call on sys_ai_mcp_call_log (call_log_id, tenant_id, id)' |
| | | ); |
| | | prepare stmt from @sql; |
| | | execute stmt; |
| | | deallocate prepare stmt; |
| | | |
| | | set @sql = if ( |
| | | exists ( |
| | | select 1 |
| | | from information_schema.statistics |
| | | where table_schema = @db_name |
| | | and table_name = 'sys_ai_mcp_call_log' |
| | | and index_name = 'idx_sys_ai_mcp_call_log_tenant_status' |
| | | ), |
| | | 'select ''skip idx_sys_ai_mcp_call_log_tenant_status''', |
| | | 'create index idx_sys_ai_mcp_call_log_tenant_status on sys_ai_mcp_call_log (tenant_id, status, id)' |
| | | ); |
| | | prepare stmt from @sql; |
| | | execute stmt; |
| | | deallocate prepare stmt; |
| | | |
| | | -- Deployment note: |
| | | -- Clear legacy AI cache keys before the new version takes traffic to avoid stale payload compatibility issues. |
| New file |
| | |
| | | package com.vincent.rsf.server.ai.mapper; |
| | | |
| | | import com.baomidou.mybatisplus.core.MybatisConfiguration; |
| | | import com.baomidou.mybatisplus.extension.spring.MybatisSqlSessionFactoryBean; |
| | | import com.vincent.rsf.server.ai.entity.AiChatMessage; |
| | | import org.junit.jupiter.api.BeforeEach; |
| | | import org.junit.jupiter.api.Test; |
| | | import org.junit.jupiter.api.extension.ExtendWith; |
| | | import org.springframework.beans.factory.annotation.Autowired; |
| | | import org.springframework.context.annotation.Bean; |
| | | import org.springframework.context.annotation.Configuration; |
| | | import org.springframework.core.io.support.PathMatchingResourcePatternResolver; |
| | | import org.springframework.jdbc.core.JdbcTemplate; |
| | | import org.springframework.jdbc.datasource.DriverManagerDataSource; |
| | | import org.springframework.test.context.ContextConfiguration; |
| | | import org.springframework.test.context.junit.jupiter.SpringExtension; |
| | | import org.mybatis.spring.annotation.MapperScan; |
| | | |
| | | import javax.sql.DataSource; |
| | | import java.util.Date; |
| | | import java.util.List; |
| | | import java.util.stream.Collectors; |
| | | |
| | | import static org.junit.jupiter.api.Assertions.assertEquals; |
| | | |
| | | @ExtendWith(SpringExtension.class) |
| | | @ContextConfiguration(classes = AiChatMessageMapperIntegrationTest.TestConfig.class) |
| | | class AiChatMessageMapperIntegrationTest { |
| | | |
| | | @Autowired |
| | | private AiChatMessageMapper aiChatMessageMapper; |
| | | |
| | | @Autowired |
| | | private JdbcTemplate jdbcTemplate; |
| | | |
| | | @BeforeEach |
| | | void setUpSchema() { |
| | | jdbcTemplate.execute("drop table if exists sys_ai_chat_message"); |
| | | jdbcTemplate.execute(""" |
| | | create table sys_ai_chat_message ( |
| | | id bigint auto_increment primary key, |
| | | session_id bigint, |
| | | seq_no integer, |
| | | role varchar(32), |
| | | content varchar(1000), |
| | | content_length integer, |
| | | user_id bigint, |
| | | tenant_id bigint, |
| | | deleted integer, |
| | | create_by bigint, |
| | | create_time timestamp |
| | | ) |
| | | """); |
| | | } |
| | | |
| | | @Test |
| | | void shouldInsertSelectLatestAndSoftDeleteMessages() { |
| | | Date now = new Date(); |
| | | aiChatMessageMapper.insertBatch(List.of( |
| | | message(10L, 1, "user", "hello", now), |
| | | message(10L, 2, "assistant", "world", now), |
| | | message(10L, 2, "assistant", "world latest", now), |
| | | message(20L, 1, "user", "other", now) |
| | | )); |
| | | |
| | | List<AiChatMessage> latest = aiChatMessageMapper.selectLatestMessagesBySessionIds(List.of(10L, 20L)); |
| | | |
| | | assertEquals(List.of(10L, 20L), latest.stream().map(AiChatMessage::getSessionId).sorted().collect(Collectors.toList())); |
| | | assertEquals("world latest", latest.stream().filter(item -> item.getSessionId().equals(10L)).findFirst().orElseThrow().getContent()); |
| | | |
| | | aiChatMessageMapper.softDeleteByIds(List.of(latest.stream().filter(item -> item.getSessionId().equals(20L)).findFirst().orElseThrow().getId())); |
| | | assertEquals(3, jdbcTemplate.queryForObject("select count(*) from sys_ai_chat_message where deleted = 0", Integer.class)); |
| | | |
| | | aiChatMessageMapper.softDeleteBySessionId(10L); |
| | | assertEquals(0, jdbcTemplate.queryForObject("select count(*) from sys_ai_chat_message where deleted = 0", Integer.class)); |
| | | } |
| | | |
| | | private AiChatMessage message(Long sessionId, int seqNo, String role, String content, Date createTime) { |
| | | return new AiChatMessage() |
| | | .setSessionId(sessionId) |
| | | .setSeqNo(seqNo) |
| | | .setRole(role) |
| | | .setContent(content) |
| | | .setContentLength(content.length()) |
| | | .setUserId(1L) |
| | | .setTenantId(1L) |
| | | .setDeleted(0) |
| | | .setCreateBy(1L) |
| | | .setCreateTime(createTime); |
| | | } |
| | | |
| | | @Configuration |
| | | @MapperScan(basePackageClasses = AiChatMessageMapper.class) |
| | | static class TestConfig { |
| | | |
| | | @Bean |
| | | DataSource dataSource() { |
| | | DriverManagerDataSource dataSource = new DriverManagerDataSource(); |
| | | dataSource.setDriverClassName("org.h2.Driver"); |
| | | dataSource.setUrl("jdbc:h2:mem:ai-chat-message;MODE=MySQL;DB_CLOSE_DELAY=-1"); |
| | | dataSource.setUsername("sa"); |
| | | dataSource.setPassword(""); |
| | | return dataSource; |
| | | } |
| | | |
| | | @Bean |
| | | JdbcTemplate jdbcTemplate(DataSource dataSource) { |
| | | return new JdbcTemplate(dataSource); |
| | | } |
| | | |
| | | @Bean |
| | | MybatisSqlSessionFactoryBean sqlSessionFactory(DataSource dataSource) throws Exception { |
| | | MybatisSqlSessionFactoryBean factoryBean = new MybatisSqlSessionFactoryBean(); |
| | | factoryBean.setDataSource(dataSource); |
| | | factoryBean.setMapperLocations(new PathMatchingResourcePatternResolver() |
| | | .getResources("classpath*:mapper/ai/*.xml")); |
| | | MybatisConfiguration configuration = new MybatisConfiguration(); |
| | | configuration.setMapUnderscoreToCamelCase(true); |
| | | factoryBean.setConfiguration(configuration); |
| | | return factoryBean; |
| | | } |
| | | } |
| | | } |
| New file |
| | |
| | | package com.vincent.rsf.server.ai.service.impl.chat; |
| | | |
| | | import com.vincent.rsf.server.ai.dto.AiChatRequest; |
| | | import com.vincent.rsf.server.ai.dto.AiChatMessageDto; |
| | | import com.vincent.rsf.server.ai.dto.AiResolvedConfig; |
| | | import com.vincent.rsf.server.ai.entity.AiParam; |
| | | import com.vincent.rsf.server.ai.entity.AiPrompt; |
| | | import com.vincent.rsf.server.ai.service.AiCallLogService; |
| | | import com.vincent.rsf.server.ai.service.AiChatMemoryService; |
| | | import com.vincent.rsf.server.ai.service.AiConfigResolverService; |
| | | import com.vincent.rsf.server.ai.service.AiParamService; |
| | | import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory; |
| | | import com.vincent.rsf.server.ai.store.AiChatRateLimiter; |
| | | import com.vincent.rsf.server.ai.store.AiStreamStateStore; |
| | | import org.junit.jupiter.api.Test; |
| | | import org.junit.jupiter.api.extension.ExtendWith; |
| | | import org.mockito.Mock; |
| | | import org.mockito.junit.jupiter.MockitoExtension; |
| | | import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; |
| | | |
| | | import java.util.List; |
| | | |
| | | import static org.mockito.ArgumentMatchers.any; |
| | | import static org.mockito.ArgumentMatchers.eq; |
| | | import static org.mockito.Mockito.never; |
| | | import static org.mockito.Mockito.verify; |
| | | import static org.mockito.Mockito.when; |
| | | |
| | | @ExtendWith(MockitoExtension.class) |
| | | class AiChatOrchestratorTest { |
| | | |
| | | @Mock |
| | | private AiConfigResolverService aiConfigResolverService; |
| | | @Mock |
| | | private AiChatMemoryService aiChatMemoryService; |
| | | @Mock |
| | | private AiParamService aiParamService; |
| | | @Mock |
| | | private McpMountRuntimeFactory mcpMountRuntimeFactory; |
| | | @Mock |
| | | private AiCallLogService aiCallLogService; |
| | | @Mock |
| | | private AiChatRateLimiter aiChatRateLimiter; |
| | | @Mock |
| | | private AiStreamStateStore aiStreamStateStore; |
| | | @Mock |
| | | private AiChatRuntimeAssembler aiChatRuntimeAssembler; |
| | | @Mock |
| | | private AiPromptMessageBuilder aiPromptMessageBuilder; |
| | | @Mock |
| | | private AiOpenAiChatModelFactory aiOpenAiChatModelFactory; |
| | | @Mock |
| | | private AiToolObservationService aiToolObservationService; |
| | | @Mock |
| | | private AiSseEventPublisher aiSseEventPublisher; |
| | | @Mock |
| | | private AiChatFailureHandler aiChatFailureHandler; |
| | | |
| | | @Test |
| | | void shouldShortCircuitWhenRateLimited() { |
| | | AiChatOrchestrator aiChatOrchestrator = new AiChatOrchestrator( |
| | | aiConfigResolverService, |
| | | aiChatMemoryService, |
| | | aiParamService, |
| | | mcpMountRuntimeFactory, |
| | | aiCallLogService, |
| | | aiChatRateLimiter, |
| | | aiStreamStateStore, |
| | | aiChatRuntimeAssembler, |
| | | aiPromptMessageBuilder, |
| | | aiOpenAiChatModelFactory, |
| | | aiToolObservationService, |
| | | aiSseEventPublisher, |
| | | aiChatFailureHandler |
| | | ); |
| | | AiChatRequest request = new AiChatRequest(); |
| | | request.setRequestId("req-1"); |
| | | request.setPromptCode("home.default"); |
| | | request.setMessages(List.of(message("user", "hello"))); |
| | | AiResolvedConfig config = AiResolvedConfig.builder() |
| | | .promptCode("home.default") |
| | | .aiParam(new AiParam().setModel("gpt-test")) |
| | | .prompt(new AiPrompt().setName("default")) |
| | | .mcpMounts(List.of()) |
| | | .build(); |
| | | when(aiConfigResolverService.resolve("home.default", 1L, null)).thenReturn(config); |
| | | when(aiParamService.listChatModelOptions(1L)).thenReturn(List.of()); |
| | | when(aiChatRateLimiter.allowChatRequest(1L, 2L, "home.default")).thenReturn(false); |
| | | when(aiChatFailureHandler.buildAiException(eq("AI_RATE_LIMITED"), any(), eq("RATE_LIMIT"), any(), eq(null))) |
| | | .thenCallRealMethod(); |
| | | |
| | | aiChatOrchestrator.executeStream(request, 2L, 1L, new SseEmitter(1000L)); |
| | | |
| | | verify(aiChatFailureHandler).handleStreamFailure(any(), eq("req-1"), eq(null), eq(null), any(Long.class), |
| | | eq(null), any(), eq(null), eq(0L), eq(0L), eq(null), eq(1L), eq(2L), eq("home.default")); |
| | | verify(aiCallLogService, never()).startCallLog(any(), any(), any(), any(), any(), any(), any(), any(), any(), any()); |
| | | } |
| | | |
| | | private AiChatMessageDto message(String role, String content) { |
| | | AiChatMessageDto dto = new AiChatMessageDto(); |
| | | dto.setRole(role); |
| | | dto.setContent(content); |
| | | return dto; |
| | | } |
| | | } |
| New file |
| | |
| | | package com.vincent.rsf.server.ai.service.impl.chat; |
| | | |
| | | import com.vincent.rsf.server.ai.dto.AiChatMemoryDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatMessageDto; |
| | | import com.vincent.rsf.server.ai.entity.AiPrompt; |
| | | import org.junit.jupiter.api.Test; |
| | | import org.springframework.ai.chat.messages.Message; |
| | | import org.springframework.ai.chat.messages.SystemMessage; |
| | | import org.springframework.ai.chat.messages.UserMessage; |
| | | |
| | | import java.util.List; |
| | | import java.util.Map; |
| | | |
| | | import static org.junit.jupiter.api.Assertions.assertEquals; |
| | | import static org.junit.jupiter.api.Assertions.assertInstanceOf; |
| | | |
| | | class AiPromptMessageBuilderTest { |
| | | |
| | | private final AiPromptMessageBuilder builder = new AiPromptMessageBuilder(); |
| | | |
| | | @Test |
| | | void shouldBuildPromptMessagesInExpectedOrderAndRenderLastUserPrompt() { |
| | | AiChatMemoryDto memory = AiChatMemoryDto.builder() |
| | | .memorySummary("summary") |
| | | .memoryFacts("facts") |
| | | .build(); |
| | | AiPrompt prompt = new AiPrompt() |
| | | .setSystemPrompt("system") |
| | | .setUserPromptTemplate("用户问题: {{input}} | 仓库: {{warehouse}}"); |
| | | List<AiChatMessageDto> messages = List.of( |
| | | message("user", "old question"), |
| | | message("assistant", "old answer"), |
| | | message("user", "latest question") |
| | | ); |
| | | |
| | | List<Message> built = builder.buildPromptMessages(memory, messages, prompt, Map.of("warehouse", "WH1")); |
| | | |
| | | assertEquals(6, built.size()); |
| | | assertInstanceOf(SystemMessage.class, built.get(0)); |
| | | assertEquals("system", built.get(0).getText()); |
| | | assertEquals("历史摘要:\nsummary", built.get(1).getText()); |
| | | assertEquals("关键事实:\nfacts", built.get(2).getText()); |
| | | assertInstanceOf(UserMessage.class, built.get(3)); |
| | | assertEquals("old question", built.get(3).getText()); |
| | | assertEquals("old answer", built.get(4).getText()); |
| | | assertInstanceOf(UserMessage.class, built.get(5)); |
| | | assertEquals("用户问题: latest question | 仓库: WH1", built.get(5).getText()); |
| | | } |
| | | |
| | | @Test |
| | | void shouldMergePersistedAndMemoryMessages() { |
| | | List<AiChatMessageDto> merged = builder.mergeMessages( |
| | | List.of(message("user", "persisted")), |
| | | List.of(message("assistant", "memory")) |
| | | ); |
| | | |
| | | assertEquals(2, merged.size()); |
| | | assertEquals("persisted", merged.get(0).getContent()); |
| | | assertEquals("memory", merged.get(1).getContent()); |
| | | } |
| | | |
| | | private AiChatMessageDto message(String role, String content) { |
| | | AiChatMessageDto dto = new AiChatMessageDto(); |
| | | dto.setRole(role); |
| | | dto.setContent(content); |
| | | return dto; |
| | | } |
| | | } |
| New file |
| | |
| | | package com.vincent.rsf.server.ai.service.impl.conversation; |
| | | |
| | | import com.vincent.rsf.server.ai.dto.AiChatMessageDto; |
| | | import com.vincent.rsf.server.ai.entity.AiChatMessage; |
| | | import com.vincent.rsf.server.ai.entity.AiChatSession; |
| | | import com.vincent.rsf.server.ai.mapper.AiChatMessageMapper; |
| | | import com.vincent.rsf.server.ai.mapper.AiChatSessionMapper; |
| | | import org.junit.jupiter.api.Test; |
| | | import org.junit.jupiter.api.extension.ExtendWith; |
| | | import org.mockito.ArgumentCaptor; |
| | | import org.mockito.InjectMocks; |
| | | import org.mockito.Mock; |
| | | import org.mockito.junit.jupiter.MockitoExtension; |
| | | import org.springframework.transaction.support.TransactionSynchronization; |
| | | import org.springframework.transaction.support.TransactionSynchronizationManager; |
| | | |
| | | import java.util.List; |
| | | |
| | | import static org.junit.jupiter.api.Assertions.assertEquals; |
| | | import static org.mockito.Mockito.never; |
| | | import static org.mockito.ArgumentMatchers.any; |
| | | import static org.mockito.ArgumentMatchers.eq; |
| | | import static org.mockito.Mockito.verify; |
| | | import static org.mockito.Mockito.when; |
| | | |
| | | @ExtendWith(MockitoExtension.class) |
| | | class AiConversationCommandServiceTest { |
| | | |
| | | @Mock |
| | | private AiChatSessionMapper aiChatSessionMapper; |
| | | @Mock |
| | | private AiChatMessageMapper aiChatMessageMapper; |
| | | @Mock |
| | | private AiConversationQueryService aiConversationQueryService; |
| | | @Mock |
| | | private AiMemoryProfileService aiMemoryProfileService; |
| | | |
| | | @InjectMocks |
| | | private AiConversationCommandService aiConversationCommandService; |
| | | |
| | | @Test |
| | | void shouldBatchInsertMessagesAndScheduleRefreshWhenSavingRound() { |
| | | AiChatSession session = new AiChatSession().setId(10L).setTitle(null); |
| | | when(aiConversationQueryService.findNextSeqNo(10L)).thenReturn(3); |
| | | when(aiConversationQueryService.resolveUpdatedTitle(eq(null), any())).thenReturn("new title"); |
| | | |
| | | aiConversationCommandService.saveRound(session, 7L, 8L, |
| | | List.of(message("user", "hello"), message("assistant", "draft answer")), |
| | | "final answer"); |
| | | |
| | | ArgumentCaptor<List<AiChatMessage>> captor = ArgumentCaptor.forClass(List.class); |
| | | verify(aiChatMessageMapper).insertBatch(captor.capture()); |
| | | assertEquals(3, captor.getValue().size()); |
| | | assertEquals(3, captor.getValue().get(0).getSeqNo()); |
| | | assertEquals(5, captor.getValue().get(2).getSeqNo()); |
| | | verify(aiMemoryProfileService).scheduleMemoryProfileRefresh(10L, 7L, 8L); |
| | | verify(aiConversationQueryService).evictConversationCaches(8L, 7L); |
| | | } |
| | | |
| | | @Test |
| | | void shouldDeferConversationSideEffectsUntilAfterCommitWhenSynchronizationActive() { |
| | | AiChatSession session = new AiChatSession().setId(10L).setTitle(null); |
| | | when(aiConversationQueryService.findNextSeqNo(10L)).thenReturn(1); |
| | | when(aiConversationQueryService.resolveUpdatedTitle(eq(null), any())).thenReturn("new title"); |
| | | |
| | | TransactionSynchronizationManager.initSynchronization(); |
| | | try { |
| | | aiConversationCommandService.saveRound(session, 7L, 8L, List.of(message("user", "hello")), "final answer"); |
| | | |
| | | verify(aiConversationQueryService, never()).evictConversationCaches(8L, 7L); |
| | | verify(aiMemoryProfileService, never()).scheduleMemoryProfileRefresh(10L, 7L, 8L); |
| | | |
| | | for (TransactionSynchronization synchronization : TransactionSynchronizationManager.getSynchronizations()) { |
| | | synchronization.afterCommit(); |
| | | } |
| | | |
| | | verify(aiConversationQueryService).evictConversationCaches(8L, 7L); |
| | | verify(aiMemoryProfileService).scheduleMemoryProfileRefresh(10L, 7L, 8L); |
| | | } finally { |
| | | TransactionSynchronizationManager.clearSynchronization(); |
| | | } |
| | | } |
| | | |
| | | @Test |
| | | void shouldSoftDeleteOnlyOlderMessagesWhenRetainingLatestRound() { |
| | | when(aiConversationQueryService.listMessageRecords(10L)).thenReturn(List.of( |
| | | record(1L, "user"), |
| | | record(2L, "assistant"), |
| | | record(3L, "user"), |
| | | record(4L, "assistant") |
| | | )); |
| | | when(aiConversationQueryService.tailMessageRecordsByRounds(any(), eq(1))).thenReturn(List.of( |
| | | record(3L, "user"), |
| | | record(4L, "assistant") |
| | | )); |
| | | |
| | | aiConversationCommandService.retainLatestRound(7L, 8L, 10L); |
| | | |
| | | verify(aiChatMessageMapper).softDeleteByIds(List.of(1L, 2L)); |
| | | verify(aiMemoryProfileService).scheduleMemoryProfileRefresh(10L, 7L, 8L); |
| | | } |
| | | |
| | | private AiChatMessageDto message(String role, String content) { |
| | | AiChatMessageDto dto = new AiChatMessageDto(); |
| | | dto.setRole(role); |
| | | dto.setContent(content); |
| | | return dto; |
| | | } |
| | | |
| | | private AiChatMessage record(Long id, String role) { |
| | | return new AiChatMessage().setId(id).setRole(role); |
| | | } |
| | | } |
| New file |
| | |
| | | package com.vincent.rsf.server.ai.service.impl.mcp; |
| | | |
| | | import com.fasterxml.jackson.databind.ObjectMapper; |
| | | import com.vincent.rsf.server.ai.config.AiDefaults; |
| | | import com.vincent.rsf.server.ai.dto.AiMcpToolPreviewDto; |
| | | import com.vincent.rsf.server.ai.entity.AiMcpMount; |
| | | import com.vincent.rsf.server.ai.mapper.AiMcpMountMapper; |
| | | import com.vincent.rsf.server.ai.service.BuiltinMcpToolRegistry; |
| | | import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory; |
| | | import org.junit.jupiter.api.Test; |
| | | import org.junit.jupiter.api.extension.ExtendWith; |
| | | import org.mockito.Mock; |
| | | import org.mockito.junit.jupiter.MockitoExtension; |
| | | import org.springframework.ai.tool.ToolCallback; |
| | | import org.springframework.ai.tool.definition.ToolDefinition; |
| | | import org.springframework.ai.tool.metadata.ToolMetadata; |
| | | |
| | | import java.util.List; |
| | | |
| | | import static org.junit.jupiter.api.Assertions.assertEquals; |
| | | import static org.mockito.Mockito.mock; |
| | | import static org.mockito.Mockito.when; |
| | | |
| | | @ExtendWith(MockitoExtension.class) |
| | | class AiMcpAdminServiceTest { |
| | | |
| | | @Mock |
| | | private AiMcpMountMapper aiMcpMountMapper; |
| | | @Mock |
| | | private BuiltinMcpToolRegistry builtinMcpToolRegistry; |
| | | @Mock |
| | | private McpMountRuntimeFactory mcpMountRuntimeFactory; |
| | | |
| | | @Test |
| | | void shouldMergeGovernedCatalogIntoPreviewResult() { |
| | | AiMcpAdminService aiMcpAdminService = new AiMcpAdminService( |
| | | aiMcpMountMapper, |
| | | builtinMcpToolRegistry, |
| | | mcpMountRuntimeFactory, |
| | | new ObjectMapper() |
| | | ); |
| | | AiMcpMount mount = new AiMcpMount() |
| | | .setTransportType(AiDefaults.MCP_TRANSPORT_BUILTIN) |
| | | .setBuiltinCode(AiDefaults.MCP_BUILTIN_RSF_WMS) |
| | | .setName("builtin"); |
| | | ToolCallback callback = mock(ToolCallback.class); |
| | | ToolDefinition toolDefinition = mock(ToolDefinition.class); |
| | | ToolMetadata toolMetadata = mock(ToolMetadata.class); |
| | | McpMountRuntimeFactory.McpMountRuntime runtime = mock(McpMountRuntimeFactory.McpMountRuntime.class); |
| | | when(toolDefinition.name()).thenReturn("rsf_query_task"); |
| | | when(toolDefinition.description()).thenReturn("desc"); |
| | | when(toolDefinition.inputSchema()).thenReturn("{schema}"); |
| | | when(toolMetadata.returnDirect()).thenReturn(Boolean.FALSE); |
| | | when(callback.getToolDefinition()).thenReturn(toolDefinition); |
| | | when(callback.getToolMetadata()).thenReturn(toolMetadata); |
| | | when(runtime.getToolCallbacks()).thenReturn(new ToolCallback[]{callback}); |
| | | when(runtime.getErrors()).thenReturn(List.of()); |
| | | when(mcpMountRuntimeFactory.create(List.of(mount), 1L)).thenReturn(runtime); |
| | | when(builtinMcpToolRegistry.listBuiltinToolCatalog(AiDefaults.MCP_BUILTIN_RSF_WMS)).thenReturn(List.of( |
| | | AiMcpToolPreviewDto.builder() |
| | | .name("rsf_query_task") |
| | | .toolGroup("任务查询") |
| | | .toolPurpose("purpose") |
| | | .queryBoundary("boundary") |
| | | .exampleQuestions(List.of("q1")) |
| | | .build() |
| | | )); |
| | | |
| | | List<AiMcpToolPreviewDto> result = aiMcpAdminService.previewTools(mount, 1L); |
| | | |
| | | assertEquals(1, result.size()); |
| | | assertEquals("任务查询", result.get(0).getToolGroup()); |
| | | assertEquals("purpose", result.get(0).getToolPurpose()); |
| | | assertEquals("{schema}", result.get(0).getInputSchema()); |
| | | } |
| | | } |
| New file |
| | |
| | | package com.vincent.rsf.server.ai.store.support; |
| | | |
| | | import org.junit.jupiter.api.Test; |
| | | |
| | | import static org.junit.jupiter.api.Assertions.assertEquals; |
| | | import static org.junit.jupiter.api.Assertions.assertNotEquals; |
| | | import static org.junit.jupiter.api.Assertions.assertTrue; |
| | | |
| | | class AiRedisKeysTest { |
| | | |
| | | private final AiRedisKeys aiRedisKeys = new AiRedisKeys(); |
| | | |
| | | @Test |
| | | void shouldBuildStableRuntimeAndToolKeys() { |
| | | String runtimeKey = aiRedisKeys.buildRuntimeKey(1L, 2L, "home.default", 3L, 4L); |
| | | String toolKeyA = aiRedisKeys.buildToolResultKey(1L, "req-1", "tool", "{\"a\":1}"); |
| | | String toolKeyB = aiRedisKeys.buildToolResultKey(1L, "req-1", "tool", "{\"a\":2}"); |
| | | |
| | | assertEquals("AI:RUNTIME:1:2:home.default:3:4", runtimeKey); |
| | | assertTrue(toolKeyA.startsWith("AI:TOOL:RESULT:1:req-1:tool:")); |
| | | assertNotEquals(toolKeyA, toolKeyB); |
| | | } |
| | | |
| | | @Test |
| | | void shouldUseSentinelTokensForNullableValues() { |
| | | assertEquals("LATEST", aiRedisKeys.sessionToken(null)); |
| | | assertEquals("DEFAULT", aiRedisKeys.aiParamToken(null)); |
| | | assertEquals("_", aiRedisKeys.safeToken(null)); |
| | | } |
| | | } |