From 82624affb0251b75b62b35567d3eb260c06efe78 Mon Sep 17 00:00:00 2001
From: zhou zhou <3272660260@qq.com>
Date: 星期一, 23 三月 2026 12:48:07 +0800
Subject: [PATCH] #ai 代码优化
---
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/conversation/AiConversationCommandService.java | 231 ++
rsf-server/src/test/java/com/vincent/rsf/server/AI/service/impl/mcp/AiMcpAdminServiceTest.java | 76
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/chat/AiChatOrchestrator.java | 291 ++
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/McpMountRuntimeFactoryImpl.java | 94
rsf-server/src/main/java/com/vincent/rsf/server/ai/store/AiChatRateLimiter.java | 36
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiConfigResolverServiceImpl.java | 7
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/chat/AiToolObservationService.java | 205 +
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatMemoryServiceImpl.java | 633 -----
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/BuiltinMcpToolRegistryImpl.java | 81
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/mcp/McpClientFactory.java | 104 +
rsf-server/src/main/resources/mapper/ai/AiChatMessageMapper.xml | 66
rsf-server/src/main/java/com/vincent/rsf/server/ai/store/AiCachedToolResult.java | 19
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiPromptServiceImpl.java | 16
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiMcpMountServiceImpl.java | 239 --
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/chat/AiSseEventPublisher.java | 112 +
rsf-server/src/main/java/com/vincent/rsf/server/ai/store/AiObserveStatsStore.java | 163 +
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/chat/AiChatFailureHandler.java | 104 +
rsf-server/src/main/java/com/vincent/rsf/server/ai/store/AiMcpCacheStore.java | 45
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java | 995 ---------
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiOpenAiApiSupport.java | 8
rsf-server/src/test/java/com/vincent/rsf/server/AI/mapper/AiChatMessageMapperIntegrationTest.java | 123 +
rsf-server/pom.xml | 10
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/mcp/BuiltinMcpToolCatalogProvider.java | 79
rsf-server/src/main/java/com/vincent/rsf/server/ai/store/support/AiRedisKeys.java | 117 +
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/conversation/AiConversationQueryService.java | 318 +++
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiParamServiceImpl.java | 19
rsf-server/src/main/java/com/vincent/rsf/server/ai/store/AiConfigCacheStore.java | 32
rsf-server/src/main/resources/sql/ai/20260323_ai_indexes.sql | 83
rsf-server/src/main/java/com/vincent/rsf/server/ai/store/support/AiRedisExecutor.java | 94
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/chat/AiOpenAiChatModelFactory.java | 89
rsf-server/src/test/java/com/vincent/rsf/server/AI/service/impl/conversation/AiConversationCommandServiceTest.java | 113 +
rsf-server/src/main/java/com/vincent/rsf/server/ai/store/AiConversationCacheStore.java | 67
rsf-server/src/main/java/com/vincent/rsf/server/ai/store/AiToolResultStore.java | 30
rsf-server/src/test/java/com/vincent/rsf/server/AI/service/impl/chat/AiPromptMessageBuilderTest.java | 68
rsf-server/src/test/java/com/vincent/rsf/server/AI/store/support/AiRedisKeysTest.java | 30
/dev/null | 519 ----
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/chat/AiThinkingTraceEmitter.java | 129 +
rsf-server/src/main/java/com/vincent/rsf/server/ai/store/AiStreamStateStore.java | 62
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/chat/AiChatRuntimeAssembler.java | 36
rsf-server/src/main/java/com/vincent/rsf/server/ai/store/support/AiRedisIndexSupport.java | 40
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiCallLogServiceImpl.java | 13
rsf-server/src/main/java/com/vincent/rsf/server/ai/mapper/AiChatMessageMapper.java | 11
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/mcp/AiMcpAdminService.java | 243 ++
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/conversation/AiMemoryProfileService.java | 130 +
rsf-server/src/test/java/com/vincent/rsf/server/AI/service/impl/chat/AiChatOrchestratorTest.java | 105 +
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/chat/AiPromptMessageBuilder.java | 116 +
46 files changed, 3,686 insertions(+), 2,515 deletions(-)
diff --git a/rsf-server/pom.xml b/rsf-server/pom.xml
index c67254f..d7b7edb 100644
--- a/rsf-server/pom.xml
+++ b/rsf-server/pom.xml
@@ -63,6 +63,16 @@
<scope>system</scope>
<systemPath>${project.basedir}/src/main/resources/lib/RouteUtils.jar</systemPath>
</dependency>
+ <dependency>
+ <groupId>org.springframework.boot</groupId>
+ <artifactId>spring-boot-starter-test</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>com.h2database</groupId>
+ <artifactId>h2</artifactId>
+ <scope>test</scope>
+ </dependency>
</dependencies>
<build>
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/mapper/AiChatMessageMapper.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/mapper/AiChatMessageMapper.java
index 81a0c64..c8dd0ba 100644
--- a/rsf-server/src/main/java/com/vincent/rsf/server/ai/mapper/AiChatMessageMapper.java
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/mapper/AiChatMessageMapper.java
@@ -3,7 +3,18 @@
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.vincent.rsf.server.ai.entity.AiChatMessage;
import org.apache.ibatis.annotations.Mapper;
+import org.apache.ibatis.annotations.Param;
+
+import java.util.List;
@Mapper
public interface AiChatMessageMapper extends BaseMapper<AiChatMessage> {
+
+ int insertBatch(@Param("list") List<AiChatMessage> messages);
+
+ int softDeleteByIds(@Param("ids") List<Long> ids);
+
+ int softDeleteBySessionId(@Param("sessionId") Long sessionId);
+
+ List<AiChatMessage> selectLatestMessagesBySessionIds(@Param("sessionIds") List<Long> sessionIds);
}
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiCallLogServiceImpl.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiCallLogServiceImpl.java
index 1a3d1c4..60e3964 100644
--- a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiCallLogServiceImpl.java
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiCallLogServiceImpl.java
@@ -8,6 +8,7 @@
import com.vincent.rsf.server.ai.entity.AiMcpCallLog;
import com.vincent.rsf.server.ai.mapper.AiCallLogMapper;
import com.vincent.rsf.server.ai.mapper.AiMcpCallLogMapper;
+import com.vincent.rsf.server.ai.store.AiObserveStatsStore;
import com.vincent.rsf.server.ai.service.AiCallLogService;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Service;
@@ -26,7 +27,7 @@
private static final Pattern BEARER_PATTERN = Pattern.compile("(?i)(bearer\\s+)([a-z0-9._-]+)");
private final AiMcpCallLogMapper aiMcpCallLogMapper;
- private final AiRedisSupport aiRedisSupport;
+ private final AiObserveStatsStore aiObserveStatsStore;
@Override
public AiCallLog startCallLog(String requestId, Long sessionId, Long userId, Long tenantId, String promptCode,
@@ -52,7 +53,7 @@
.setCreateTime(now)
.setUpdateTime(now);
this.save(callLog);
- aiRedisSupport.recordObserveCallStarted(tenantId);
+ aiObserveStatsStore.recordObserveCallStarted(tenantId);
return callLog;
}
@@ -77,7 +78,7 @@
.set(AiCallLog::getUpdateTime, new Date()));
AiCallLog latest = this.getById(callLogId);
if (latest != null) {
- aiRedisSupport.recordObserveCallFinished(latest.getTenantId(), status, elapsedMs, firstTokenLatencyMs, totalTokens);
+ aiObserveStatsStore.recordObserveCallFinished(latest.getTenantId(), status, elapsedMs, firstTokenLatencyMs, totalTokens);
}
}
@@ -101,7 +102,7 @@
.set(AiCallLog::getUpdateTime, new Date()));
AiCallLog latest = this.getById(callLogId);
if (latest != null) {
- aiRedisSupport.recordObserveCallFinished(latest.getTenantId(), status, elapsedMs, firstTokenLatencyMs, null);
+ aiObserveStatsStore.recordObserveCallFinished(latest.getTenantId(), status, elapsedMs, firstTokenLatencyMs, null);
}
}
@@ -127,12 +128,12 @@
.setUserId(userId)
.setTenantId(tenantId)
.setCreateTime(new Date()));
- aiRedisSupport.recordObserveToolCall(tenantId, toolName, status);
+ aiObserveStatsStore.recordObserveToolCall(tenantId, toolName, status);
}
@Override
public AiObserveStatsDto getObserveStats(Long tenantId) {
- return aiRedisSupport.getObserveStats(tenantId, () -> loadObserveStatsFromDatabase(tenantId));
+ return aiObserveStatsStore.getObserveStats(tenantId, () -> loadObserveStatsFromDatabase(tenantId));
}
private AiObserveStatsDto loadObserveStatsFromDatabase(Long tenantId) {
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatMemoryServiceImpl.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatMemoryServiceImpl.java
index a7210c3..3caec94 100644
--- a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatMemoryServiceImpl.java
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatMemoryServiceImpl.java
@@ -1,673 +1,68 @@
package com.vincent.rsf.server.ai.service.impl;
-import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
-import com.vincent.rsf.framework.common.Cools;
-import com.vincent.rsf.framework.exception.CoolException;
-import com.vincent.rsf.server.ai.config.AiDefaults;
import com.vincent.rsf.server.ai.dto.AiChatMemoryDto;
import com.vincent.rsf.server.ai.dto.AiChatMessageDto;
+import com.vincent.rsf.server.ai.dto.AiChatSessionDto;
import com.vincent.rsf.server.ai.dto.AiChatSessionPinRequest;
import com.vincent.rsf.server.ai.dto.AiChatSessionRenameRequest;
-import com.vincent.rsf.server.ai.dto.AiChatSessionDto;
-import com.vincent.rsf.server.ai.entity.AiChatMessage;
import com.vincent.rsf.server.ai.entity.AiChatSession;
-import com.vincent.rsf.server.ai.mapper.AiChatMessageMapper;
-import com.vincent.rsf.server.ai.mapper.AiChatSessionMapper;
import com.vincent.rsf.server.ai.service.AiChatMemoryService;
-import com.vincent.rsf.server.system.enums.StatusType;
+import com.vincent.rsf.server.ai.service.impl.conversation.AiConversationCommandService;
+import com.vincent.rsf.server.ai.service.impl.conversation.AiConversationQueryService;
import lombok.RequiredArgsConstructor;
-import lombok.extern.slf4j.Slf4j;
-import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service;
-import org.springframework.util.StringUtils;
-import java.util.ArrayList;
-import java.util.Date;
import java.util.List;
-import java.util.Set;
-import java.util.concurrent.Executor;
-import java.util.concurrent.ConcurrentHashMap;
@Service
-@Slf4j
@RequiredArgsConstructor
public class AiChatMemoryServiceImpl implements AiChatMemoryService {
- private final AiChatSessionMapper aiChatSessionMapper;
- private final AiChatMessageMapper aiChatMessageMapper;
- private final AiRedisSupport aiRedisSupport;
- @Qualifier("aiMemoryTaskExecutor")
- private final Executor aiMemoryTaskExecutor;
+ private final AiConversationQueryService aiConversationQueryService;
+ private final AiConversationCommandService aiConversationCommandService;
- /**
- * 鐢ㄤ袱涓湰鍦伴泦鍚堟妸鈥滃悓涓�涓細璇濈殑鎽樿鍒锋柊鈥濆悎骞舵垚涓茶浠诲姟锛岄伩鍏嶈繛缁秷鎭妸閲嶅浠诲姟濉炴弧绾跨▼姹犮��
- */
- private final Set<Long> refreshingSessionIds = ConcurrentHashMap.newKeySet();
- private final Set<Long> pendingRefreshSessionIds = ConcurrentHashMap.newKeySet();
-
- /**
- * 璇诲彇浼氳瘽璁板繂蹇収銆�
- * 杩斿洖缁撴灉鍚屾椂鍖呭惈瀹屾暣钀藉簱鍘嗗彶銆佺煭鏈熻蹇嗙獥鍙d互鍙婃憳瑕�/浜嬪疄璁板繂锛�
- * 渚夸簬璋冪敤鏂规寜涓嶅悓鐢ㄩ�旈�夋嫨鏁版嵁绮掑害銆�
- */
@Override
public AiChatMemoryDto getMemory(Long userId, Long tenantId, String promptCode, Long sessionId) {
- ensureIdentity(userId, tenantId);
- String resolvedPromptCode = requirePromptCode(promptCode);
- // 浼氳瘽璁板繂灞炰簬鍏稿瀷鈥滆澶氬啓灏戔�濇暟鎹紝鍏堣蛋鐭� TTL 缂撳瓨鑳芥槑鏄惧噺杞绘娊灞夊垵濮嬪寲鍜屽垏浼氳瘽鍘嬪姏銆�
- AiChatMemoryDto cached = aiRedisSupport.getMemory(tenantId, userId, resolvedPromptCode, sessionId);
- if (cached != null) {
- return cached;
- }
- AiChatSession session = sessionId == null
- ? findLatestSession(userId, tenantId, resolvedPromptCode)
- : getSession(sessionId, userId, tenantId, resolvedPromptCode);
- AiChatMemoryDto memory;
- if (session == null) {
- memory = AiChatMemoryDto.builder()
- .sessionId(null)
- .memorySummary(null)
- .memoryFacts(null)
- .recentMessageCount(0)
- .persistedMessages(List.of())
- .shortMemoryMessages(List.of())
- .build();
- aiRedisSupport.cacheMemory(tenantId, userId, resolvedPromptCode, sessionId, memory);
- return memory;
- }
- List<AiChatMessageDto> persistedMessages = listMessages(session.getId());
- List<AiChatMessageDto> shortMemoryMessages = tailMessagesByRounds(persistedMessages, AiDefaults.MEMORY_RECENT_ROUNDS);
- memory = AiChatMemoryDto.builder()
- .sessionId(session.getId())
- .memorySummary(session.getMemorySummary())
- .memoryFacts(session.getMemoryFacts())
- .recentMessageCount(shortMemoryMessages.size())
- .persistedMessages(persistedMessages)
- .shortMemoryMessages(shortMemoryMessages)
- .build();
- aiRedisSupport.cacheMemory(tenantId, userId, resolvedPromptCode, session.getId(), memory);
- if (sessionId == null || !session.getId().equals(sessionId)) {
- aiRedisSupport.cacheMemory(tenantId, userId, resolvedPromptCode, null, memory);
- }
- return memory;
+ return aiConversationQueryService.getMemory(userId, tenantId, promptCode, sessionId);
}
- /**
- * 鏌ヨ褰撳墠鐢ㄦ埛鍦ㄦ煇涓� Prompt 涓嬬殑浼氳瘽鍒楄〃銆�
- * 鍒楄〃鍙繑鍥炵敤浜庝晶杈规爮灞曠ず鐨勬憳瑕佷俊鎭紝涓嶈繑鍥炲畬鏁村璇濆唴瀹广��
- */
@Override
public List<AiChatSessionDto> listSessions(Long userId, Long tenantId, String promptCode, String keyword) {
- ensureIdentity(userId, tenantId);
- String resolvedPromptCode = requirePromptCode(promptCode);
- List<AiChatSessionDto> cached = aiRedisSupport.getSessionList(tenantId, userId, resolvedPromptCode, keyword);
- if (cached != null) {
- return cached;
- }
- List<AiChatSession> sessions = aiChatSessionMapper.selectList(new LambdaQueryWrapper<AiChatSession>()
- .eq(AiChatSession::getUserId, userId)
- .eq(AiChatSession::getTenantId, tenantId)
- .eq(AiChatSession::getPromptCode, resolvedPromptCode)
- .eq(AiChatSession::getDeleted, 0)
- .eq(AiChatSession::getStatus, StatusType.ENABLE.val)
- .like(StringUtils.hasText(keyword), AiChatSession::getTitle, keyword == null ? null : keyword.trim())
- .orderByDesc(AiChatSession::getPinned)
- .orderByDesc(AiChatSession::getLastMessageTime)
- .orderByDesc(AiChatSession::getId));
- if (Cools.isEmpty(sessions)) {
- aiRedisSupport.cacheSessionList(tenantId, userId, resolvedPromptCode, keyword, List.of());
- return List.of();
- }
- List<AiChatSessionDto> result = new ArrayList<>();
- for (AiChatSession session : sessions) {
- result.add(buildSessionDto(session));
- }
- aiRedisSupport.cacheSessionList(tenantId, userId, resolvedPromptCode, keyword, result);
- return result;
+ return aiConversationQueryService.listSessions(userId, tenantId, promptCode, keyword);
}
- /**
- * 瑙f瀽鏈疆璇锋眰搴旇钀藉埌鍝釜浼氳瘽銆�
- * 濡傛灉鍓嶇甯︿簡 sessionId 鍒欏仛褰掑睘鏍¢獙骞跺鐢紱鍚﹀垯鑷姩鍒涘缓鏂颁細璇濄��
- */
@Override
public AiChatSession resolveSession(Long userId, Long tenantId, String promptCode, Long sessionId, String titleSeed) {
- ensureIdentity(userId, tenantId);
- String resolvedPromptCode = requirePromptCode(promptCode);
- if (sessionId != null) {
- return getSession(sessionId, userId, tenantId, resolvedPromptCode);
- }
- Date now = new Date();
- AiChatSession session = new AiChatSession()
- .setTitle(buildSessionTitle(titleSeed))
- .setPromptCode(resolvedPromptCode)
- .setUserId(userId)
- .setTenantId(tenantId)
- .setLastMessageTime(now)
- .setPinned(0)
- .setStatus(StatusType.ENABLE.val)
- .setDeleted(0)
- .setCreateBy(userId)
- .setCreateTime(now)
- .setUpdateBy(userId)
- .setUpdateTime(now);
- aiChatSessionMapper.insert(session);
- evictConversationCaches(tenantId, userId);
- return session;
+ return aiConversationCommandService.resolveSession(userId, tenantId, promptCode, sessionId, titleSeed);
}
- /**
- * 钀藉簱淇濆瓨涓�鏁磋疆瀵硅瘽銆�
- * 杩欓噷浼氶『搴忓啓鍏ユ湰杞敤鎴锋秷鎭拰妯″瀷鍥炲锛屽苟鍦ㄦ渶鍚庡埛鏂颁細璇濇爣棰樺拰娲昏穬鏃堕棿銆�
- * 璁板繂鐢诲儚鏀逛负鍚庡彴寮傛鍒锋柊锛岄伩鍏嶆妸鎽樿閲嶇畻鑰楁椂鍘嬪湪鐢ㄦ埛鏈疆鍝嶅簲灏鹃儴銆�
- */
@Override
public void saveRound(AiChatSession session, Long userId, Long tenantId, List<AiChatMessageDto> memoryMessages, String assistantContent) {
- if (session == null || session.getId() == null) {
- throw new CoolException("AI 浼氳瘽涓嶅瓨鍦�");
- }
- ensureIdentity(userId, tenantId);
- List<AiChatMessageDto> normalizedMessages = normalizeMessages(memoryMessages);
- if (normalizedMessages.isEmpty()) {
- throw new CoolException("鏈疆娌℃湁鍙繚瀛樼殑瀵硅瘽娑堟伅");
- }
- int nextSeqNo = findNextSeqNo(session.getId());
- Date now = new Date();
- for (AiChatMessageDto message : normalizedMessages) {
- aiChatMessageMapper.insert(buildMessageEntity(session.getId(), nextSeqNo++, message.getRole(), message.getContent(), userId, tenantId, now));
- }
- if (StringUtils.hasText(assistantContent)) {
- aiChatMessageMapper.insert(buildMessageEntity(session.getId(), nextSeqNo, "assistant", assistantContent, userId, tenantId, now));
- }
- AiChatSession update = new AiChatSession()
- .setId(session.getId())
- .setTitle(resolveUpdatedTitle(session.getTitle(), normalizedMessages))
- .setLastMessageTime(now)
- .setUpdateBy(userId)
- .setUpdateTime(now);
- aiChatSessionMapper.updateById(update);
- evictConversationCaches(tenantId, userId);
- scheduleMemoryProfileRefresh(session.getId(), userId, tenantId);
+ aiConversationCommandService.saveRound(session, userId, tenantId, memoryMessages, assistantContent);
}
- /** 鍒犻櫎鏁翠釜浼氳瘽鍙婂叾娑堟伅銆� */
@Override
public void removeSession(Long userId, Long tenantId, Long sessionId) {
- ensureIdentity(userId, tenantId);
- if (sessionId == null) {
- throw new CoolException("AI 浼氳瘽 ID 涓嶈兘涓虹┖");
- }
- AiChatSession session = aiChatSessionMapper.selectOne(new LambdaQueryWrapper<AiChatSession>()
- .eq(AiChatSession::getId, sessionId)
- .eq(AiChatSession::getUserId, userId)
- .eq(AiChatSession::getTenantId, tenantId)
- .eq(AiChatSession::getDeleted, 0)
- .last("limit 1"));
- if (session == null) {
- throw new CoolException("AI 浼氳瘽涓嶅瓨鍦ㄦ垨鏃犳潈鍒犻櫎");
- }
- Date now = new Date();
- AiChatSession updateSession = new AiChatSession()
- .setId(sessionId)
- .setDeleted(1)
- .setUpdateBy(userId)
- .setUpdateTime(now);
- aiChatSessionMapper.updateById(updateSession);
- List<AiChatMessage> messages = aiChatMessageMapper.selectList(new LambdaQueryWrapper<AiChatMessage>()
- .eq(AiChatMessage::getSessionId, sessionId)
- .eq(AiChatMessage::getDeleted, 0));
- for (AiChatMessage message : messages) {
- AiChatMessage updateMessage = new AiChatMessage()
- .setId(message.getId())
- .setDeleted(1);
- aiChatMessageMapper.updateById(updateMessage);
- }
- evictConversationCaches(tenantId, userId);
+ aiConversationCommandService.removeSession(userId, tenantId, sessionId);
}
- /** 鏇存柊浼氳瘽鏍囬骞惰繑鍥炴渶鏂颁細璇濇憳瑕併�� */
@Override
public AiChatSessionDto renameSession(Long userId, Long tenantId, Long sessionId, AiChatSessionRenameRequest request) {
- ensureIdentity(userId, tenantId);
- if (request == null || !StringUtils.hasText(request.getTitle())) {
- throw new CoolException("浼氳瘽鏍囬涓嶈兘涓虹┖");
- }
- AiChatSession session = requireOwnedSession(sessionId, userId, tenantId);
- Date now = new Date();
- AiChatSession update = new AiChatSession()
- .setId(sessionId)
- .setTitle(buildSessionTitle(request.getTitle()))
- .setUpdateBy(userId)
- .setUpdateTime(now);
- aiChatSessionMapper.updateById(update);
- AiChatSessionDto sessionDto = buildSessionDto(requireOwnedSession(sessionId, userId, tenantId));
- evictConversationCaches(tenantId, userId);
- return sessionDto;
+ return aiConversationCommandService.renameSession(userId, tenantId, sessionId, request);
}
- /** 鏇存柊浼氳瘽缃《鐘舵�併�� */
@Override
public AiChatSessionDto pinSession(Long userId, Long tenantId, Long sessionId, AiChatSessionPinRequest request) {
- ensureIdentity(userId, tenantId);
- if (request == null || request.getPinned() == null) {
- throw new CoolException("缃《鐘舵�佷笉鑳戒负绌�");
- }
- AiChatSession session = requireOwnedSession(sessionId, userId, tenantId);
- Date now = new Date();
- AiChatSession update = new AiChatSession()
- .setId(sessionId)
- .setPinned(Boolean.TRUE.equals(request.getPinned()) ? 1 : 0)
- .setUpdateBy(userId)
- .setUpdateTime(now);
- aiChatSessionMapper.updateById(update);
- AiChatSessionDto sessionDto = buildSessionDto(requireOwnedSession(sessionId, userId, tenantId));
- evictConversationCaches(tenantId, userId);
- return sessionDto;
+ return aiConversationCommandService.pinSession(userId, tenantId, sessionId, request);
}
- /** 娓呯┖鏌愪釜浼氳瘽鐨勫叏閮ㄦ秷鎭拰娲剧敓璁板繂瀛楁銆� */
@Override
public void clearSessionMemory(Long userId, Long tenantId, Long sessionId) {
- ensureIdentity(userId, tenantId);
- AiChatSession session = requireOwnedSession(sessionId, userId, tenantId);
- List<AiChatMessage> messages = aiChatMessageMapper.selectList(new LambdaQueryWrapper<AiChatMessage>()
- .eq(AiChatMessage::getSessionId, sessionId)
- .eq(AiChatMessage::getDeleted, 0));
- for (AiChatMessage message : messages) {
- aiChatMessageMapper.updateById(new AiChatMessage()
- .setId(message.getId())
- .setDeleted(1));
- }
- aiChatSessionMapper.updateById(new AiChatSession()
- .setId(sessionId)
- .setMemorySummary(null)
- .setMemoryFacts(null)
- .setUpdateBy(userId)
- .setUpdateTime(new Date())
- .setLastMessageTime(session.getCreateTime()));
- evictConversationCaches(tenantId, userId);
+ aiConversationCommandService.clearSessionMemory(userId, tenantId, sessionId);
}
- /** 鍙繚鐣欐渶杩戜竴杞棶绛旓紝鐢ㄤ簬鎵嬪姩瑁佸壀闀夸細璇濄�� */
@Override
public void retainLatestRound(Long userId, Long tenantId, Long sessionId) {
- ensureIdentity(userId, tenantId);
- requireOwnedSession(sessionId, userId, tenantId);
- List<AiChatMessage> records = listMessageRecords(sessionId);
- if (records.isEmpty()) {
- return;
- }
- List<AiChatMessage> retained = tailMessageRecordsByRounds(records, 1);
- for (AiChatMessage message : records) {
- boolean shouldKeep = retained.stream().anyMatch(item -> item.getId().equals(message.getId()));
- if (!shouldKeep) {
- aiChatMessageMapper.updateById(new AiChatMessage()
- .setId(message.getId())
- .setDeleted(1));
- }
- }
- evictConversationCaches(tenantId, userId);
- scheduleMemoryProfileRefresh(sessionId, userId, tenantId);
- }
-
- private void evictConversationCaches(Long tenantId, Long userId) {
- // 浼氳瘽鏍囬銆佹憳瑕併�佹渶杩戞秷鎭拰 runtime 閮戒細浜掔浉褰卞搷锛岀粺涓�鎸夌敤鎴风淮搴︿竴璧峰け鏁堟洿绋冲Ε銆�
- aiRedisSupport.evictUserConversationCaches(tenantId, userId);
- }
-
- private void scheduleMemoryProfileRefresh(Long sessionId, Long userId, Long tenantId) {
- if (sessionId == null) {
- return;
- }
- if (!refreshingSessionIds.add(sessionId)) {
- pendingRefreshSessionIds.add(sessionId);
- return;
- }
- aiMemoryTaskExecutor.execute(() -> runMemoryProfileRefreshLoop(sessionId, userId, tenantId));
- }
-
- private void runMemoryProfileRefreshLoop(Long sessionId, Long userId, Long tenantId) {
- try {
- boolean shouldContinue;
- do {
- pendingRefreshSessionIds.remove(sessionId);
- try {
- refreshMemoryProfile(sessionId, userId);
- evictConversationCaches(tenantId, userId);
- } catch (Exception e) {
- log.warn("AI memory profile refresh failed, sessionId={}, userId={}, tenantId={}, message={}",
- sessionId, userId, tenantId, e.getMessage(), e);
- }
- shouldContinue = pendingRefreshSessionIds.remove(sessionId);
- } while (shouldContinue);
- } finally {
- refreshingSessionIds.remove(sessionId);
- if (pendingRefreshSessionIds.remove(sessionId) && refreshingSessionIds.add(sessionId)) {
- aiMemoryTaskExecutor.execute(() -> runMemoryProfileRefreshLoop(sessionId, userId, tenantId));
- }
- }
- }
-
- private AiChatSession findLatestSession(Long userId, Long tenantId, String promptCode) {
- return aiChatSessionMapper.selectOne(new LambdaQueryWrapper<AiChatSession>()
- .eq(AiChatSession::getUserId, userId)
- .eq(AiChatSession::getTenantId, tenantId)
- .eq(AiChatSession::getPromptCode, promptCode)
- .eq(AiChatSession::getDeleted, 0)
- .eq(AiChatSession::getStatus, StatusType.ENABLE.val)
- .orderByDesc(AiChatSession::getLastMessageTime)
- .orderByDesc(AiChatSession::getId)
- .last("limit 1"));
- }
-
- private AiChatSession getSession(Long sessionId, Long userId, Long tenantId, String promptCode) {
- AiChatSession session = aiChatSessionMapper.selectOne(new LambdaQueryWrapper<AiChatSession>()
- .eq(AiChatSession::getId, sessionId)
- .eq(AiChatSession::getUserId, userId)
- .eq(AiChatSession::getTenantId, tenantId)
- .eq(AiChatSession::getPromptCode, promptCode)
- .eq(AiChatSession::getDeleted, 0)
- .eq(AiChatSession::getStatus, StatusType.ENABLE.val)
- .last("limit 1"));
- if (session == null) {
- throw new CoolException("AI 浼氳瘽涓嶅瓨鍦ㄦ垨鏃犳潈璁块棶");
- }
- return session;
- }
-
- private AiChatSession requireOwnedSession(Long sessionId, Long userId, Long tenantId) {
- if (sessionId == null) {
- throw new CoolException("AI 浼氳瘽 ID 涓嶈兘涓虹┖");
- }
- AiChatSession session = aiChatSessionMapper.selectOne(new LambdaQueryWrapper<AiChatSession>()
- .eq(AiChatSession::getId, sessionId)
- .eq(AiChatSession::getUserId, userId)
- .eq(AiChatSession::getTenantId, tenantId)
- .eq(AiChatSession::getDeleted, 0)
- .eq(AiChatSession::getStatus, StatusType.ENABLE.val)
- .last("limit 1"));
- if (session == null) {
- throw new CoolException("AI 浼氳瘽涓嶅瓨鍦ㄦ垨鏃犳潈璁块棶");
- }
- return session;
- }
-
- private List<AiChatMessageDto> listMessages(Long sessionId) {
- List<AiChatMessage> records = listMessageRecords(sessionId);
- if (Cools.isEmpty(records)) {
- return List.of();
- }
- List<AiChatMessageDto> messages = new ArrayList<>();
- for (AiChatMessage record : records) {
- if (!StringUtils.hasText(record.getContent())) {
- continue;
- }
- AiChatMessageDto item = new AiChatMessageDto();
- item.setRole(record.getRole());
- item.setContent(record.getContent());
- messages.add(item);
- }
- return messages;
- }
-
- private List<AiChatMessageDto> normalizeMessages(List<AiChatMessageDto> memoryMessages) {
- /** 娓呮礂鍓嶇涓婁紶鐨勫唴瀛樻秷鎭紝鍙厑璁� user/assistant 涓ょ被瑙掕壊钀藉簱銆� */
- List<AiChatMessageDto> normalized = new ArrayList<>();
- if (Cools.isEmpty(memoryMessages)) {
- return normalized;
- }
- for (AiChatMessageDto item : memoryMessages) {
- if (item == null || !StringUtils.hasText(item.getContent())) {
- continue;
- }
- String role = item.getRole() == null ? "user" : item.getRole().toLowerCase();
- if ("system".equals(role)) {
- continue;
- }
- AiChatMessageDto normalizedItem = new AiChatMessageDto();
- normalizedItem.setRole("assistant".equals(role) ? "assistant" : "user");
- normalizedItem.setContent(item.getContent().trim());
- normalized.add(normalizedItem);
- }
- return normalized;
- }
-
- private List<AiChatMessage> listMessageRecords(Long sessionId) {
- return aiChatMessageMapper.selectList(new LambdaQueryWrapper<AiChatMessage>()
- .eq(AiChatMessage::getSessionId, sessionId)
- .eq(AiChatMessage::getDeleted, 0)
- .orderByAsc(AiChatMessage::getSeqNo)
- .orderByAsc(AiChatMessage::getId));
- }
-
- private int findNextSeqNo(Long sessionId) {
- AiChatMessage lastMessage = aiChatMessageMapper.selectOne(new LambdaQueryWrapper<AiChatMessage>()
- .eq(AiChatMessage::getSessionId, sessionId)
- .eq(AiChatMessage::getDeleted, 0)
- .orderByDesc(AiChatMessage::getSeqNo)
- .orderByDesc(AiChatMessage::getId)
- .last("limit 1"));
- return lastMessage == null || lastMessage.getSeqNo() == null ? 1 : lastMessage.getSeqNo() + 1;
- }
-
- private AiChatMessage buildMessageEntity(Long sessionId, int seqNo, String role, String content, Long userId, Long tenantId, Date createTime) {
- return new AiChatMessage()
- .setSessionId(sessionId)
- .setSeqNo(seqNo)
- .setRole(role)
- .setContent(content)
- .setContentLength(content == null ? 0 : content.length())
- .setUserId(userId)
- .setTenantId(tenantId)
- .setDeleted(0)
- .setCreateBy(userId)
- .setCreateTime(createTime);
- }
-
- private String resolveUpdatedTitle(String currentTitle, List<AiChatMessageDto> memoryMessages) {
- if (StringUtils.hasText(currentTitle)) {
- return currentTitle;
- }
- for (AiChatMessageDto item : memoryMessages) {
- if ("user".equals(item.getRole()) && StringUtils.hasText(item.getContent())) {
- return buildSessionTitle(item.getContent());
- }
- }
- return null;
- }
-
- private String buildSessionTitle(String titleSeed) {
- /**
- * 鎶婇杞敤鎴烽棶棰樺帇缂╂垚閫傚悎浣滀负浼氳瘽鏍囬鐨勭煭鎽樿銆�
- * 杩欓噷浼氬幓鎺夋崲琛屻�佽繛缁┖鐧斤紝骞朵紭鍏堝湪鑷劧璇箟鏂偣澶勬埅鏂��
- */
- if (!StringUtils.hasText(titleSeed)) {
- throw new CoolException("AI 浼氳瘽鏍囬涓嶈兘涓虹┖");
- }
- String title = titleSeed.trim()
- .replace("\r", " ")
- .replace("\n", " ")
- .replaceAll("\\s+", " ");
- int punctuationIndex = findSummaryBreakIndex(title);
- if (punctuationIndex > 0) {
- title = title.substring(0, punctuationIndex).trim();
- }
- return title.length() > 48 ? title.substring(0, 48) : title;
- }
-
- private int findSummaryBreakIndex(String title) {
- String[] separators = {"銆�", "锛�", "锛�", ".", "!", "?"};
- int result = -1;
- for (String separator : separators) {
- int index = title.indexOf(separator);
- if (index > 0 && (result < 0 || index < result)) {
- result = index;
- }
- }
- return result;
- }
-
- private AiChatSessionDto buildSessionDto(AiChatSession session) {
- AiChatMessage lastMessage = aiChatMessageMapper.selectOne(new LambdaQueryWrapper<AiChatMessage>()
- .eq(AiChatMessage::getSessionId, session.getId())
- .eq(AiChatMessage::getDeleted, 0)
- .orderByDesc(AiChatMessage::getSeqNo)
- .orderByDesc(AiChatMessage::getId)
- .last("limit 1"));
- return AiChatSessionDto.builder()
- .sessionId(session.getId())
- .title(session.getTitle())
- .promptCode(session.getPromptCode())
- .pinned(session.getPinned() != null && session.getPinned() == 1)
- .lastMessagePreview(buildLastMessagePreview(lastMessage))
- .lastMessageTime(session.getLastMessageTime())
- .build();
- }
-
- private String buildLastMessagePreview(AiChatMessage message) {
- if (message == null || !StringUtils.hasText(message.getContent())) {
- return null;
- }
- String preview = message.getContent().trim()
- .replace("\r", " ")
- .replace("\n", " ")
- .replaceAll("\\s+", " ");
- String prefix = "assistant".equalsIgnoreCase(message.getRole()) ? "AI: " : "浣�: ";
- String normalized = prefix + preview;
- return normalized.length() > 80 ? normalized.substring(0, 80) : normalized;
- }
-
- private void refreshMemoryProfile(Long sessionId, Long userId) {
- /**
- * 閲嶆柊璁$畻浼氳瘽鐨勬憳瑕佽蹇嗗拰鍏抽敭浜嬪疄銆�
- * 杩欐槸鈥滄寔涔呭寲娑堟伅鈥濆拰鈥滄ā鍨嬩笂涓嬫枃娌荤悊鈥濅箣闂寸殑妗ユ鏂规硶銆�
- * 鐜板湪瀹冭繍琛屽湪鍚庡彴绾跨▼閲岋紝鍥犳鍏佽鐭椂闂存渶缁堜竴鑷达紝鑰屼笉鏄己鍒舵湰杞悓姝ュ畬鎴愩��
- */
- List<AiChatMessageDto> messages = listMessages(sessionId);
- List<AiChatMessageDto> shortMemoryMessages = tailMessagesByRounds(messages, AiDefaults.MEMORY_RECENT_ROUNDS);
- List<AiChatMessageDto> historyMessages = messages.size() > shortMemoryMessages.size()
- ? messages.subList(0, messages.size() - shortMemoryMessages.size())
- : List.of();
- String memorySummary = historyMessages.size() >= AiDefaults.MEMORY_SUMMARY_TRIGGER_MESSAGES
- ? buildMemorySummary(historyMessages)
- : null;
- String memoryFacts = buildMemoryFacts(messages);
- AiChatMessage lastMessage = aiChatMessageMapper.selectOne(new LambdaQueryWrapper<AiChatMessage>()
- .eq(AiChatMessage::getSessionId, sessionId)
- .eq(AiChatMessage::getDeleted, 0)
- .orderByDesc(AiChatMessage::getSeqNo)
- .orderByDesc(AiChatMessage::getId)
- .last("limit 1"));
- aiChatSessionMapper.updateById(new AiChatSession()
- .setId(sessionId)
- .setMemorySummary(memorySummary)
- .setMemoryFacts(memoryFacts)
- .setLastMessageTime(lastMessage == null ? null : lastMessage.getCreateTime())
- .setUpdateBy(userId)
- .setUpdateTime(new Date()));
- }
-
- private List<AiChatMessageDto> tailMessagesByRounds(List<AiChatMessageDto> source, int rounds) {
- /** 鎸夆�滅敤鎴峰彂瑷�杞鈥濊鍓渶杩戞秷鎭紝鑰屼笉鏄畝鍗曟寜鏉℃暟鎴柇銆� */
- if (Cools.isEmpty(source) || rounds <= 0) {
- return List.of();
- }
- int userCount = 0;
- int startIndex = source.size();
- for (int i = source.size() - 1; i >= 0; i--) {
- AiChatMessageDto item = source.get(i);
- startIndex = i;
- if (item != null && "user".equalsIgnoreCase(item.getRole())) {
- userCount++;
- if (userCount >= rounds) {
- break;
- }
- }
- }
- return new ArrayList<>(source.subList(Math.max(0, startIndex), source.size()));
- }
-
- private List<AiChatMessage> tailMessageRecordsByRounds(List<AiChatMessage> source, int rounds) {
- if (Cools.isEmpty(source) || rounds <= 0) {
- return List.of();
- }
- int userCount = 0;
- int startIndex = source.size();
- for (int i = source.size() - 1; i >= 0; i--) {
- AiChatMessage item = source.get(i);
- startIndex = i;
- if (item != null && "user".equalsIgnoreCase(item.getRole())) {
- userCount++;
- if (userCount >= rounds) {
- break;
- }
- }
- }
- return new ArrayList<>(source.subList(Math.max(0, startIndex), source.size()));
- }
-
- private String buildMemorySummary(List<AiChatMessageDto> historyMessages) {
- /** 涓鸿緝鏃╁巻鍙茬敓鎴愬彲鐩存帴鎻掑叆绯荤粺娑堟伅鐨勬枃鏈憳瑕併�� */
- StringBuilder builder = new StringBuilder("杈冩棭瀵硅瘽鎽樿:\n");
- for (AiChatMessageDto item : historyMessages) {
- if (item == null || !StringUtils.hasText(item.getContent())) {
- continue;
- }
- String prefix = "assistant".equalsIgnoreCase(item.getRole()) ? "- AI: " : "- 鐢ㄦ埛: ";
- String content = compactText(item.getContent(), 120);
- if (!StringUtils.hasText(content)) {
- continue;
- }
- builder.append(prefix).append(content).append("\n");
- if (builder.length() >= AiDefaults.MEMORY_SUMMARY_MAX_LENGTH) {
- break;
- }
- }
- return compactText(builder.toString(), AiDefaults.MEMORY_SUMMARY_MAX_LENGTH);
- }
-
- private String buildMemoryFacts(List<AiChatMessageDto> messages) {
- /** 浠庢渶杩戠敤鎴峰叧娉ㄧ偣涓彁鐐煎叧閿簨瀹烇紝浣滀负杞婚噺鎸佷箙璁板繂銆� */
- if (Cools.isEmpty(messages)) {
- return null;
- }
- StringBuilder builder = new StringBuilder("鍏抽敭浜嬪疄:\n");
- int userFacts = 0;
- for (int i = messages.size() - 1; i >= 0 && userFacts < 4; i--) {
- AiChatMessageDto item = messages.get(i);
- if (item == null || !"user".equalsIgnoreCase(item.getRole()) || !StringUtils.hasText(item.getContent())) {
- continue;
- }
- builder.append("- 鐢ㄦ埛鍏虫敞: ").append(compactText(item.getContent(), 100)).append("\n");
- userFacts++;
- }
- return userFacts == 0 ? null : compactText(builder.toString(), AiDefaults.MEMORY_FACTS_MAX_LENGTH);
- }
-
- private String compactText(String content, int maxLength) {
- if (!StringUtils.hasText(content)) {
- return null;
- }
- String normalized = content.trim()
- .replace("\r", " ")
- .replace("\n", " ")
- .replaceAll("\\s+", " ");
- return normalized.length() > maxLength ? normalized.substring(0, maxLength) : normalized;
- }
-
- private void ensureIdentity(Long userId, Long tenantId) {
- if (userId == null) {
- throw new CoolException("褰撳墠鐧诲綍鐢ㄦ埛涓嶅瓨鍦�");
- }
- if (tenantId == null) {
- throw new CoolException("褰撳墠绉熸埛涓嶅瓨鍦�");
- }
- }
-
- private String requirePromptCode(String promptCode) {
- if (!StringUtils.hasText(promptCode)) {
- throw new CoolException("Prompt 缂栫爜涓嶈兘涓虹┖");
- }
- return promptCode;
+ aiConversationCommandService.retainLatestRound(userId, tenantId, sessionId);
}
}
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java
index 320421e..1d6da64 100644
--- a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java
@@ -1,79 +1,31 @@
package com.vincent.rsf.server.ai.service.impl;
-import com.fasterxml.jackson.databind.ObjectMapper;
-import com.vincent.rsf.framework.common.Cools;
-import com.vincent.rsf.framework.exception.CoolException;
import com.vincent.rsf.server.ai.config.AiDefaults;
-import com.vincent.rsf.server.ai.dto.AiChatDoneDto;
-import com.vincent.rsf.server.ai.dto.AiChatErrorDto;
import com.vincent.rsf.server.ai.dto.AiChatMemoryDto;
-import com.vincent.rsf.server.ai.dto.AiChatMessageDto;
import com.vincent.rsf.server.ai.dto.AiChatModelOptionDto;
import com.vincent.rsf.server.ai.dto.AiChatRequest;
import com.vincent.rsf.server.ai.dto.AiChatRuntimeDto;
-import com.vincent.rsf.server.ai.dto.AiChatStatusDto;
import com.vincent.rsf.server.ai.dto.AiChatSessionDto;
import com.vincent.rsf.server.ai.dto.AiChatSessionPinRequest;
import com.vincent.rsf.server.ai.dto.AiChatSessionRenameRequest;
-import com.vincent.rsf.server.ai.dto.AiChatThinkingEventDto;
-import com.vincent.rsf.server.ai.dto.AiChatToolEventDto;
import com.vincent.rsf.server.ai.dto.AiResolvedConfig;
-import com.vincent.rsf.server.ai.entity.AiCallLog;
-import com.vincent.rsf.server.ai.entity.AiParam;
-import com.vincent.rsf.server.ai.entity.AiPrompt;
-import com.vincent.rsf.server.ai.entity.AiChatSession;
-import com.vincent.rsf.server.ai.enums.AiErrorCategory;
-import com.vincent.rsf.server.ai.exception.AiChatException;
-import com.vincent.rsf.server.ai.service.AiCallLogService;
-import com.vincent.rsf.server.ai.service.AiChatService;
import com.vincent.rsf.server.ai.service.AiChatMemoryService;
+import com.vincent.rsf.server.ai.service.AiChatService;
import com.vincent.rsf.server.ai.service.AiConfigResolverService;
import com.vincent.rsf.server.ai.service.AiParamService;
-import com.vincent.rsf.server.ai.service.MountedToolCallback;
-import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory;
-import io.micrometer.observation.ObservationRegistry;
+import com.vincent.rsf.server.ai.service.impl.chat.AiChatOrchestrator;
+import com.vincent.rsf.server.ai.service.impl.chat.AiChatRuntimeAssembler;
+import com.vincent.rsf.server.ai.store.AiConversationCacheStore;
import lombok.RequiredArgsConstructor;
-import lombok.extern.slf4j.Slf4j;
-import org.springframework.ai.chat.messages.AssistantMessage;
-import org.springframework.ai.chat.messages.Message;
-import org.springframework.ai.chat.messages.SystemMessage;
-import org.springframework.ai.chat.messages.UserMessage;
-import org.springframework.ai.chat.metadata.ChatResponseMetadata;
-import org.springframework.ai.chat.metadata.Usage;
-import org.springframework.ai.chat.model.ChatResponse;
-import org.springframework.ai.chat.model.ToolContext;
-import org.springframework.ai.chat.prompt.Prompt;
-import org.springframework.ai.model.tool.DefaultToolCallingManager;
-import org.springframework.ai.model.tool.ToolCallingManager;
-import org.springframework.ai.openai.OpenAiChatModel;
-import org.springframework.ai.openai.OpenAiChatOptions;
-import org.springframework.ai.openai.api.OpenAiApi;
-import org.springframework.ai.tool.ToolCallback;
-import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor;
-import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver;
-import org.springframework.ai.util.json.schema.SchemaType;
-import org.springframework.context.support.GenericApplicationContext;
-import org.springframework.http.MediaType;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service;
-import org.springframework.util.StringUtils;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
-import reactor.core.publisher.Flux;
-import java.io.IOException;
-import java.time.Instant;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.LinkedHashMap;
import java.util.List;
-import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
-import java.util.concurrent.atomic.AtomicReference;
-import java.util.concurrent.atomic.AtomicLong;
-@Slf4j
@Service
@RequiredArgsConstructor
public class AiChatServiceImpl implements AiChatService {
@@ -81,31 +33,22 @@
private final AiConfigResolverService aiConfigResolverService;
private final AiChatMemoryService aiChatMemoryService;
private final AiParamService aiParamService;
- private final McpMountRuntimeFactory mcpMountRuntimeFactory;
- private final AiCallLogService aiCallLogService;
- private final AiRedisSupport aiRedisSupport;
- private final GenericApplicationContext applicationContext;
- private final ObservationRegistry observationRegistry;
- private final ObjectMapper objectMapper;
+ private final AiConversationCacheStore aiConversationCacheStore;
+ private final AiChatRuntimeAssembler aiChatRuntimeAssembler;
+ private final AiChatOrchestrator aiChatOrchestrator;
@Qualifier("aiChatTaskExecutor")
private final Executor aiChatTaskExecutor;
- /**
- * 鑾峰彇褰撳墠瀵硅瘽鎶藉眽鍒濆鍖栨墍闇�鐨勮繍琛屾椂鏁版嵁銆�
- * 璇ユ柟娉曚笉浼氳Е鍙戞ā鍨嬭皟鐢紝鑰屾槸鎶婇厤缃В鏋愮粨鏋滃拰浼氳瘽璁板繂鑱氬悎鎴愬墠绔竴娆℃覆鏌撴墍闇�鐨勫揩鐓с��
- */
@Override
public AiChatRuntimeDto getRuntime(String promptCode, Long sessionId, Long aiParamId, Long userId, Long tenantId) {
AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId, aiParamId);
- Long runtimeCacheAiParamId = aiParamId;
- // runtime 鏄厤缃揩鐓у拰浼氳瘽璁板繂鐨勮仛鍚堣鍥撅紝鍗曠嫭缂撳瓨鑳藉噺灏戜竴娆¢〉闈㈣繘鍏ユ椂鐨勯噸澶嶆嫾瑁呫��
- AiChatRuntimeDto cached = aiRedisSupport.getRuntime(tenantId, userId, config.getPromptCode(), sessionId, runtimeCacheAiParamId);
+ AiChatRuntimeDto cached = aiConversationCacheStore.getRuntime(tenantId, userId, config.getPromptCode(), sessionId, aiParamId);
if (cached != null) {
return cached;
}
AiChatMemoryDto memory = aiChatMemoryService.getMemory(userId, tenantId, config.getPromptCode(), sessionId);
List<AiChatModelOptionDto> modelOptions = aiParamService.listChatModelOptions(tenantId);
- AiChatRuntimeDto runtime = buildRuntimeSnapshot(
+ AiChatRuntimeDto runtime = aiChatRuntimeAssembler.buildRuntimeSnapshot(
null,
memory.getSessionId(),
config,
@@ -115,20 +58,24 @@
List.of(),
memory
);
- aiRedisSupport.cacheRuntime(tenantId, userId, config.getPromptCode(), sessionId, runtimeCacheAiParamId, runtime);
+ aiConversationCacheStore.cacheRuntime(tenantId, userId, config.getPromptCode(), sessionId, aiParamId, runtime);
if (memory.getSessionId() != null && !Objects.equals(memory.getSessionId(), sessionId)) {
- aiRedisSupport.cacheRuntime(tenantId, userId, config.getPromptCode(), memory.getSessionId(), runtimeCacheAiParamId, runtime);
+ aiConversationCacheStore.cacheRuntime(tenantId, userId, config.getPromptCode(), memory.getSessionId(), aiParamId, runtime);
}
return runtime;
}
- /**
- * 鏌ヨ鎸囧畾 Prompt 鍦烘櫙涓嬬殑鍘嗗彶浼氳瘽鎽樿鍒楄〃銆�
- */
@Override
public List<AiChatSessionDto> listSessions(String promptCode, String keyword, Long userId, Long tenantId) {
AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId);
return aiChatMemoryService.listSessions(userId, tenantId, config.getPromptCode(), keyword);
+ }
+
+ @Override
+ public SseEmitter stream(AiChatRequest request, Long userId, Long tenantId) {
+ SseEmitter emitter = new SseEmitter(AiDefaults.SSE_TIMEOUT_MS);
+ CompletableFuture.runAsync(() -> aiChatOrchestrator.executeStream(request, userId, tenantId, emitter), aiChatTaskExecutor);
+ return emitter;
}
@Override
@@ -154,911 +101,5 @@
@Override
public void retainLatestRound(Long sessionId, Long userId, Long tenantId) {
aiChatMemoryService.retainLatestRound(userId, tenantId, sessionId);
- }
-
- /**
- * 鍚姩涓�娆℃柊鐨� SSE 瀵硅瘽娴併��
- * 鎺у埗绾跨▼绔嬪嵆杩斿洖 emitter锛岀湡姝g殑妯″瀷璋冪敤涓庡伐鍏锋墽琛屼氦缁� AI 涓撶敤绾跨▼姹犲紓姝ュ鐞嗐��
- */
- @Override
- public SseEmitter stream(AiChatRequest request, Long userId, Long tenantId) {
- SseEmitter emitter = new SseEmitter(AiDefaults.SSE_TIMEOUT_MS);
- CompletableFuture.runAsync(() -> doStream(request, userId, tenantId, emitter), aiChatTaskExecutor);
- return emitter;
- }
-
- private void doStream(AiChatRequest request, Long userId, Long tenantId, SseEmitter emitter) {
- /**
- * AI 瀵硅瘽鐨勬牳蹇冩墽琛岄摼璺細
- * 1. 鏍¢獙韬唤鍜岃В鏋愮鎴烽厤缃�
- * 2. 瑙f瀽鎴栧垱寤轰細璇濓紝鍔犺浇璁板繂
- * 3. 鍔ㄦ�佹寕杞� MCP 宸ュ叿
- * 4. 鍙戣捣妯″瀷娴佸紡/闈炴祦寮忚皟鐢�
- * 5. 鎸佷箙鍖栨湰杞秷鎭紝杈撳嚭 SSE 浜嬩欢骞惰褰曞璁℃棩蹇�
- */
- String requestId = request.getRequestId();
- long startedAt = System.currentTimeMillis();
- AtomicReference<Long> firstTokenAtRef = new AtomicReference<>();
- AtomicLong toolCallSequence = new AtomicLong(0);
- AtomicLong toolSuccessCount = new AtomicLong(0);
- AtomicLong toolFailureCount = new AtomicLong(0);
- Long sessionId = request.getSessionId();
- Long callLogId = null;
- String model = null;
- String resolvedPromptCode = request.getPromptCode();
- ThinkingTraceEmitter thinkingTraceEmitter = null;
- try {
- ensureIdentity(userId, tenantId);
- AiResolvedConfig config = resolveConfig(request, tenantId);
- List<AiChatModelOptionDto> modelOptions = aiParamService.listChatModelOptions(tenantId);
- resolvedPromptCode = config.getPromptCode();
- if (!aiRedisSupport.allowChatRequest(tenantId, userId, config.getPromptCode())) {
- throw buildAiException("AI_RATE_LIMITED", AiErrorCategory.REQUEST, "RATE_LIMIT",
- "褰撳墠鎻愰棶杩囦簬棰戠箒锛岃绋嶅悗鍐嶈瘯", null);
- }
- final String resolvedModel = config.getAiParam().getModel();
- model = resolvedModel;
- AiChatSession session = resolveSession(request, userId, tenantId, config.getPromptCode());
- sessionId = session.getId();
- // 娴佺姸鎬佽惤 Redis 鐨勭洰鏍囨槸缁欏瀹炰緥鍜屽悗缁繍缁存煡璇㈢暀缁熶竴鍏ュ彛锛屼笉鏇夸唬鏁版嵁搴撴棩蹇椼��
- aiRedisSupport.markStreamState(requestId, tenantId, userId, sessionId, config.getPromptCode(), "RUNNING", null);
- AiChatMemoryDto memory = loadMemory(userId, tenantId, config.getPromptCode(), session.getId());
- List<AiChatMessageDto> mergedMessages = mergeMessages(memory.getShortMemoryMessages(), request.getMessages());
- AiCallLog callLog = aiCallLogService.startCallLog(
- requestId,
- session.getId(),
- userId,
- tenantId,
- config.getPromptCode(),
- config.getPrompt().getName(),
- config.getAiParam().getModel(),
- config.getMcpMounts().size(),
- config.getMcpMounts().size(),
- config.getMcpMounts().stream().map(item -> item.getName()).toList()
- );
- callLogId = callLog.getId();
- try (McpMountRuntimeFactory.McpMountRuntime runtime = createRuntime(config, userId)) {
- emitStrict(emitter, "start", buildRuntimeSnapshot(
- requestId,
- session.getId(),
- config,
- modelOptions,
- runtime.getMountedCount(),
- runtime.getMountedNames(),
- runtime.getErrors(),
- memory
- ));
- emitSafely(emitter, "status", AiChatStatusDto.builder()
- .requestId(requestId)
- .sessionId(session.getId())
- .status("STARTED")
- .model(resolvedModel)
- .timestamp(Instant.now().toEpochMilli())
- .elapsedMs(0L)
- .build());
- log.info("AI chat started, requestId={}, userId={}, tenantId={}, sessionId={}, model={}",
- requestId, userId, tenantId, session.getId(), resolvedModel);
- thinkingTraceEmitter = new ThinkingTraceEmitter(emitter, requestId, session.getId());
- thinkingTraceEmitter.startAnalyze();
-
- ThinkingTraceEmitter activeThinkingTraceEmitter = thinkingTraceEmitter;
- ToolCallback[] observableToolCallbacks = wrapToolCallbacks(
- runtime.getToolCallbacks(), emitter, requestId, session.getId(), toolCallSequence,
- toolSuccessCount, toolFailureCount, callLogId, userId, tenantId, activeThinkingTraceEmitter
- );
- Prompt prompt = new Prompt(
- buildPromptMessages(memory, mergedMessages, config.getPrompt(), request.getMetadata()),
- buildChatOptions(config.getAiParam(), observableToolCallbacks, userId, tenantId,
- requestId, session.getId(), request.getMetadata())
- );
- OpenAiChatModel chatModel = createChatModel(config.getAiParam());
- if (Boolean.FALSE.equals(config.getAiParam().getStreamingEnabled())) {
- ChatResponse response = invokeChatCall(chatModel, prompt);
- String content = extractContent(response);
- aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), content);
- if (StringUtils.hasText(content)) {
- markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt, activeThinkingTraceEmitter);
- emitStrict(emitter, "delta", buildMessagePayload("requestId", requestId, "content", content));
- }
- activeThinkingTraceEmitter.completeCurrentPhase();
- emitDone(emitter, requestId, response.getMetadata(), config.getAiParam().getModel(), session.getId(), startedAt, firstTokenAtRef.get());
- emitSafely(emitter, "status", buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get()));
- aiCallLogService.completeCallLog(
- callLogId,
- "COMPLETED",
- System.currentTimeMillis() - startedAt,
- resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()),
- response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getPromptTokens(),
- response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getCompletionTokens(),
- response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getTotalTokens(),
- toolSuccessCount.get(),
- toolFailureCount.get()
- );
- aiRedisSupport.markStreamState(requestId, tenantId, userId, session.getId(), config.getPromptCode(), "COMPLETED", null);
- log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}",
- requestId, session.getId(), System.currentTimeMillis() - startedAt, resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()));
- emitter.complete();
- return;
- }
-
- Flux<ChatResponse> responseFlux = invokeChatStream(chatModel, prompt);
- AtomicReference<ChatResponseMetadata> lastMetadata = new AtomicReference<>();
- StringBuilder assistantContent = new StringBuilder();
- try {
- responseFlux.doOnNext(response -> {
- lastMetadata.set(response.getMetadata());
- String content = extractContent(response);
- if (StringUtils.hasText(content)) {
- markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt, activeThinkingTraceEmitter);
- assistantContent.append(content);
- emitStrict(emitter, "delta", buildMessagePayload("requestId", requestId, "content", content));
- }
- })
- .blockLast();
- } catch (Exception e) {
- throw buildAiException("AI_MODEL_STREAM_ERROR", AiErrorCategory.MODEL, "MODEL_STREAM",
- e == null ? "AI 妯″瀷娴佸紡璋冪敤澶辫触" : e.getMessage(), e);
- }
- aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), assistantContent.toString());
- activeThinkingTraceEmitter.completeCurrentPhase();
- emitDone(emitter, requestId, lastMetadata.get(), config.getAiParam().getModel(), session.getId(), startedAt, firstTokenAtRef.get());
- emitSafely(emitter, "status", buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get()));
- aiCallLogService.completeCallLog(
- callLogId,
- "COMPLETED",
- System.currentTimeMillis() - startedAt,
- resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()),
- lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getPromptTokens(),
- lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getCompletionTokens(),
- lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getTotalTokens(),
- toolSuccessCount.get(),
- toolFailureCount.get()
- );
- aiRedisSupport.markStreamState(requestId, tenantId, userId, session.getId(), config.getPromptCode(), "COMPLETED", null);
- log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}",
- requestId, session.getId(), System.currentTimeMillis() - startedAt, resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()));
- emitter.complete();
- }
- } catch (AiChatException e) {
- handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(), e,
- callLogId, toolSuccessCount.get(), toolFailureCount.get(), thinkingTraceEmitter,
- tenantId, userId, resolvedPromptCode);
- } catch (Exception e) {
- handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(),
- buildAiException("AI_INTERNAL_ERROR", AiErrorCategory.INTERNAL, "INTERNAL",
- e == null ? "AI 瀵硅瘽澶辫触" : e.getMessage(), e),
- callLogId, toolSuccessCount.get(), toolFailureCount.get(), thinkingTraceEmitter,
- tenantId, userId, resolvedPromptCode);
- } finally {
- log.debug("AI chat stream finished, requestId={}", requestId);
- }
- }
-
- private void ensureIdentity(Long userId, Long tenantId) {
- if (userId == null) {
- throw buildAiException("AI_AUTH_USER_MISSING", AiErrorCategory.AUTH, "AUTH_VALIDATE", "褰撳墠鐧诲綍鐢ㄦ埛涓嶅瓨鍦�", null);
- }
- if (tenantId == null) {
- throw buildAiException("AI_AUTH_TENANT_MISSING", AiErrorCategory.AUTH, "AUTH_VALIDATE", "褰撳墠绉熸埛涓嶅瓨鍦�", null);
- }
- }
-
- private AiResolvedConfig resolveConfig(AiChatRequest request, Long tenantId) {
- /** 鎶婅姹傞噷鐨� Prompt 鍦烘櫙瑙f瀽鎴愪竴浠藉彲鐩存帴鎵ц鐨� AI 閰嶇疆銆� */
- try {
- return aiConfigResolverService.resolve(request.getPromptCode(), tenantId, request.getAiParamId());
- } catch (Exception e) {
- throw buildAiException("AI_CONFIG_RESOLVE_ERROR", AiErrorCategory.CONFIG, "CONFIG_RESOLVE",
- e == null ? "AI 閰嶇疆瑙f瀽澶辫触" : e.getMessage(), e);
- }
- }
-
- private AiChatRuntimeDto buildRuntimeSnapshot(String requestId, Long sessionId, AiResolvedConfig config,
- List<AiChatModelOptionDto> modelOptions, Integer mountedMcpCount,
- List<String> mountedMcpNames, List<String> mountErrors,
- AiChatMemoryDto memory) {
- return AiChatRuntimeDto.builder()
- .requestId(requestId)
- .sessionId(sessionId)
- .aiParamId(config.getAiParam().getId())
- .promptCode(config.getPromptCode())
- .promptName(config.getPrompt().getName())
- .model(config.getAiParam().getModel())
- .modelOptions(modelOptions)
- .configuredMcpCount(config.getMcpMounts().size())
- .mountedMcpCount(mountedMcpCount)
- .mountedMcpNames(mountedMcpNames)
- .mountErrors(mountErrors)
- .memorySummary(memory.getMemorySummary())
- .memoryFacts(memory.getMemoryFacts())
- .recentMessageCount(memory.getRecentMessageCount())
- .persistedMessages(memory.getPersistedMessages())
- .build();
- }
-
- private AiChatSession resolveSession(AiChatRequest request, Long userId, Long tenantId, String promptCode) {
- /** 鏍规嵁 sessionId 澶嶇敤鍘嗗彶浼氳瘽锛屾垨鍦ㄩ娆℃彁闂椂鍒涘缓鏂颁細璇濄�� */
- try {
- return aiChatMemoryService.resolveSession(userId, tenantId, promptCode, request.getSessionId(), resolveTitleSeed(request.getMessages()));
- } catch (Exception e) {
- throw buildAiException("AI_SESSION_RESOLVE_ERROR", AiErrorCategory.REQUEST, "SESSION_RESOLVE",
- e == null ? "AI 浼氳瘽瑙f瀽澶辫触" : e.getMessage(), e);
- }
- }
-
- private AiChatMemoryDto loadMemory(Long userId, Long tenantId, String promptCode, Long sessionId) {
- /** 璇诲彇浼氳瘽鐨勭煭鏈熻蹇嗐�佹憳瑕佽蹇嗗拰浜嬪疄璁板繂锛屼緵妯″瀷缁勮涓婁笅鏂囥�� */
- try {
- return aiChatMemoryService.getMemory(userId, tenantId, promptCode, sessionId);
- } catch (Exception e) {
- throw buildAiException("AI_MEMORY_LOAD_ERROR", AiErrorCategory.REQUEST, "MEMORY_LOAD",
- e == null ? "AI 浼氳瘽璁板繂鍔犺浇澶辫触" : e.getMessage(), e);
- }
- }
-
- private McpMountRuntimeFactory.McpMountRuntime createRuntime(AiResolvedConfig config, Long userId) {
- /** 鎸夐厤缃腑鐨� MCP 鎸傝浇璁板綍鏋勯�犳湰杞璇濅笓灞炵殑宸ュ叿杩愯鏃躲�� */
- try {
- return mcpMountRuntimeFactory.create(config.getMcpMounts(), userId);
- } catch (Exception e) {
- throw buildAiException("AI_MCP_MOUNT_ERROR", AiErrorCategory.MCP, "MCP_MOUNT",
- e == null ? "MCP 鎸傝浇澶辫触" : e.getMessage(), e);
- }
- }
-
- private ChatResponse invokeChatCall(OpenAiChatModel chatModel, Prompt prompt) {
- try {
- return chatModel.call(prompt);
- } catch (Exception e) {
- throw buildAiException("AI_MODEL_CALL_ERROR", AiErrorCategory.MODEL, "MODEL_CALL",
- e == null ? "AI 妯″瀷璋冪敤澶辫触" : e.getMessage(), e);
- }
- }
-
- private Flux<ChatResponse> invokeChatStream(OpenAiChatModel chatModel, Prompt prompt) {
- try {
- return chatModel.stream(prompt);
- } catch (Exception e) {
- throw buildAiException("AI_MODEL_STREAM_ERROR", AiErrorCategory.MODEL, "MODEL_STREAM_INIT",
- e == null ? "AI 妯″瀷娴佸紡璋冪敤澶辫触" : e.getMessage(), e);
- }
- }
-
- private void markFirstToken(AtomicReference<Long> firstTokenAtRef, SseEmitter emitter, String requestId,
- Long sessionId, String model, long startedAt, ThinkingTraceEmitter thinkingTraceEmitter) {
- if (!firstTokenAtRef.compareAndSet(null, System.currentTimeMillis())) {
- return;
- }
- if (thinkingTraceEmitter != null) {
- thinkingTraceEmitter.startAnswer();
- }
- emitSafely(emitter, "status", AiChatStatusDto.builder()
- .requestId(requestId)
- .sessionId(sessionId)
- .status("FIRST_TOKEN")
- .model(model)
- .timestamp(Instant.now().toEpochMilli())
- .elapsedMs(System.currentTimeMillis() - startedAt)
- .firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()))
- .build());
- }
-
- private AiChatStatusDto buildTerminalStatus(String requestId, Long sessionId, String status, String model, long startedAt, Long firstTokenAt) {
- return AiChatStatusDto.builder()
- .requestId(requestId)
- .sessionId(sessionId)
- .status(status)
- .model(model)
- .timestamp(Instant.now().toEpochMilli())
- .elapsedMs(System.currentTimeMillis() - startedAt)
- .firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAt))
- .build();
- }
-
- private Long resolveFirstTokenLatency(long startedAt, Long firstTokenAt) {
- return firstTokenAt == null ? null : Math.max(0L, firstTokenAt - startedAt);
- }
-
- private void handleStreamFailure(SseEmitter emitter, String requestId, Long sessionId, String model, long startedAt,
- Long firstTokenAt, AiChatException exception, Long callLogId,
- long toolSuccessCount, long toolFailureCount,
- ThinkingTraceEmitter thinkingTraceEmitter,
- Long tenantId, Long userId, String promptCode) {
- if (isClientAbortException(exception)) {
- log.warn("AI chat aborted by client, requestId={}, sessionId={}, stage={}, message={}",
- requestId, sessionId, exception.getStage(), exception.getMessage());
- if (thinkingTraceEmitter != null) {
- thinkingTraceEmitter.markTerminated("ABORTED");
- }
- emitSafely(emitter, "status", buildTerminalStatus(requestId, sessionId, "ABORTED", model, startedAt, firstTokenAt));
- aiCallLogService.failCallLog(
- callLogId,
- "ABORTED",
- exception.getCategory().name(),
- exception.getStage(),
- exception.getMessage(),
- System.currentTimeMillis() - startedAt,
- resolveFirstTokenLatency(startedAt, firstTokenAt),
- toolSuccessCount,
- toolFailureCount
- );
- aiRedisSupport.markStreamState(requestId, tenantId, userId, sessionId, promptCode, "ABORTED", exception.getMessage());
- emitter.completeWithError(exception);
- return;
- }
- log.error("AI chat failed, requestId={}, sessionId={}, category={}, stage={}, message={}",
- requestId, sessionId, exception.getCategory(), exception.getStage(), exception.getMessage(), exception);
- if (thinkingTraceEmitter != null) {
- thinkingTraceEmitter.markTerminated("FAILED");
- }
- emitSafely(emitter, "status", buildTerminalStatus(requestId, sessionId, "FAILED", model, startedAt, firstTokenAt));
- emitSafely(emitter, "error", AiChatErrorDto.builder()
- .requestId(requestId)
- .sessionId(sessionId)
- .code(exception.getCode())
- .category(exception.getCategory().name())
- .stage(exception.getStage())
- .message(exception.getMessage())
- .timestamp(Instant.now().toEpochMilli())
- .build());
- aiCallLogService.failCallLog(
- callLogId,
- "FAILED",
- exception.getCategory().name(),
- exception.getStage(),
- exception.getMessage(),
- System.currentTimeMillis() - startedAt,
- resolveFirstTokenLatency(startedAt, firstTokenAt),
- toolSuccessCount,
- toolFailureCount
- );
- aiRedisSupport.markStreamState(requestId, tenantId, userId, sessionId, promptCode, "FAILED", exception.getMessage());
- emitter.completeWithError(exception);
- }
-
- private OpenAiChatModel createChatModel(AiParam aiParam) {
- OpenAiApi openAiApi = buildOpenAiApi(aiParam);
- ToolCallingManager toolCallingManager = DefaultToolCallingManager.builder()
- .observationRegistry(observationRegistry)
- .toolCallbackResolver(new SpringBeanToolCallbackResolver(applicationContext, SchemaType.OPEN_API_SCHEMA))
- .toolExecutionExceptionProcessor(new DefaultToolExecutionExceptionProcessor(false))
- .build();
- return new OpenAiChatModel(
- openAiApi,
- OpenAiChatOptions.builder()
- .model(aiParam.getModel())
- .temperature(aiParam.getTemperature())
- .topP(aiParam.getTopP())
- .maxTokens(aiParam.getMaxTokens())
- .streamUsage(true)
- .build(),
- toolCallingManager,
- org.springframework.retry.support.RetryTemplate.builder().maxAttempts(1).build(),
- observationRegistry
- );
- }
-
- private OpenAiApi buildOpenAiApi(AiParam aiParam) {
- return AiOpenAiApiSupport.buildOpenAiApi(aiParam);
- }
-
- private OpenAiChatOptions buildChatOptions(AiParam aiParam, ToolCallback[] toolCallbacks, Long userId, Long tenantId,
- String requestId, Long sessionId, Map<String, Object> metadata) {
- /**
- * 缁勮涓�娆¤亰澶╄皟鐢ㄧ殑鍏ㄩ儴妯″瀷閫夐」鍜� Tool Context銆�
- * Tool Context 浼氶�忎紶缁欏唴缃伐鍏峰拰澶栭儴 MCP锛屼繚璇佸伐鍏峰湪绉熸埛鍜屼細璇濊寖鍥村唴鎵ц銆�
- */
- if (userId == null) {
- throw buildAiException("AI_AUTH_USER_MISSING", AiErrorCategory.AUTH, "OPTIONS_BUILD", "褰撳墠鐧诲綍鐢ㄦ埛涓嶅瓨鍦�", null);
- }
- OpenAiChatOptions.Builder builder = OpenAiChatOptions.builder()
- .model(aiParam.getModel())
- .temperature(aiParam.getTemperature())
- .topP(aiParam.getTopP())
- .maxTokens(aiParam.getMaxTokens())
- .streamUsage(true)
- .user(String.valueOf(userId));
- if (!Cools.isEmpty(toolCallbacks)) {
- builder.toolCallbacks(Arrays.stream(toolCallbacks).toList());
- }
- Map<String, Object> toolContext = new LinkedHashMap<>();
- toolContext.put("userId", userId);
- toolContext.put("tenantId", tenantId);
- toolContext.put("requestId", requestId);
- toolContext.put("sessionId", sessionId);
- Map<String, String> metadataMap = new LinkedHashMap<>();
- if (metadata != null) {
- metadata.forEach((key, value) -> {
- String normalized = value == null ? "" : String.valueOf(value);
- metadataMap.put(key, normalized);
- toolContext.put(key, normalized);
- });
- }
- builder.toolContext(toolContext);
- if (!metadataMap.isEmpty()) {
- builder.metadata(metadataMap);
- }
- return builder.build();
- }
-
- private ToolCallback[] wrapToolCallbacks(ToolCallback[] toolCallbacks, SseEmitter emitter, String requestId,
- Long sessionId, AtomicLong toolCallSequence,
- AtomicLong toolSuccessCount, AtomicLong toolFailureCount,
- Long callLogId, Long userId, Long tenantId,
- ThinkingTraceEmitter thinkingTraceEmitter) {
- /** 缁欐墍鏈夊伐鍏峰洖璋冨涓婁竴灞傚彲瑙傛祴鍖呰锛岀敤浜庡疄鏃� SSE 杞ㄨ抗鍜屽璁℃棩蹇楄惤搴撱�� */
- if (Cools.isEmpty(toolCallbacks)) {
- return toolCallbacks;
- }
- List<ToolCallback> wrappedCallbacks = new ArrayList<>();
- for (ToolCallback callback : toolCallbacks) {
- if (callback == null) {
- continue;
- }
- wrappedCallbacks.add(new ObservableToolCallback(callback, emitter, requestId, sessionId, toolCallSequence,
- toolSuccessCount, toolFailureCount, callLogId, userId, tenantId, thinkingTraceEmitter));
- }
- return wrappedCallbacks.toArray(new ToolCallback[0]);
- }
-
- private List<Message> buildPromptMessages(AiChatMemoryDto memory, List<AiChatMessageDto> sourceMessages, AiPrompt aiPrompt, Map<String, Object> metadata) {
- /**
- * 缁勮鏈�缁堟彁浜ょ粰妯″瀷鐨勬秷鎭垪琛ㄣ��
- * 椤哄簭涓婂缁堟槸锛氱郴缁� Prompt -> 鍘嗗彶鎽樿 -> 鍏抽敭浜嬪疄 -> 鏈�杩戝璇� -> 褰撳墠鐢ㄦ埛杈撳叆銆�
- */
- if (Cools.isEmpty(sourceMessages)) {
- throw new CoolException("瀵硅瘽娑堟伅涓嶈兘涓虹┖");
- }
- List<Message> messages = new ArrayList<>();
- if (StringUtils.hasText(aiPrompt.getSystemPrompt())) {
- messages.add(new SystemMessage(aiPrompt.getSystemPrompt()));
- }
- if (memory != null && StringUtils.hasText(memory.getMemorySummary())) {
- messages.add(new SystemMessage("鍘嗗彶鎽樿:\n" + memory.getMemorySummary()));
- }
- if (memory != null && StringUtils.hasText(memory.getMemoryFacts())) {
- messages.add(new SystemMessage("鍏抽敭浜嬪疄:\n" + memory.getMemoryFacts()));
- }
- int lastUserIndex = -1;
- for (int i = 0; i < sourceMessages.size(); i++) {
- AiChatMessageDto item = sourceMessages.get(i);
- if (item != null && "user".equalsIgnoreCase(item.getRole())) {
- lastUserIndex = i;
- }
- }
- for (int i = 0; i < sourceMessages.size(); i++) {
- AiChatMessageDto item = sourceMessages.get(i);
- if (item == null || !StringUtils.hasText(item.getContent())) {
- continue;
- }
- String role = item.getRole() == null ? "user" : item.getRole().toLowerCase();
- if ("system".equals(role)) {
- continue;
- }
- String content = item.getContent();
- if ("user".equals(role) && i == lastUserIndex) {
- content = renderUserPrompt(aiPrompt.getUserPromptTemplate(), content, metadata);
- }
- if ("assistant".equals(role)) {
- messages.add(new AssistantMessage(content));
- } else {
- messages.add(new UserMessage(content));
- }
- }
- if (messages.stream().noneMatch(item -> item instanceof UserMessage)) {
- throw new CoolException("鑷冲皯闇�瑕佷竴鏉$敤鎴锋秷鎭�");
- }
- return messages;
- }
-
- private List<AiChatMessageDto> mergeMessages(List<AiChatMessageDto> persistedMessages, List<AiChatMessageDto> memoryMessages) {
- /** 鎶婅惤搴撳巻鍙蹭笌鏈疆鍓嶇鍐呭瓨澧為噺鍚堝苟鎴愭ā鍨嬪彲娑堣垂鐨勫畬鏁翠笂涓嬫枃銆� */
- List<AiChatMessageDto> merged = new ArrayList<>();
- if (!Cools.isEmpty(persistedMessages)) {
- merged.addAll(persistedMessages);
- }
- if (!Cools.isEmpty(memoryMessages)) {
- merged.addAll(memoryMessages);
- }
- if (merged.isEmpty()) {
- throw new CoolException("瀵硅瘽娑堟伅涓嶈兘涓虹┖");
- }
- return merged;
- }
-
- private String resolveTitleSeed(List<AiChatMessageDto> messages) {
- if (Cools.isEmpty(messages)) {
- throw new CoolException("瀵硅瘽娑堟伅涓嶈兘涓虹┖");
- }
- for (int i = messages.size() - 1; i >= 0; i--) {
- AiChatMessageDto item = messages.get(i);
- if (item != null && "user".equalsIgnoreCase(item.getRole()) && StringUtils.hasText(item.getContent())) {
- return item.getContent();
- }
- }
- throw new CoolException("鑷冲皯闇�瑕佷竴鏉$敤鎴锋秷鎭�");
- }
-
- private String renderUserPrompt(String userPromptTemplate, String content, Map<String, Object> metadata) {
- if (!StringUtils.hasText(userPromptTemplate)) {
- return content;
- }
- String rendered = userPromptTemplate
- .replace("{{input}}", content)
- .replace("{input}", content);
- if (metadata != null) {
- for (Map.Entry<String, Object> entry : metadata.entrySet()) {
- String value = entry.getValue() == null ? "" : String.valueOf(entry.getValue());
- rendered = rendered.replace("{{" + entry.getKey() + "}}", value);
- rendered = rendered.replace("{" + entry.getKey() + "}", value);
- }
- }
- if (Objects.equals(rendered, userPromptTemplate)) {
- return userPromptTemplate + "\n\n" + content;
- }
- return rendered;
- }
-
- private String extractContent(ChatResponse response) {
- if (response == null || response.getResult() == null || response.getResult().getOutput() == null) {
- return null;
- }
- return response.getResult().getOutput().getText();
- }
-
- private String summarizeToolPayload(String content, int maxLength) {
- if (!StringUtils.hasText(content)) {
- return null;
- }
- String normalized = content.trim()
- .replace("\r", " ")
- .replace("\n", " ")
- .replaceAll("\\s+", " ");
- return normalized.length() > maxLength ? normalized.substring(0, maxLength) : normalized;
- }
-
- private void emitDone(SseEmitter emitter, String requestId, ChatResponseMetadata metadata, String fallbackModel, Long sessionId, long startedAt, Long firstTokenAt) {
- /** 杈撳嚭瀵硅瘽瀹屾垚浜嬩欢锛岀粺涓�灏佽鑰楁椂銆侀鍖呭欢杩熷拰 token 鐢ㄩ噺銆� */
- Usage usage = metadata == null ? null : metadata.getUsage();
- emitStrict(emitter, "done", AiChatDoneDto.builder()
- .requestId(requestId)
- .sessionId(sessionId)
- .model(metadata != null && StringUtils.hasText(metadata.getModel()) ? metadata.getModel() : fallbackModel)
- .elapsedMs(System.currentTimeMillis() - startedAt)
- .firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAt))
- .promptTokens(usage == null ? null : usage.getPromptTokens())
- .completionTokens(usage == null ? null : usage.getCompletionTokens())
- .totalTokens(usage == null ? null : usage.getTotalTokens())
- .build());
- }
-
- private Map<String, String> buildMessagePayload(String... keyValues) {
- Map<String, String> payload = new LinkedHashMap<>();
- if (keyValues == null || keyValues.length == 0) {
- return payload;
- }
- if (keyValues.length % 2 != 0) {
- throw new CoolException("娑堟伅杞借嵎鍙傛暟蹇呴』鎴愬鍑虹幇");
- }
- for (int i = 0; i < keyValues.length; i += 2) {
- payload.put(keyValues[i], keyValues[i + 1] == null ? "" : keyValues[i + 1]);
- }
- return payload;
- }
-
- private void emitStrict(SseEmitter emitter, String eventName, Object payload) {
- /** 涓ユ牸鍙戦�� SSE 浜嬩欢锛涗竴鏃﹀彂閫佸け璐ワ紝鐩存帴涓婃姏涓烘祦寮忚緭鍑哄紓甯搞�� */
- try {
- String data = objectMapper.writeValueAsString(payload);
- emitter.send(SseEmitter.event()
- .name(eventName)
- .data(data, MediaType.APPLICATION_JSON));
- } catch (IOException e) {
- throw buildAiException("AI_SSE_EMIT_ERROR", AiErrorCategory.STREAM, "SSE_EMIT", "SSE 杈撳嚭澶辫触: " + e.getMessage(), e);
- }
- }
-
- private void emitSafely(SseEmitter emitter, String eventName, Object payload) {
- /** 灏濊瘯鍙戦�侀潪鍏抽敭浜嬩欢锛屽彂閫佸け璐ュ彧璁板綍鏃ュ織锛屼笉鎵撴柇涓诲璇濇祦绋嬨�� */
- try {
- emitStrict(emitter, eventName, payload);
- } catch (Exception e) {
- log.warn("AI SSE event emit skipped, eventName={}, message={}", eventName, e.getMessage());
- }
- }
-
- private AiChatException buildAiException(String code, AiErrorCategory category, String stage, String message, Throwable cause) {
- return new AiChatException(code, category, stage, message, cause);
- }
-
- private boolean isClientAbortException(Throwable throwable) {
- Throwable current = throwable;
- while (current != null) {
- String message = current.getMessage();
- if (message != null) {
- String normalized = message.toLowerCase();
- if (normalized.contains("broken pipe")
- || normalized.contains("connection reset")
- || normalized.contains("forcibly closed")
- || normalized.contains("abort")) {
- return true;
- }
- }
- current = current.getCause();
- }
- return false;
- }
-
- private class ThinkingTraceEmitter {
-
- private final SseEmitter emitter;
- private final String requestId;
- private final Long sessionId;
- private String currentPhase;
- private String currentStatus;
-
- private ThinkingTraceEmitter(SseEmitter emitter, String requestId, Long sessionId) {
- this.emitter = emitter;
- this.requestId = requestId;
- this.sessionId = sessionId;
- }
-
- private void startAnalyze() {
- if (currentPhase != null) {
- return;
- }
- currentPhase = "ANALYZE";
- currentStatus = "STARTED";
- emitThinkingEvent("ANALYZE", "STARTED", "姝e湪鍒嗘瀽闂",
- "宸叉帴鏀朵綘鐨勯棶棰橈紝姝e湪鐞嗚В鎰忓浘骞跺垽鏂槸鍚﹂渶瑕佽皟鐢ㄥ伐鍏枫��", null);
- }
-
- private void onToolStart(String toolName, String toolCallId) {
- switchPhase("TOOL_CALL", "STARTED", "姝e湪璋冪敤宸ュ叿", "宸插垽鏂渶瑕佽皟鐢ㄥ伐鍏凤紝姝e湪鏌ヨ鐩稿叧淇℃伅銆�", null);
- currentStatus = "UPDATED";
- emitThinkingEvent("TOOL_CALL", "UPDATED", "姝e湪璋冪敤宸ュ叿",
- "姝e湪璋冪敤宸ュ叿 " + safeLabel(toolName, "鏈煡宸ュ叿") + " 鑾峰彇鎵�闇�淇℃伅銆�", toolCallId);
- }
-
- private void onToolResult(String toolName, String toolCallId, boolean failed) {
- currentPhase = "TOOL_CALL";
- currentStatus = failed ? "FAILED" : "UPDATED";
- emitThinkingEvent("TOOL_CALL", failed ? "FAILED" : "UPDATED",
- failed ? "宸ュ叿璋冪敤澶辫触" : "宸ュ叿璋冪敤瀹屾垚",
- failed
- ? "宸ュ叿 " + safeLabel(toolName, "鏈煡宸ュ叿") + " 璋冪敤澶辫触锛屾鍦ㄨ瘎浼板け璐ュ奖鍝嶅苟鏁寸悊鍙敤淇℃伅銆�"
- : "宸ュ叿 " + safeLabel(toolName, "鏈煡宸ュ叿") + " 宸茶繑鍥炵粨鏋滐紝姝e湪缁х画鍒嗘瀽骞舵彁鐐煎叧閿俊鎭��",
- toolCallId);
- }
-
- private void startAnswer() {
- switchPhase("ANSWER", "STARTED", "姝e湪鏁寸悊绛旀", "宸插畬鎴愬垎鏋愶紝姝e湪缁勭粐鏈�缁堝洖澶嶅唴瀹广��", null);
- }
-
- private void completeCurrentPhase() {
- if (!StringUtils.hasText(currentPhase) || isTerminalStatus(currentStatus)) {
- return;
- }
- currentStatus = "COMPLETED";
- emitThinkingEvent(currentPhase, "COMPLETED", resolveCompleteTitle(currentPhase),
- resolveCompleteContent(currentPhase), null);
- }
-
- private void markTerminated(String terminalStatus) {
- if (!StringUtils.hasText(currentPhase) || isTerminalStatus(currentStatus)) {
- return;
- }
- currentStatus = terminalStatus;
- emitThinkingEvent(currentPhase, terminalStatus,
- "ABORTED".equals(terminalStatus) ? "鎬濊�冨凡涓" : "鎬濊�冨け璐�",
- "ABORTED".equals(terminalStatus)
- ? "鏈疆瀵硅瘽宸茶涓锛屾�濊�冭繃绋嬫彁鍓嶇粨鏉熴��"
- : "鏈疆瀵硅瘽鍦ㄧ敓鎴愮瓟妗堝墠澶辫触锛屽綋鍓嶆�濊�冭繃绋嬪凡鍋滄銆�",
- null);
- }
-
- private void switchPhase(String nextPhase, String nextStatus, String title, String content, String toolCallId) {
- if (!Objects.equals(currentPhase, nextPhase)) {
- completeCurrentPhase();
- }
- currentPhase = nextPhase;
- currentStatus = nextStatus;
- emitThinkingEvent(nextPhase, nextStatus, title, content, toolCallId);
- }
-
- private void emitThinkingEvent(String phase, String status, String title, String content, String toolCallId) {
- emitSafely(emitter, "thinking", AiChatThinkingEventDto.builder()
- .requestId(requestId)
- .sessionId(sessionId)
- .phase(phase)
- .status(status)
- .title(title)
- .content(content)
- .toolCallId(toolCallId)
- .timestamp(Instant.now().toEpochMilli())
- .build());
- }
-
- private boolean isTerminalStatus(String status) {
- return "COMPLETED".equals(status) || "FAILED".equals(status) || "ABORTED".equals(status);
- }
-
- private String resolveCompleteTitle(String phase) {
- if ("ANSWER".equals(phase)) {
- return "绛旀鏁寸悊瀹屾垚";
- }
- if ("TOOL_CALL".equals(phase)) {
- return "宸ュ叿鍒嗘瀽瀹屾垚";
- }
- return "闂鍒嗘瀽瀹屾垚";
- }
-
- private String resolveCompleteContent(String phase) {
- if ("ANSWER".equals(phase)) {
- return "鏈�缁堢瓟澶嶅凡鐢熸垚瀹屾垚銆�";
- }
- if ("TOOL_CALL".equals(phase)) {
- return "宸ュ叿璋冪敤闃舵宸茬粨鏉燂紝鐩稿叧淇℃伅宸叉暣鐞嗗畬姣曘��";
- }
- return "闂鎰忓浘鍜屽鐞嗘柟鍚戝凡鍒嗘瀽瀹屾垚銆�";
- }
-
- private String safeLabel(String value, String fallback) {
- return StringUtils.hasText(value) ? value : fallback;
- }
- }
-
- private class ObservableToolCallback implements ToolCallback {
-
- private final ToolCallback delegate;
- private final SseEmitter emitter;
- private final String requestId;
- private final Long sessionId;
- private final AtomicLong toolCallSequence;
- private final AtomicLong toolSuccessCount;
- private final AtomicLong toolFailureCount;
- private final Long callLogId;
- private final Long userId;
- private final Long tenantId;
- private final ThinkingTraceEmitter thinkingTraceEmitter;
-
- private ObservableToolCallback(ToolCallback delegate, SseEmitter emitter, String requestId,
- Long sessionId, AtomicLong toolCallSequence,
- AtomicLong toolSuccessCount, AtomicLong toolFailureCount,
- Long callLogId, Long userId, Long tenantId,
- ThinkingTraceEmitter thinkingTraceEmitter) {
- this.delegate = delegate;
- this.emitter = emitter;
- this.requestId = requestId;
- this.sessionId = sessionId;
- this.toolCallSequence = toolCallSequence;
- this.toolSuccessCount = toolSuccessCount;
- this.toolFailureCount = toolFailureCount;
- this.callLogId = callLogId;
- this.userId = userId;
- this.tenantId = tenantId;
- this.thinkingTraceEmitter = thinkingTraceEmitter;
- }
-
- @Override
- public org.springframework.ai.tool.definition.ToolDefinition getToolDefinition() {
- return delegate.getToolDefinition();
- }
-
- @Override
- public org.springframework.ai.tool.metadata.ToolMetadata getToolMetadata() {
- return delegate.getToolMetadata();
- }
-
- @Override
- public String call(String toolInput) {
- return call(toolInput, null);
- }
-
- @Override
- public String call(String toolInput, ToolContext toolContext) {
- /**
- * 宸ュ叿鎵ц瑙傛祴鍖呰鍣ㄣ��
- * 鍦ㄧ湡瀹炶皟鐢ㄥ墠鍚庡垎鍒彂閫� tool_start / tool_result / tool_error锛�
- * 鍚屾椂鎶婅皟鐢ㄦ憳瑕佸啓鍏� MCP 璋冪敤鏃ュ織琛ㄣ��
- */
- String toolName = delegate.getToolDefinition() == null ? "unknown" : delegate.getToolDefinition().name();
- String mountName = delegate instanceof MountedToolCallback ? ((MountedToolCallback) delegate).getMountName() : null;
- String toolCallId = requestId + "-tool-" + toolCallSequence.incrementAndGet();
- long startedAt = System.currentTimeMillis();
- // 杩欓噷鍙鍚屼竴 request 鍐呯殑閲嶅宸ュ叿璋冪敤鍋氱煭鏈熷鐢紝閬垮厤鎶婅法璇锋眰缁撴灉璇綋鎴愰�氱敤缂撳瓨銆�
- AiRedisSupport.CachedToolResult cachedToolResult = aiRedisSupport.getToolResult(tenantId, requestId, toolName, toolInput);
- if (cachedToolResult != null) {
- emitSafely(emitter, "tool_result", AiChatToolEventDto.builder()
- .requestId(requestId)
- .sessionId(sessionId)
- .toolCallId(toolCallId)
- .toolName(toolName)
- .mountName(mountName)
- .status(cachedToolResult.isSuccess() ? "COMPLETED" : "FAILED")
- .inputSummary(summarizeToolPayload(toolInput, 400))
- .outputSummary(summarizeToolPayload(cachedToolResult.getOutput(), 600))
- .errorMessage(cachedToolResult.getErrorMessage())
- .durationMs(0L)
- .timestamp(System.currentTimeMillis())
- .build());
- if (thinkingTraceEmitter != null) {
- thinkingTraceEmitter.onToolResult(toolName, toolCallId, !cachedToolResult.isSuccess());
- }
- if (cachedToolResult.isSuccess()) {
- toolSuccessCount.incrementAndGet();
- aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
- "COMPLETED", summarizeToolPayload(toolInput, 400), summarizeToolPayload(cachedToolResult.getOutput(), 600),
- null, 0L, userId, tenantId);
- return cachedToolResult.getOutput();
- }
- toolFailureCount.incrementAndGet();
- aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
- "FAILED", summarizeToolPayload(toolInput, 400), null, cachedToolResult.getErrorMessage(),
- 0L, userId, tenantId);
- throw new CoolException(cachedToolResult.getErrorMessage());
- }
- if (thinkingTraceEmitter != null) {
- thinkingTraceEmitter.onToolStart(toolName, toolCallId);
- }
- emitSafely(emitter, "tool_start", AiChatToolEventDto.builder()
- .requestId(requestId)
- .sessionId(sessionId)
- .toolCallId(toolCallId)
- .toolName(toolName)
- .mountName(mountName)
- .status("STARTED")
- .inputSummary(summarizeToolPayload(toolInput, 400))
- .timestamp(startedAt)
- .build());
- try {
- String output = toolContext == null ? delegate.call(toolInput) : delegate.call(toolInput, toolContext);
- long durationMs = System.currentTimeMillis() - startedAt;
- emitSafely(emitter, "tool_result", AiChatToolEventDto.builder()
- .requestId(requestId)
- .sessionId(sessionId)
- .toolCallId(toolCallId)
- .toolName(toolName)
- .mountName(mountName)
- .status("COMPLETED")
- .inputSummary(summarizeToolPayload(toolInput, 400))
- .outputSummary(summarizeToolPayload(output, 600))
- .durationMs(durationMs)
- .timestamp(System.currentTimeMillis())
- .build());
- if (thinkingTraceEmitter != null) {
- thinkingTraceEmitter.onToolResult(toolName, toolCallId, false);
- }
- aiRedisSupport.cacheToolResult(tenantId, requestId, toolName, toolInput, true, output, null);
- toolSuccessCount.incrementAndGet();
- aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
- "COMPLETED", summarizeToolPayload(toolInput, 400), summarizeToolPayload(output, 600),
- null, durationMs, userId, tenantId);
- return output;
- } catch (RuntimeException e) {
- long durationMs = System.currentTimeMillis() - startedAt;
- emitSafely(emitter, "tool_error", AiChatToolEventDto.builder()
- .requestId(requestId)
- .sessionId(sessionId)
- .toolCallId(toolCallId)
- .toolName(toolName)
- .mountName(mountName)
- .status("FAILED")
- .inputSummary(summarizeToolPayload(toolInput, 400))
- .errorMessage(e.getMessage())
- .durationMs(durationMs)
- .timestamp(System.currentTimeMillis())
- .build());
- if (thinkingTraceEmitter != null) {
- thinkingTraceEmitter.onToolResult(toolName, toolCallId, true);
- }
- aiRedisSupport.cacheToolResult(tenantId, requestId, toolName, toolInput, false, null, e.getMessage());
- toolFailureCount.incrementAndGet();
- aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
- "FAILED", summarizeToolPayload(toolInput, 400), null, e.getMessage(),
- durationMs, userId, tenantId);
- throw e;
- }
- }
}
}
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiConfigResolverServiceImpl.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiConfigResolverServiceImpl.java
index 7aa47b0..8173b21 100644
--- a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiConfigResolverServiceImpl.java
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiConfigResolverServiceImpl.java
@@ -2,6 +2,7 @@
import com.vincent.rsf.server.ai.config.AiDefaults;
import com.vincent.rsf.server.ai.dto.AiResolvedConfig;
+import com.vincent.rsf.server.ai.store.AiConfigCacheStore;
import com.vincent.rsf.server.ai.service.AiConfigResolverService;
import com.vincent.rsf.server.ai.service.AiMcpMountService;
import com.vincent.rsf.server.ai.service.AiParamService;
@@ -18,7 +19,7 @@
private final AiParamService aiParamService;
private final AiPromptService aiPromptService;
private final AiMcpMountService aiMcpMountService;
- private final AiRedisSupport aiRedisSupport;
+ private final AiConfigCacheStore aiConfigCacheStore;
/**
* 鎸夌鎴疯В鏋愪竴娆″畬鏁寸殑 AI 杩愯閰嶇疆銆�
@@ -37,7 +38,7 @@
}
String finalPromptCode = StringUtils.hasText(promptCode) ? promptCode : AiDefaults.DEFAULT_PROMPT_CODE;
// 閰嶇疆瑙f瀽鏄涓叆鍙e叡浜殑鐑偣璺緞锛屽懡涓紦瀛樻椂鍙互閬垮厤涓夊紶閰嶇疆琛ㄧ殑閲嶅鏌ヨ銆�
- AiResolvedConfig cached = aiRedisSupport.getResolvedConfig(tenantId, finalPromptCode, aiParamId);
+ AiResolvedConfig cached = aiConfigCacheStore.getResolvedConfig(tenantId, finalPromptCode, aiParamId);
if (cached != null) {
return cached;
}
@@ -47,7 +48,7 @@
.prompt(aiPromptService.getActivePrompt(finalPromptCode, tenantId))
.mcpMounts(aiMcpMountService.listActiveMounts(tenantId))
.build();
- aiRedisSupport.cacheResolvedConfig(tenantId, finalPromptCode, aiParamId, resolvedConfig);
+ aiConfigCacheStore.cacheResolvedConfig(tenantId, finalPromptCode, aiParamId, resolvedConfig);
return resolvedConfig;
}
}
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiMcpMountServiceImpl.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiMcpMountServiceImpl.java
index 409c20f..8c96030 100644
--- a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiMcpMountServiceImpl.java
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiMcpMountServiceImpl.java
@@ -1,9 +1,7 @@
package com.vincent.rsf.server.ai.service.impl;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
-import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
-import com.fasterxml.jackson.databind.ObjectMapper;
import com.vincent.rsf.framework.exception.CoolException;
import com.vincent.rsf.server.ai.config.AiDefaults;
import com.vincent.rsf.server.ai.dto.AiMcpConnectivityTestDto;
@@ -12,30 +10,28 @@
import com.vincent.rsf.server.ai.dto.AiMcpToolTestRequest;
import com.vincent.rsf.server.ai.entity.AiMcpMount;
import com.vincent.rsf.server.ai.mapper.AiMcpMountMapper;
+import com.vincent.rsf.server.ai.service.impl.mcp.AiMcpAdminService;
+import com.vincent.rsf.server.ai.store.AiConfigCacheStore;
+import com.vincent.rsf.server.ai.store.AiConversationCacheStore;
+import com.vincent.rsf.server.ai.store.AiMcpCacheStore;
import com.vincent.rsf.server.ai.service.AiMcpMountService;
import com.vincent.rsf.server.ai.service.BuiltinMcpToolRegistry;
-import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory;
import com.vincent.rsf.server.system.enums.StatusType;
import lombok.RequiredArgsConstructor;
-import org.springframework.ai.chat.model.ToolContext;
-import org.springframework.ai.tool.ToolCallback;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Date;
import java.util.List;
-import java.util.Map;
@Service("aiMcpMountService")
@RequiredArgsConstructor
public class AiMcpMountServiceImpl extends ServiceImpl<AiMcpMountMapper, AiMcpMount> implements AiMcpMountService {
private final BuiltinMcpToolRegistry builtinMcpToolRegistry;
- private final McpMountRuntimeFactory mcpMountRuntimeFactory;
- private final ObjectMapper objectMapper;
- private final AiRedisSupport aiRedisSupport;
+ private final AiMcpAdminService aiMcpAdminService;
+ private final AiMcpCacheStore aiMcpCacheStore;
+ private final AiConfigCacheStore aiConfigCacheStore;
+ private final AiConversationCacheStore aiConversationCacheStore;
/** 鏌ヨ鏌愪釜绉熸埛涓嬪綋鍓嶅惎鐢ㄧ殑 MCP 鎸傝浇鍒楄〃銆� */
@Override
@@ -66,7 +62,7 @@
if (aiMcpMount.getId() == null) {
throw new CoolException("MCP 鎸傝浇 ID 涓嶈兘涓虹┖");
}
- AiMcpMount current = requireMount(aiMcpMount.getId(), tenantId);
+ AiMcpMount current = aiMcpAdminService.requireMount(aiMcpMount.getId(), tenantId);
aiMcpMount.setTenantId(current.getTenantId());
ensureRequiredFields(aiMcpMount, tenantId);
}
@@ -77,68 +73,23 @@
*/
@Override
public List<AiMcpToolPreviewDto> previewTools(Long mountId, Long userId, Long tenantId) {
- AiMcpMount mount = requireMount(mountId, tenantId);
- // 宸ュ叿鐩綍棰勮鍒濆鍖栨垚鏈珮锛屼絾鍙樺寲棰戠巼浣庯紝閫傚悎鍋氱鐞嗙鐭紦瀛樸��
- List<AiMcpToolPreviewDto> cached = aiRedisSupport.getToolPreview(tenantId, mountId);
+ AiMcpMount mount = aiMcpAdminService.requireMount(mountId, tenantId);
+ List<AiMcpToolPreviewDto> cached = aiMcpCacheStore.getToolPreview(tenantId, mountId);
if (cached != null) {
return cached;
}
- long startedAt = System.currentTimeMillis();
- try (McpMountRuntimeFactory.McpMountRuntime runtime = mcpMountRuntimeFactory.create(List.of(mount), userId)) {
- List<AiMcpToolPreviewDto> tools = buildToolPreviewDtos(runtime.getToolCallbacks(),
- AiDefaults.MCP_TRANSPORT_BUILTIN.equals(mount.getTransportType())
- ? builtinMcpToolRegistry.listBuiltinToolCatalog(mount.getBuiltinCode())
- : List.of());
- if (!runtime.getErrors().isEmpty()) {
- String message = String.join("锛�", runtime.getErrors());
- updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY, message, System.currentTimeMillis() - startedAt);
- throw new CoolException(message);
- }
- updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_HEALTHY,
- "宸ュ叿瑙f瀽鎴愬姛锛屽叡 " + tools.size() + " 涓伐鍏�", System.currentTimeMillis() - startedAt);
- aiRedisSupport.cacheToolPreview(tenantId, mountId, tools);
- return tools;
- } catch (CoolException e) {
- throw e;
- } catch (Exception e) {
- updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY,
- "宸ュ叿瑙f瀽澶辫触: " + e.getMessage(), System.currentTimeMillis() - startedAt);
- throw new CoolException("鑾峰彇宸ュ叿鍒楄〃澶辫触: " + e.getMessage());
- }
+ List<AiMcpToolPreviewDto> tools = aiMcpAdminService.previewTools(mount, userId);
+ aiMcpCacheStore.cacheToolPreview(tenantId, mountId, tools);
+ return tools;
}
/** 瀵瑰凡淇濆瓨鐨勬寕杞藉仛鐪熷疄杩為�氭�ф祴璇曪紝骞舵妸缁撴灉鍥炲啓鍒拌繍琛屾�佸瓧娈点�� */
@Override
public AiMcpConnectivityTestDto testConnectivity(Long mountId, Long userId, Long tenantId) {
- AiMcpMount mount = requireMount(mountId, tenantId);
- long startedAt = System.currentTimeMillis();
- try (McpMountRuntimeFactory.McpMountRuntime runtime = mcpMountRuntimeFactory.create(List.of(mount), userId)) {
- long elapsedMs = System.currentTimeMillis() - startedAt;
- if (!runtime.getErrors().isEmpty()) {
- String message = String.join("锛�", runtime.getErrors());
- updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY, message, elapsedMs);
- AiMcpMount latest = requireMount(mount.getId(), tenantId);
- AiMcpConnectivityTestDto connectivity = buildConnectivityDto(latest, message, elapsedMs, runtime.getToolCallbacks().length);
- aiRedisSupport.cacheConnectivity(tenantId, mountId, connectivity);
- return connectivity;
- }
- String message = "杩為�氭�ф祴璇曟垚鍔燂紝瑙f瀽鍑� " + runtime.getToolCallbacks().length + " 涓伐鍏�";
- updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_HEALTHY, message, elapsedMs);
- AiMcpMount latest = requireMount(mount.getId(), tenantId);
- AiMcpConnectivityTestDto connectivity = buildConnectivityDto(latest, message, elapsedMs, runtime.getToolCallbacks().length);
- aiRedisSupport.cacheConnectivity(tenantId, mountId, connectivity);
- return connectivity;
- } catch (CoolException e) {
- throw e;
- } catch (Exception e) {
- long elapsedMs = System.currentTimeMillis() - startedAt;
- String message = "杩為�氭�ф祴璇曞け璐�: " + e.getMessage();
- updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY, message, elapsedMs);
- AiMcpMount latest = requireMount(mount.getId(), tenantId);
- AiMcpConnectivityTestDto connectivity = buildConnectivityDto(latest, message, elapsedMs, 0);
- aiRedisSupport.cacheConnectivity(tenantId, mountId, connectivity);
- return connectivity;
- }
+ AiMcpMount mount = aiMcpAdminService.requireMount(mountId, tenantId);
+ AiMcpConnectivityTestDto connectivity = aiMcpAdminService.testConnectivity(mount, userId, true);
+ aiMcpCacheStore.cacheConnectivity(tenantId, mountId, connectivity);
+ return connectivity;
}
/** 瀵硅〃鍗曢噷鐨勮崏绋块厤缃仛涓存椂杩為�氭�ф祴璇曪紝涓嶈惤搴撱�� */
@@ -154,34 +105,7 @@
mount.setTenantId(tenantId);
fillDefaults(mount);
ensureRequiredFields(mount, tenantId);
- long startedAt = System.currentTimeMillis();
- try (McpMountRuntimeFactory.McpMountRuntime runtime = mcpMountRuntimeFactory.create(List.of(mount), userId)) {
- long elapsedMs = System.currentTimeMillis() - startedAt;
- if (!runtime.getErrors().isEmpty()) {
- return AiMcpConnectivityTestDto.builder()
- .mountId(mount.getId())
- .mountName(mount.getName())
- .healthStatus(AiDefaults.MCP_HEALTH_UNHEALTHY)
- .message(String.join("锛�", runtime.getErrors()))
- .initElapsedMs(elapsedMs)
- .toolCount(runtime.getToolCallbacks().length)
- .testedAt(new java.text.SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new Date()))
- .build();
- }
- return AiMcpConnectivityTestDto.builder()
- .mountId(mount.getId())
- .mountName(mount.getName())
- .healthStatus(AiDefaults.MCP_HEALTH_HEALTHY)
- .message("鑽夌杩為�氭�ф祴璇曟垚鍔燂紝瑙f瀽鍑� " + runtime.getToolCallbacks().length + " 涓伐鍏�")
- .initElapsedMs(elapsedMs)
- .toolCount(runtime.getToolCallbacks().length)
- .testedAt(new java.text.SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new Date()))
- .build();
- } catch (CoolException e) {
- throw e;
- } catch (Exception e) {
- throw new CoolException("鑽夌杩為�氭�ф祴璇曞け璐�: " + e.getMessage());
- }
+ return aiMcpAdminService.testConnectivity(mount, userId, false);
}
/**
@@ -190,61 +114,15 @@
*/
@Override
public AiMcpToolTestDto testTool(Long mountId, Long userId, Long tenantId, AiMcpToolTestRequest request) {
- if (userId == null) {
- throw new CoolException("褰撳墠鐧诲綍鐢ㄦ埛涓嶅瓨鍦�");
- }
- if (tenantId == null) {
- throw new CoolException("褰撳墠绉熸埛涓嶅瓨鍦�");
- }
- if (request == null) {
- throw new CoolException("宸ュ叿娴嬭瘯鍙傛暟涓嶈兘涓虹┖");
- }
- if (!StringUtils.hasText(request.getToolName())) {
- throw new CoolException("宸ュ叿鍚嶇О涓嶈兘涓虹┖");
- }
- if (!StringUtils.hasText(request.getInputJson())) {
- throw new CoolException("宸ュ叿杈撳叆 JSON 涓嶈兘涓虹┖");
- }
- try {
- objectMapper.readTree(request.getInputJson());
- } catch (Exception e) {
- throw new CoolException("宸ュ叿杈撳叆 JSON 鏍煎紡閿欒: " + e.getMessage());
- }
- AiMcpMount mount = requireMount(mountId, tenantId);
- long startedAt = System.currentTimeMillis();
- try (McpMountRuntimeFactory.McpMountRuntime runtime = mcpMountRuntimeFactory.create(List.of(mount), userId)) {
- ToolCallback callback = Arrays.stream(runtime.getToolCallbacks())
- .filter(item -> item != null && item.getToolDefinition() != null)
- .filter(item -> request.getToolName().equals(item.getToolDefinition().name()))
- .findFirst()
- .orElseThrow(() -> new CoolException("鏈壘鍒拌娴嬭瘯鐨勫伐鍏�: " + request.getToolName()));
- String output = callback.call(
- request.getInputJson(),
- new ToolContext(Map.of("userId", userId, "tenantId", tenantId, "mountId", mountId))
- );
- updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_HEALTHY,
- "宸ュ叿娴嬭瘯鎴愬姛: " + request.getToolName(), System.currentTimeMillis() - startedAt);
- return AiMcpToolTestDto.builder()
- .toolName(request.getToolName())
- .inputJson(request.getInputJson())
- .output(output)
- .build();
- } catch (CoolException e) {
- updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY,
- "宸ュ叿娴嬭瘯澶辫触: " + e.getMessage(), System.currentTimeMillis() - startedAt);
- throw e;
- } catch (Exception e) {
- updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY,
- "宸ュ叿娴嬭瘯澶辫触: " + e.getMessage(), System.currentTimeMillis() - startedAt);
- throw new CoolException("宸ュ叿娴嬭瘯澶辫触: " + e.getMessage());
- }
+ AiMcpMount mount = aiMcpAdminService.requireMount(mountId, tenantId);
+ return aiMcpAdminService.testTool(mount, userId, tenantId, request);
}
@Override
public boolean save(AiMcpMount entity) {
boolean saved = super.save(entity);
if (saved && entity != null && entity.getTenantId() != null) {
- aiRedisSupport.evictMcpMountCaches(entity.getTenantId(), entity.getId());
+ evictMountRelatedCaches(entity.getTenantId(), entity.getId());
}
return saved;
}
@@ -253,7 +131,7 @@
public boolean updateById(AiMcpMount entity) {
boolean updated = super.updateById(entity);
if (updated && entity != null && entity.getTenantId() != null) {
- aiRedisSupport.evictMcpMountCaches(entity.getTenantId(), entity.getId());
+ evictMountRelatedCaches(entity.getTenantId(), entity.getId());
}
return updated;
}
@@ -269,7 +147,7 @@
if (removed) {
records.stream()
.filter(java.util.Objects::nonNull)
- .forEach(item -> aiRedisSupport.evictMcpMountCaches(item.getTenantId(), item.getId()));
+ .forEach(item -> evictMountRelatedCaches(item.getTenantId(), item.getId()));
}
return removed;
}
@@ -321,23 +199,6 @@
throw new CoolException("涓嶆敮鎸佺殑 MCP 浼犺緭绫诲瀷: " + aiMcpMount.getTransportType());
}
- private AiMcpMount requireMount(Long mountId, Long tenantId) {
- /** 鎸夌鎴峰姞杞芥寕杞借褰曪紝涓嶅瓨鍦ㄧ洿鎺ユ姏閿欍�� */
- ensureTenantId(tenantId);
- if (mountId == null) {
- throw new CoolException("MCP 鎸傝浇 ID 涓嶈兘涓虹┖");
- }
- AiMcpMount mount = this.getOne(new LambdaQueryWrapper<AiMcpMount>()
- .eq(AiMcpMount::getId, mountId)
- .eq(AiMcpMount::getTenantId, tenantId)
- .eq(AiMcpMount::getDeleted, 0)
- .last("limit 1"));
- if (mount == null) {
- throw new CoolException("MCP 鎸傝浇涓嶅瓨鍦�");
- }
- return mount;
- }
-
private void ensureBuiltinConflictFree(AiMcpMount aiMcpMount, Long tenantId) {
/** 鏍¢獙鍚岀鎴蜂笅鏄惁瀛樺湪涓庡綋鍓嶅唴缃紪鐮佷簰鏂ョ殑鍚敤鎸傝浇銆� */
if (aiMcpMount.getStatus() == null || aiMcpMount.getStatus() != StatusType.ENABLE.val) {
@@ -377,53 +238,9 @@
}
}
- private List<AiMcpToolPreviewDto> buildToolPreviewDtos(ToolCallback[] callbacks, List<AiMcpToolPreviewDto> governedCatalog) {
- /** 鎶婂簳灞� ToolCallback 鍜屾不鐞嗙洰褰曚俊鎭嫾鎴愬墠绔渶瑕佺殑缁撴瀯鍖栧伐鍏峰崱鐗囨暟鎹�� */
- List<AiMcpToolPreviewDto> tools = new ArrayList<>();
- Map<String, AiMcpToolPreviewDto> catalogMap = new java.util.LinkedHashMap<>();
- for (AiMcpToolPreviewDto item : governedCatalog) {
- if (item == null || !StringUtils.hasText(item.getName())) {
- continue;
- }
- catalogMap.put(item.getName(), item);
- }
- for (ToolCallback callback : callbacks) {
- if (callback == null || callback.getToolDefinition() == null) {
- continue;
- }
- AiMcpToolPreviewDto governedItem = catalogMap.get(callback.getToolDefinition().name());
- tools.add(AiMcpToolPreviewDto.builder()
- .name(callback.getToolDefinition().name())
- .description(callback.getToolDefinition().description())
- .inputSchema(callback.getToolDefinition().inputSchema())
- .returnDirect(callback.getToolMetadata() == null ? null : callback.getToolMetadata().returnDirect())
- .toolGroup(governedItem == null ? null : governedItem.getToolGroup())
- .toolPurpose(governedItem == null ? null : governedItem.getToolPurpose())
- .queryBoundary(governedItem == null ? null : governedItem.getQueryBoundary())
- .exampleQuestions(governedItem == null ? List.of() : governedItem.getExampleQuestions())
- .build());
- }
- return tools;
- }
-
- private void updateHealthStatus(Long mountId, String healthStatus, String message, Long initElapsedMs) {
- this.update(new LambdaUpdateWrapper<AiMcpMount>()
- .eq(AiMcpMount::getId, mountId)
- .set(AiMcpMount::getHealthStatus, healthStatus)
- .set(AiMcpMount::getLastTestTime, new Date())
- .set(AiMcpMount::getLastTestMessage, message)
- .set(AiMcpMount::getLastInitElapsedMs, initElapsedMs));
- }
-
- private AiMcpConnectivityTestDto buildConnectivityDto(AiMcpMount mount, String message, Long initElapsedMs, Integer toolCount) {
- return AiMcpConnectivityTestDto.builder()
- .mountId(mount.getId())
- .mountName(mount.getName())
- .healthStatus(mount.getHealthStatus())
- .message(message)
- .initElapsedMs(initElapsedMs)
- .toolCount(toolCount)
- .testedAt(mount.getLastTestTime$())
- .build();
+ private void evictMountRelatedCaches(Long tenantId, Long mountId) {
+ aiMcpCacheStore.evictMcpMountCaches(tenantId, mountId);
+ aiConfigCacheStore.evictTenantConfigCaches(tenantId);
+ aiConversationCacheStore.evictTenantRuntimeCaches(tenantId);
}
}
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiOpenAiApiSupport.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiOpenAiApiSupport.java
index bb404b9..06ab523 100644
--- a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiOpenAiApiSupport.java
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiOpenAiApiSupport.java
@@ -10,7 +10,7 @@
import java.util.Locale;
-final class AiOpenAiApiSupport {
+public final class AiOpenAiApiSupport {
private static final String DEFAULT_COMPLETIONS_PATH = "/v1/chat/completions";
private static final String DEFAULT_EMBEDDINGS_PATH = "/v1/embeddings";
@@ -21,7 +21,7 @@
private AiOpenAiApiSupport() {
}
- static OpenAiApi buildOpenAiApi(AiParam aiParam) {
+ public static OpenAiApi buildOpenAiApi(AiParam aiParam) {
int timeoutMs = aiParam.getTimeoutMs() == null ? AiDefaults.DEFAULT_TIMEOUT_MS : aiParam.getTimeoutMs();
SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory();
requestFactory.setConnectTimeout(timeoutMs);
@@ -38,7 +38,7 @@
.build();
}
- static EndpointConfig resolveEndpointConfig(String rawBaseUrl) {
+ public static EndpointConfig resolveEndpointConfig(String rawBaseUrl) {
String normalizedBaseUrl = trimTrailingSlash(rawBaseUrl);
String lowerCaseBaseUrl = normalizedBaseUrl.toLowerCase(Locale.ROOT);
@@ -81,6 +81,6 @@
return normalized;
}
- record EndpointConfig(String baseUrl, String completionsPath, String embeddingsPath) {
+ public record EndpointConfig(String baseUrl, String completionsPath, String embeddingsPath) {
}
}
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiParamServiceImpl.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiParamServiceImpl.java
index 4e1e83e..7901060 100644
--- a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiParamServiceImpl.java
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiParamServiceImpl.java
@@ -8,6 +8,8 @@
import com.vincent.rsf.server.ai.dto.AiParamValidateResultDto;
import com.vincent.rsf.server.ai.entity.AiParam;
import com.vincent.rsf.server.ai.mapper.AiParamMapper;
+import com.vincent.rsf.server.ai.store.AiConfigCacheStore;
+import com.vincent.rsf.server.ai.store.AiConversationCacheStore;
import com.vincent.rsf.server.ai.service.AiParamService;
import com.vincent.rsf.server.system.enums.StatusType;
import lombok.RequiredArgsConstructor;
@@ -24,7 +26,8 @@
public class AiParamServiceImpl extends ServiceImpl<AiParamMapper, AiParam> implements AiParamService {
private final AiParamValidationSupport aiParamValidationSupport;
- private final AiRedisSupport aiRedisSupport;
+ private final AiConfigCacheStore aiConfigCacheStore;
+ private final AiConversationCacheStore aiConversationCacheStore;
@Override
public AiParam getActiveParam(Long tenantId) {
@@ -97,7 +100,8 @@
if (!super.updateById(target)) {
throw new CoolException("璁剧疆榛樿 AI 鍙傛暟澶辫触");
}
- aiRedisSupport.evictTenantConfigCaches(tenantId);
+ aiConfigCacheStore.evictTenantConfigCaches(tenantId);
+ aiConversationCacheStore.evictTenantRuntimeCaches(tenantId);
return target;
}
@@ -138,7 +142,8 @@
public boolean save(AiParam entity) {
boolean saved = super.save(entity);
if (saved && entity != null && entity.getTenantId() != null) {
- aiRedisSupport.evictTenantConfigCaches(entity.getTenantId());
+ aiConfigCacheStore.evictTenantConfigCaches(entity.getTenantId());
+ aiConversationCacheStore.evictTenantRuntimeCaches(entity.getTenantId());
}
return saved;
}
@@ -147,7 +152,8 @@
public boolean updateById(AiParam entity) {
boolean updated = super.updateById(entity);
if (updated && entity != null && entity.getTenantId() != null) {
- aiRedisSupport.evictTenantConfigCaches(entity.getTenantId());
+ aiConfigCacheStore.evictTenantConfigCaches(entity.getTenantId());
+ aiConversationCacheStore.evictTenantRuntimeCaches(entity.getTenantId());
}
return updated;
}
@@ -166,7 +172,10 @@
.map(AiParam::getTenantId)
.filter(java.util.Objects::nonNull)
.distinct()
- .forEach(aiRedisSupport::evictTenantConfigCaches);
+ .forEach(tenantId -> {
+ aiConfigCacheStore.evictTenantConfigCaches(tenantId);
+ aiConversationCacheStore.evictTenantRuntimeCaches(tenantId);
+ });
}
return removed;
}
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiPromptServiceImpl.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiPromptServiceImpl.java
index c6d688f..02b3398 100644
--- a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiPromptServiceImpl.java
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiPromptServiceImpl.java
@@ -7,6 +7,8 @@
import com.vincent.rsf.server.ai.dto.AiPromptPreviewRequest;
import com.vincent.rsf.server.ai.entity.AiPrompt;
import com.vincent.rsf.server.ai.mapper.AiPromptMapper;
+import com.vincent.rsf.server.ai.store.AiConfigCacheStore;
+import com.vincent.rsf.server.ai.store.AiConversationCacheStore;
import com.vincent.rsf.server.ai.service.AiPromptService;
import com.vincent.rsf.server.system.enums.StatusType;
import lombok.RequiredArgsConstructor;
@@ -18,7 +20,8 @@
public class AiPromptServiceImpl extends ServiceImpl<AiPromptMapper, AiPrompt> implements AiPromptService {
private final AiPromptRenderSupport aiPromptRenderSupport;
- private final AiRedisSupport aiRedisSupport;
+ private final AiConfigCacheStore aiConfigCacheStore;
+ private final AiConversationCacheStore aiConversationCacheStore;
@Override
public AiPrompt getActivePrompt(String code, Long tenantId) {
@@ -76,7 +79,8 @@
public boolean save(AiPrompt entity) {
boolean saved = super.save(entity);
if (saved && entity != null && entity.getTenantId() != null) {
- aiRedisSupport.evictTenantConfigCaches(entity.getTenantId());
+ aiConfigCacheStore.evictTenantConfigCaches(entity.getTenantId());
+ aiConversationCacheStore.evictTenantRuntimeCaches(entity.getTenantId());
}
return saved;
}
@@ -85,7 +89,8 @@
public boolean updateById(AiPrompt entity) {
boolean updated = super.updateById(entity);
if (updated && entity != null && entity.getTenantId() != null) {
- aiRedisSupport.evictTenantConfigCaches(entity.getTenantId());
+ aiConfigCacheStore.evictTenantConfigCaches(entity.getTenantId());
+ aiConversationCacheStore.evictTenantRuntimeCaches(entity.getTenantId());
}
return updated;
}
@@ -103,7 +108,10 @@
.map(AiPrompt::getTenantId)
.filter(java.util.Objects::nonNull)
.distinct()
- .forEach(aiRedisSupport::evictTenantConfigCaches);
+ .forEach(tenantId -> {
+ aiConfigCacheStore.evictTenantConfigCaches(tenantId);
+ aiConversationCacheStore.evictTenantRuntimeCaches(tenantId);
+ });
}
return removed;
}
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiRedisSupport.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiRedisSupport.java
deleted file mode 100644
index 1a5c941..0000000
--- a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiRedisSupport.java
+++ /dev/null
@@ -1,519 +0,0 @@
-package com.vincent.rsf.server.ai.service.impl;
-
-import com.fasterxml.jackson.core.type.TypeReference;
-import com.fasterxml.jackson.databind.ObjectMapper;
-import com.vincent.rsf.server.ai.config.AiDefaults;
-import com.vincent.rsf.server.ai.dto.AiChatMemoryDto;
-import com.vincent.rsf.server.ai.dto.AiChatRuntimeDto;
-import com.vincent.rsf.server.ai.dto.AiChatSessionDto;
-import com.vincent.rsf.server.ai.dto.AiMcpConnectivityTestDto;
-import com.vincent.rsf.server.ai.dto.AiMcpToolPreviewDto;
-import com.vincent.rsf.server.ai.dto.AiObserveStatsDto;
-import com.vincent.rsf.server.ai.dto.AiResolvedConfig;
-import com.vincent.rsf.server.common.service.RedisService;
-import lombok.AllArgsConstructor;
-import lombok.Builder;
-import lombok.Data;
-import lombok.NoArgsConstructor;
-import lombok.RequiredArgsConstructor;
-import lombok.extern.slf4j.Slf4j;
-import org.springframework.stereotype.Service;
-import org.springframework.util.StringUtils;
-import redis.clients.jedis.Jedis;
-
-import java.net.URLEncoder;
-import java.nio.charset.StandardCharsets;
-import java.security.MessageDigest;
-import java.time.Instant;
-import java.util.LinkedHashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-import java.util.UUID;
-import java.util.function.Consumer;
-import java.util.function.Function;
-import java.util.function.Supplier;
-
-@Slf4j
-@Service
-@RequiredArgsConstructor
-public class AiRedisSupport {
-
- /** 缁熶竴鏀跺彛 AI 妯″潡鐨� Redis key銆乀TL 鍜屽簭鍒楀寲绛栫暐锛岄伩鍏嶄笟鍔$被鐩存帴鏁e啓 Redis銆� */
- private static final String CONFIG_KEY_PREFIX = "AI:CONFIG:";
- private static final String RUNTIME_KEY_PREFIX = "AI:RUNTIME:";
- private static final String MEMORY_KEY_PREFIX = "AI:MEMORY:";
- private static final String SESSIONS_KEY_PREFIX = "AI:SESSIONS:";
- private static final String MCP_PREVIEW_KEY_PREFIX = "AI:MCP:PREVIEW:";
- private static final String MCP_HEALTH_KEY_PREFIX = "AI:MCP:HEALTH:";
- private static final String STREAM_STATE_KEY_PREFIX = "AI:STREAM:";
- private static final String TOOL_RESULT_KEY_PREFIX = "AI:TOOL:RESULT:";
- private static final String RATE_LIMIT_KEY_PREFIX = "AI:RATE:";
- private static final String OBSERVE_STATS_KEY_PREFIX = "AI:OBSERVE:STATS:";
- private static final String OBSERVE_TOOL_RANK_KEY_PREFIX = "AI:OBSERVE:TOOL:RANK:";
- private static final String OBSERVE_TOOL_FAIL_RANK_KEY_PREFIX = "AI:OBSERVE:TOOL:FAIL:RANK:";
-
- private static final String FIELD_CALL_COUNT = "callCount";
- private static final String FIELD_SUCCESS_COUNT = "successCount";
- private static final String FIELD_FAILURE_COUNT = "failureCount";
- private static final String FIELD_ELAPSED_SUM = "elapsedSum";
- private static final String FIELD_ELAPSED_COUNT = "elapsedCount";
- private static final String FIELD_FIRST_TOKEN_SUM = "firstTokenSum";
- private static final String FIELD_FIRST_TOKEN_COUNT = "firstTokenCount";
- private static final String FIELD_TOTAL_TOKENS_SUM = "totalTokensSum";
- private static final String FIELD_TOTAL_TOKENS_COUNT = "totalTokensCount";
- private static final String FIELD_TOOL_CALL_COUNT = "toolCallCount";
- private static final String FIELD_TOOL_SUCCESS_COUNT = "toolSuccessCount";
- private static final String FIELD_TOOL_FAILURE_COUNT = "toolFailureCount";
-
- private final RedisService redisService;
- private final ObjectMapper objectMapper;
-
- public AiResolvedConfig getResolvedConfig(Long tenantId, String promptCode, Long aiParamId) {
- return readJson(buildConfigKey(tenantId, promptCode, aiParamId), AiResolvedConfig.class);
- }
-
- public void cacheResolvedConfig(Long tenantId, String promptCode, Long aiParamId, AiResolvedConfig config) {
- writeJson(buildConfigKey(tenantId, promptCode, aiParamId), config, AiDefaults.CONFIG_CACHE_TTL_SECONDS);
- }
-
- public void evictTenantConfigCaches(Long tenantId) {
- deleteByPrefix(CONFIG_KEY_PREFIX + tenantId + ":");
- deleteByPrefix(RUNTIME_KEY_PREFIX + tenantId + ":");
- }
-
- public AiChatRuntimeDto getRuntime(Long tenantId, Long userId, String promptCode, Long sessionId, Long aiParamId) {
- return readJson(buildRuntimeKey(tenantId, userId, promptCode, sessionId, aiParamId), AiChatRuntimeDto.class);
- }
-
- public void cacheRuntime(Long tenantId, Long userId, String promptCode, Long sessionId, Long aiParamId, AiChatRuntimeDto runtime) {
- writeJson(buildRuntimeKey(tenantId, userId, promptCode, sessionId, aiParamId), runtime, AiDefaults.RUNTIME_CACHE_TTL_SECONDS);
- }
-
- public AiChatMemoryDto getMemory(Long tenantId, Long userId, String promptCode, Long sessionId) {
- return readJson(buildMemoryKey(tenantId, userId, promptCode, sessionId), AiChatMemoryDto.class);
- }
-
- public void cacheMemory(Long tenantId, Long userId, String promptCode, Long sessionId, AiChatMemoryDto memory) {
- writeJson(buildMemoryKey(tenantId, userId, promptCode, sessionId), memory, AiDefaults.MEMORY_CACHE_TTL_SECONDS);
- }
-
- public List<AiChatSessionDto> getSessionList(Long tenantId, Long userId, String promptCode, String keyword) {
- return readJson(buildSessionsKey(tenantId, userId, promptCode, keyword), new TypeReference<List<AiChatSessionDto>>() {
- });
- }
-
- public void cacheSessionList(Long tenantId, Long userId, String promptCode, String keyword, List<AiChatSessionDto> sessions) {
- writeJson(buildSessionsKey(tenantId, userId, promptCode, keyword), sessions, AiDefaults.SESSION_LIST_CACHE_TTL_SECONDS);
- }
-
- public void evictUserConversationCaches(Long tenantId, Long userId) {
- deleteByPrefix(RUNTIME_KEY_PREFIX + tenantId + ":" + userId + ":");
- deleteByPrefix(MEMORY_KEY_PREFIX + tenantId + ":" + userId + ":");
- deleteByPrefix(SESSIONS_KEY_PREFIX + tenantId + ":" + userId + ":");
- }
-
- public List<AiMcpToolPreviewDto> getToolPreview(Long tenantId, Long mountId) {
- return readJson(buildMcpPreviewKey(tenantId, mountId), new TypeReference<List<AiMcpToolPreviewDto>>() {
- });
- }
-
- public void cacheToolPreview(Long tenantId, Long mountId, List<AiMcpToolPreviewDto> tools) {
- writeJson(buildMcpPreviewKey(tenantId, mountId), tools, AiDefaults.MCP_PREVIEW_CACHE_TTL_SECONDS);
- }
-
- public AiMcpConnectivityTestDto getConnectivity(Long tenantId, Long mountId) {
- return readJson(buildMcpHealthKey(tenantId, mountId), AiMcpConnectivityTestDto.class);
- }
-
- public void cacheConnectivity(Long tenantId, Long mountId, AiMcpConnectivityTestDto connectivity) {
- writeJson(buildMcpHealthKey(tenantId, mountId), connectivity, AiDefaults.MCP_HEALTH_CACHE_TTL_SECONDS);
- }
-
- public void evictMcpMountCaches(Long tenantId, Long mountId) {
- if (mountId != null) {
- delete(buildMcpPreviewKey(tenantId, mountId));
- delete(buildMcpHealthKey(tenantId, mountId));
- } else {
- deleteByPrefix(MCP_PREVIEW_KEY_PREFIX + tenantId + ":");
- deleteByPrefix(MCP_HEALTH_KEY_PREFIX + tenantId + ":");
- }
- evictTenantConfigCaches(tenantId);
- }
-
- public boolean allowChatRequest(Long tenantId, Long userId, String promptCode) {
- String key = buildRateLimitKey(tenantId, userId, promptCode);
- long now = Instant.now().toEpochMilli();
- long windowStart = now - (AiDefaults.CHAT_RATE_LIMIT_WINDOW_SECONDS * 1000L);
- Boolean allowed = execute(jedis -> {
- // 鐢� zset 缁存姢婊戝姩绐楀彛锛岃�屼笉鏄畝鍗曡鏁板櫒锛岄伩鍏嶇獥鍙h竟鐣屽嚭鐜扮獊鍒鸿鍒ゃ��
- jedis.zremrangeByScore(key, 0, windowStart);
- long count = jedis.zcard(key);
- if (count >= AiDefaults.CHAT_RATE_LIMIT_MAX_REQUESTS) {
- jedis.expire(key, AiDefaults.CHAT_RATE_LIMIT_WINDOW_SECONDS);
- return Boolean.FALSE;
- }
- jedis.zadd(key, now, now + ":" + UUID.randomUUID());
- jedis.expire(key, AiDefaults.CHAT_RATE_LIMIT_WINDOW_SECONDS);
- return Boolean.TRUE;
- });
- return Boolean.TRUE.equals(allowed);
- }
-
- public void markStreamState(String requestId, Long tenantId, Long userId, Long sessionId, String promptCode,
- String status, String errorMessage) {
- if (!StringUtils.hasText(requestId)) {
- return;
- }
- writeJson(buildStreamStateKey(tenantId, requestId), AiStreamState.builder()
- .requestId(requestId)
- .tenantId(tenantId)
- .userId(userId)
- .sessionId(sessionId)
- .promptCode(promptCode)
- .status(status)
- .errorMessage(errorMessage)
- .timestamp(Instant.now().toEpochMilli())
- .build(), AiDefaults.STREAM_STATE_TTL_SECONDS);
- }
-
- public CachedToolResult getToolResult(Long tenantId, String requestId, String toolName, String toolInput) {
- return readJson(buildToolResultKey(tenantId, requestId, toolName, toolInput), CachedToolResult.class);
- }
-
- public void cacheToolResult(Long tenantId, String requestId, String toolName, String toolInput,
- boolean success, String output, String errorMessage) {
- writeJson(buildToolResultKey(tenantId, requestId, toolName, toolInput), CachedToolResult.builder()
- .success(success)
- .output(output)
- .errorMessage(errorMessage)
- .build(), AiDefaults.TOOL_RESULT_CACHE_TTL_SECONDS);
- }
-
- public void recordObserveCallStarted(Long tenantId) {
- executeVoid(jedis -> jedis.hincrBy(buildObserveStatsKey(tenantId), FIELD_CALL_COUNT, 1));
- }
-
- public void recordObserveCallFinished(Long tenantId, String status, Long elapsedMs, Long firstTokenLatencyMs, Integer totalTokens) {
- executeVoid(jedis -> {
- String key = buildObserveStatsKey(tenantId);
- if ("COMPLETED".equals(status)) {
- jedis.hincrBy(key, FIELD_SUCCESS_COUNT, 1);
- } else if ("FAILED".equals(status)) {
- jedis.hincrBy(key, FIELD_FAILURE_COUNT, 1);
- }
- if (elapsedMs != null) {
- jedis.hincrBy(key, FIELD_ELAPSED_SUM, elapsedMs);
- jedis.hincrBy(key, FIELD_ELAPSED_COUNT, 1);
- }
- if (firstTokenLatencyMs != null) {
- jedis.hincrBy(key, FIELD_FIRST_TOKEN_SUM, firstTokenLatencyMs);
- jedis.hincrBy(key, FIELD_FIRST_TOKEN_COUNT, 1);
- }
- if (totalTokens != null) {
- jedis.hincrBy(key, FIELD_TOTAL_TOKENS_SUM, totalTokens.longValue());
- jedis.hincrBy(key, FIELD_TOTAL_TOKENS_COUNT, 1);
- }
- });
- }
-
- public void recordObserveToolCall(Long tenantId, String toolName, String status) {
- executeVoid(jedis -> {
- String key = buildObserveStatsKey(tenantId);
- jedis.hincrBy(key, FIELD_TOOL_CALL_COUNT, 1);
- if ("COMPLETED".equals(status)) {
- jedis.hincrBy(key, FIELD_TOOL_SUCCESS_COUNT, 1);
- } else if ("FAILED".equals(status)) {
- jedis.hincrBy(key, FIELD_TOOL_FAILURE_COUNT, 1);
- }
- if (StringUtils.hasText(toolName)) {
- jedis.zincrby(buildToolRankKey(tenantId), 1D, toolName);
- if ("FAILED".equals(status)) {
- jedis.zincrby(buildToolFailRankKey(tenantId), 1D, toolName);
- }
- }
- });
- }
-
- public AiObserveStatsDto getObserveStats(Long tenantId, Supplier<AiObserveStatsDto> fallbackLoader) {
- AiObserveStatsDto cached = readObserveStats(tenantId);
- if (cached != null) {
- return cached;
- }
- // Redis 涓虹┖鏃跺啀鍥炴簮鏁版嵁搴擄紝閬垮厤绠$悊绔湅鏉挎瘡娆¢兘鎵叏閲忔棩蹇楄〃銆�
- AiObserveStatsDto snapshot = fallbackLoader.get();
- if (snapshot != null) {
- seedObserveStats(tenantId, snapshot);
- }
- return snapshot;
- }
-
- private AiObserveStatsDto readObserveStats(Long tenantId) {
- Map<String, String> fields = execute(jedis -> {
- String key = buildObserveStatsKey(tenantId);
- if (!jedis.exists(key)) {
- return null;
- }
- return jedis.hgetAll(key);
- });
- if (fields == null || fields.isEmpty()) {
- return null;
- }
- long callCount = parseLong(fields.get(FIELD_CALL_COUNT));
- long successCount = parseLong(fields.get(FIELD_SUCCESS_COUNT));
- long failureCount = parseLong(fields.get(FIELD_FAILURE_COUNT));
- long elapsedSum = parseLong(fields.get(FIELD_ELAPSED_SUM));
- long elapsedCount = parseLong(fields.get(FIELD_ELAPSED_COUNT));
- long firstTokenSum = parseLong(fields.get(FIELD_FIRST_TOKEN_SUM));
- long firstTokenCount = parseLong(fields.get(FIELD_FIRST_TOKEN_COUNT));
- long totalTokensSum = parseLong(fields.get(FIELD_TOTAL_TOKENS_SUM));
- long totalTokensCount = parseLong(fields.get(FIELD_TOTAL_TOKENS_COUNT));
- long toolCallCount = parseLong(fields.get(FIELD_TOOL_CALL_COUNT));
- long toolSuccessCount = parseLong(fields.get(FIELD_TOOL_SUCCESS_COUNT));
- long toolFailureCount = parseLong(fields.get(FIELD_TOOL_FAILURE_COUNT));
- return AiObserveStatsDto.builder()
- .callCount(callCount)
- .successCount(successCount)
- .failureCount(failureCount)
- .avgElapsedMs(elapsedCount == 0 ? 0L : elapsedSum / elapsedCount)
- .avgFirstTokenLatencyMs(firstTokenCount == 0 ? 0L : firstTokenSum / firstTokenCount)
- .totalTokens(totalTokensSum)
- .avgTotalTokens(totalTokensCount == 0 ? 0L : totalTokensSum / totalTokensCount)
- .toolCallCount(toolCallCount)
- .toolSuccessCount(toolSuccessCount)
- .toolFailureCount(toolFailureCount)
- .toolSuccessRate(toolCallCount == 0 ? 0D : (toolSuccessCount * 100D) / toolCallCount)
- .build();
- }
-
- private void seedObserveStats(Long tenantId, AiObserveStatsDto snapshot) {
- executeVoid(jedis -> {
- String key = buildObserveStatsKey(tenantId);
- Map<String, String> values = new LinkedHashMap<>();
- values.put(FIELD_CALL_COUNT, String.valueOf(defaultLong(snapshot.getCallCount())));
- values.put(FIELD_SUCCESS_COUNT, String.valueOf(defaultLong(snapshot.getSuccessCount())));
- values.put(FIELD_FAILURE_COUNT, String.valueOf(defaultLong(snapshot.getFailureCount())));
- values.put(FIELD_ELAPSED_SUM, String.valueOf(defaultLong(snapshot.getAvgElapsedMs()) * defaultLong(snapshot.getCallCount())));
- values.put(FIELD_ELAPSED_COUNT, String.valueOf(defaultLong(snapshot.getCallCount())));
- values.put(FIELD_FIRST_TOKEN_SUM, String.valueOf(defaultLong(snapshot.getAvgFirstTokenLatencyMs()) * defaultLong(snapshot.getCallCount())));
- values.put(FIELD_FIRST_TOKEN_COUNT, String.valueOf(defaultLong(snapshot.getCallCount())));
- values.put(FIELD_TOTAL_TOKENS_SUM, String.valueOf(defaultLong(snapshot.getTotalTokens())));
- values.put(FIELD_TOTAL_TOKENS_COUNT, String.valueOf(defaultLong(snapshot.getCallCount())));
- values.put(FIELD_TOOL_CALL_COUNT, String.valueOf(defaultLong(snapshot.getToolCallCount())));
- values.put(FIELD_TOOL_SUCCESS_COUNT, String.valueOf(defaultLong(snapshot.getToolSuccessCount())));
- values.put(FIELD_TOOL_FAILURE_COUNT, String.valueOf(defaultLong(snapshot.getToolFailureCount())));
- jedis.hset(key, values);
- });
- }
-
- private String buildConfigKey(Long tenantId, String promptCode, Long aiParamId) {
- return CONFIG_KEY_PREFIX + tenantId + ":" + safeToken(promptCode) + ":" + aiParamToken(aiParamId);
- }
-
- private String buildRuntimeKey(Long tenantId, Long userId, String promptCode, Long sessionId, Long aiParamId) {
- return RUNTIME_KEY_PREFIX + tenantId + ":" + userId + ":" + safeToken(promptCode) + ":" + sessionToken(sessionId) + ":" + aiParamToken(aiParamId);
- }
-
- private String buildMemoryKey(Long tenantId, Long userId, String promptCode, Long sessionId) {
- return MEMORY_KEY_PREFIX + tenantId + ":" + userId + ":" + safeToken(promptCode) + ":" + sessionToken(sessionId);
- }
-
- private String buildSessionsKey(Long tenantId, Long userId, String promptCode, String keyword) {
- return SESSIONS_KEY_PREFIX + tenantId + ":" + userId + ":" + safeToken(promptCode) + ":" + safeToken(keyword);
- }
-
- private String buildMcpPreviewKey(Long tenantId, Long mountId) {
- return MCP_PREVIEW_KEY_PREFIX + tenantId + ":" + mountId;
- }
-
- private String buildMcpHealthKey(Long tenantId, Long mountId) {
- return MCP_HEALTH_KEY_PREFIX + tenantId + ":" + mountId;
- }
-
- private String buildStreamStateKey(Long tenantId, String requestId) {
- return STREAM_STATE_KEY_PREFIX + tenantId + ":" + safeToken(requestId);
- }
-
- private String buildToolResultKey(Long tenantId, String requestId, String toolName, String toolInput) {
- return TOOL_RESULT_KEY_PREFIX + tenantId + ":" + safeToken(requestId) + ":" + safeToken(toolName) + ":" + digest(toolInput);
- }
-
- private String buildRateLimitKey(Long tenantId, Long userId, String promptCode) {
- return RATE_LIMIT_KEY_PREFIX + tenantId + ":" + userId + ":" + safeToken(promptCode);
- }
-
- private String buildObserveStatsKey(Long tenantId) {
- return OBSERVE_STATS_KEY_PREFIX + tenantId;
- }
-
- private String buildToolRankKey(Long tenantId) {
- return OBSERVE_TOOL_RANK_KEY_PREFIX + tenantId;
- }
-
- private String buildToolFailRankKey(Long tenantId) {
- return OBSERVE_TOOL_FAIL_RANK_KEY_PREFIX + tenantId;
- }
-
- private String sessionToken(Long sessionId) {
- return sessionId == null ? "LATEST" : String.valueOf(sessionId);
- }
-
- private String aiParamToken(Long aiParamId) {
- return aiParamId == null ? "DEFAULT" : String.valueOf(aiParamId);
- }
-
- private String safeToken(String source) {
- if (!StringUtils.hasText(source)) {
- return "_";
- }
- return URLEncoder.encode(source.trim(), StandardCharsets.UTF_8);
- }
-
- private String digest(String source) {
- try {
- MessageDigest messageDigest = MessageDigest.getInstance("SHA-256");
- byte[] bytes = messageDigest.digest((source == null ? "" : source).getBytes(StandardCharsets.UTF_8));
- StringBuilder builder = new StringBuilder();
- for (byte value : bytes) {
- builder.append(String.format("%02x", value));
- }
- return builder.toString();
- } catch (Exception e) {
- return safeToken(source);
- }
- }
-
- private <T> T readJson(String key, Class<T> type) {
- return readJson(key, value -> objectMapper.readValue(value, type));
- }
-
- private <T> T readJson(String key, TypeReference<T> typeReference) {
- return readJson(key, value -> objectMapper.readValue(value, typeReference));
- }
-
- private <T> T readJson(String key, JsonReader<T> reader) {
- return execute(jedis -> {
- String value = jedis.get(key);
- if (!StringUtils.hasText(value)) {
- return null;
- }
- try {
- return reader.read(value);
- } catch (Exception e) {
- // 鍙嶅簭鍒楀寲澶辫触鏃剁洿鎺ュ垹鍧忕紦瀛橈紝璁╀笅娆¤姹傝嚜鐒跺洖婧愰噸寤恒��
- log.warn("AI redis cache deserialize failed, key={}, message={}", key, e.getMessage());
- jedis.del(key);
- return null;
- }
- });
- }
-
- private void writeJson(String key, Object value, int ttlSeconds) {
- if (value == null) {
- delete(key);
- return;
- }
- executeVoid(jedis -> {
- try {
- jedis.setex(key, ttlSeconds, objectMapper.writeValueAsString(value));
- } catch (Exception e) {
- log.warn("AI redis cache serialize failed, key={}, message={}", key, e.getMessage());
- }
- });
- }
-
- private void delete(String key) {
- executeVoid(jedis -> jedis.del(key));
- }
-
- private void deleteByPrefix(String prefix) {
- executeVoid(jedis -> {
- Set<String> keys = jedis.keys(prefix + "*");
- if (keys == null || keys.isEmpty()) {
- return;
- }
- jedis.del(keys.toArray(new String[0]));
- });
- }
-
- private long parseLong(String source) {
- if (!StringUtils.hasText(source)) {
- return 0L;
- }
- try {
- return Long.parseLong(source);
- } catch (Exception e) {
- return 0L;
- }
- }
-
- private long defaultLong(Long value) {
- return value == null ? 0L : value;
- }
-
- private <T> T execute(Function<Jedis, T> action) {
- Jedis jedis = redisService.getJedis();
- if (jedis == null) {
- return null;
- }
- try (jedis) {
- return action.apply(jedis);
- } catch (Exception e) {
- log.warn("AI redis operation skipped, message={}", e.getMessage());
- return null;
- }
- }
-
- private void executeVoid(Consumer<Jedis> action) {
- Jedis jedis = redisService.getJedis();
- if (jedis == null) {
- return;
- }
- try (jedis) {
- action.accept(jedis);
- } catch (Exception e) {
- log.warn("AI redis operation skipped, message={}", e.getMessage());
- }
- }
-
- @FunctionalInterface
- private interface JsonReader<T> {
- T read(String value) throws Exception;
- }
-
- @Data
- @Builder
- @NoArgsConstructor
- @AllArgsConstructor
- public static class CachedToolResult {
-
- private boolean success;
-
- private String output;
-
- private String errorMessage;
- }
-
- @Data
- @Builder
- @NoArgsConstructor
- @AllArgsConstructor
- private static class AiStreamState {
-
- private String requestId;
-
- private Long tenantId;
-
- private Long userId;
-
- private Long sessionId;
-
- private String promptCode;
-
- private String status;
-
- private String errorMessage;
-
- private Long timestamp;
- }
-}
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/BuiltinMcpToolRegistryImpl.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/BuiltinMcpToolRegistryImpl.java
index c04d84b..da7c1d6 100644
--- a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/BuiltinMcpToolRegistryImpl.java
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/BuiltinMcpToolRegistryImpl.java
@@ -4,6 +4,7 @@
import com.vincent.rsf.server.ai.config.AiDefaults;
import com.vincent.rsf.server.ai.dto.AiMcpToolPreviewDto;
import com.vincent.rsf.server.ai.entity.AiMcpMount;
+import com.vincent.rsf.server.ai.service.impl.mcp.BuiltinMcpToolCatalogProvider;
import com.vincent.rsf.server.ai.service.BuiltinMcpToolRegistry;
import com.vincent.rsf.server.ai.tool.RsfWmsBaseTools;
import com.vincent.rsf.server.ai.tool.RsfWmsStockTools;
@@ -27,6 +28,7 @@
private final RsfWmsStockTools rsfWmsStockTools;
private final RsfWmsTaskTools rsfWmsTaskTools;
private final RsfWmsBaseTools rsfWmsBaseTools;
+ private final BuiltinMcpToolCatalogProvider builtinMcpToolCatalogProvider;
/**
* 鏍¢獙鍐呯疆 MCP 缂栫爜鏄惁鍚堟硶銆�
@@ -38,7 +40,7 @@
if (!StringUtils.hasText(builtinCode)) {
throw new CoolException("鍐呯疆 MCP 缂栫爜涓嶈兘涓虹┖");
}
- if (!supportedBuiltinCodes().contains(builtinCode)) {
+ if (!builtinMcpToolCatalogProvider.supportedBuiltinCodes().contains(builtinCode)) {
throw new CoolException("涓嶆敮鎸佺殑鍐呯疆 MCP 缂栫爜: " + builtinCode);
}
}
@@ -68,10 +70,7 @@
@Override
public List<AiMcpToolPreviewDto> listBuiltinToolCatalog(String builtinCode) {
validateBuiltinCode(builtinCode);
- if (AiDefaults.MCP_BUILTIN_RSF_WMS.equals(builtinCode)) {
- return new ArrayList<>(catalogByBuiltinCode(builtinCode).values());
- }
- return new ArrayList<>(catalogByBuiltinCode(builtinCode).values());
+ return new ArrayList<>(builtinMcpToolCatalogProvider.getCatalog(builtinCode).values());
}
private List<ToolCallback> createValidatedCallbacks(Object toolBean, String builtinCode) {
@@ -81,7 +80,7 @@
* 2. 姣忎釜宸ュ叿閮藉繀椤诲嚭鐜板湪娌荤悊鐩綍閲�
*/
List<ToolCallback> callbacks = Arrays.asList(ToolCallbacks.from(toolBean));
- Map<String, AiMcpToolPreviewDto> catalog = catalogByBuiltinCode(builtinCode);
+ Map<String, AiMcpToolPreviewDto> catalog = builtinMcpToolCatalogProvider.getCatalog(builtinCode);
for (ToolCallback callback : callbacks) {
if (callback == null || callback.getToolDefinition() == null) {
continue;
@@ -97,74 +96,4 @@
return callbacks;
}
- private List<String> supportedBuiltinCodes() {
- /** 褰撳墠鐗堟湰鍏佽鎸傝浇鐨勫叏閮ㄥ唴缃� MCP 缂栫爜銆� */
- return List.of(AiDefaults.MCP_BUILTIN_RSF_WMS);
- }
-
- private Map<String, AiMcpToolPreviewDto> catalogByBuiltinCode(String builtinCode) {
- /**
- * 鏋勯�犲唴缃伐鍏锋不鐞嗙洰褰曘��
- * 杩欓噷鐨勭洰褰曟槸杩愯鏃舵牎楠屽拰绠$悊绔瑙堢殑鍏卞悓浜嬪疄鏉ユ簮锛屼笉鑳戒笌宸ュ叿瀹炵幇鑴辫妭銆�
- */
- if (AiDefaults.MCP_BUILTIN_RSF_WMS.equals(builtinCode)) {
- Map<String, AiMcpToolPreviewDto> catalog = new LinkedHashMap<>();
- catalog.put("rsf_query_available_inventory", buildCatalogItem(
- "rsf_query_available_inventory",
- "搴撳瓨鏌ヨ",
- "鏌ヨ鎸囧畾鐗╂枡褰撳墠鍙敤浜庡嚭搴撶殑搴撳瓨鏄庣粏銆�",
- "蹇呴』鎻愪緵鐗╂枡缂栫爜鎴栫墿鏂欏悕绉帮紝骞朵笖鏈�澶氳繑鍥� 50 鏉″簱瀛樿褰曘��",
- List.of("鏌ヨ鐗╂枡 MAT001 褰撳墠鍙嚭搴撳簱瀛�", "鎸夌墿鏂欏悕绉版煡璇㈡墭鐩樺簱瀛樻槑缁�")
- ));
- catalog.put("rsf_query_station_list", buildCatalogItem(
- "rsf_query_station_list",
- "搴撳瓨鏌ヨ",
- "鏌ヨ鎸囧畾浣滀笟绫诲瀷鍙敤鐨勮澶囩珯鐐广��",
- "蹇呴』鎻愪緵绔欑偣绫诲瀷鍒楄〃锛岀被鍨嬫暟閲忔渶澶� 10 涓紝鏈�澶氳繑鍥� 50 涓珯鐐广��",
- List.of("鏌ヨ鍏ュ簱鍜屽嚭搴撲綔涓氬彲鐢ㄧ珯鐐�", "鍒楀嚭 AGV_PICK 绫诲瀷鐨勪綔涓氱珯鐐�")
- ));
- catalog.put("rsf_query_task", buildCatalogItem(
- "rsf_query_task",
- "浠诲姟鏌ヨ",
- "鎸変换鍔″彿銆佺姸鎬併�佺被鍨嬫垨绔欑偣鏉′欢鏌ヨ浠诲姟锛涙敮鎸佷粠鑷劧璇█涓嚜鍔ㄦ彁鍙栦换鍔″彿锛岀簿纭懡涓椂杩斿洖浠诲姟鏄庣粏銆�",
- "杩囨护鏉′欢鍧囦负鍙�夛紝涓嶄紶杩囨护鏉′欢鏃堕粯璁よ繑鍥炴渶杩戜换鍔★紝鏈�澶氳繑鍥� 50 鏉¤褰曘��",
- List.of("鏌ヨ鏈�杩� 10 鏉′换鍔�", "鏌ヨ浠诲姟鍙� TASK24001 鐨勮鎯�")
- ));
- catalog.put("rsf_query_warehouses", buildCatalogItem(
- "rsf_query_warehouses",
- "鍩虹璧勬枡",
- "鏌ヨ浠撳簱鍩虹淇℃伅銆�",
- "鑷冲皯鎻愪緵浠撳簱缂栫爜鎴栧悕绉帮紝鏈�澶氳繑鍥� 50 鏉′粨搴撹褰曘��",
- List.of("鏌ヨ缂栫爜鍖呭惈 WH 鐨勪粨搴�", "鎸変粨搴撳悕绉版煡璇粨搴撳湴鍧�")
- ));
- catalog.put("rsf_query_bas_stations", buildCatalogItem(
- "rsf_query_bas_stations",
- "鍩虹璧勬枡",
- "鏌ヨ鍩虹绔欑偣淇℃伅銆�",
- "鑷冲皯鎻愪緵绔欑偣缂栧彿銆佺珯鐐瑰悕绉版垨浣跨敤鐘舵�佷箣涓�锛屾渶澶氳繑鍥� 50 鏉$珯鐐硅褰曘��",
- List.of("鏌ヨ浣跨敤涓殑鍩虹绔欑偣", "鎸夌珯鐐圭紪鍙锋煡璇㈠熀纭�绔欑偣")
- ));
- catalog.put("rsf_query_dict_data", buildCatalogItem(
- "rsf_query_dict_data",
- "鍩虹璧勬枡",
- "鏌ヨ鎸囧畾瀛楀吀绫诲瀷涓嬬殑瀛楀吀鏁版嵁銆�",
- "蹇呴』鎻愪緵瀛楀吀绫诲瀷缂栫爜锛屾渶澶氳繑鍥� 100 鏉″瓧鍏歌褰曘��",
- List.of("鏌ヨ task_status 瀛楀吀", "鎸夊瓧鍏告爣绛捐繃婊� task_type 瀛楀吀鏁版嵁")
- ));
- return catalog;
- }
- throw new CoolException("涓嶆敮鎸佺殑鍐呯疆 MCP 缂栫爜: " + builtinCode);
- }
-
- private AiMcpToolPreviewDto buildCatalogItem(String name, String toolGroup, String toolPurpose,
- String queryBoundary, List<String> exampleQuestions) {
- /** 缁熶竴鍒涘缓宸ュ叿鐩綍鏉$洰锛岄伩鍏嶄笉鍚屽伐鍏风粍鍑虹幇瀛楁椋庢牸涓嶄竴鑷淬�� */
- return AiMcpToolPreviewDto.builder()
- .name(name)
- .toolGroup(toolGroup)
- .toolPurpose(toolPurpose)
- .queryBoundary(queryBoundary)
- .exampleQuestions(exampleQuestions)
- .build();
- }
}
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/McpMountRuntimeFactoryImpl.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/McpMountRuntimeFactoryImpl.java
index befeb3f..cd8285a 100644
--- a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/McpMountRuntimeFactoryImpl.java
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/McpMountRuntimeFactoryImpl.java
@@ -1,20 +1,13 @@
package com.vincent.rsf.server.ai.service.impl;
-import com.fasterxml.jackson.core.type.TypeReference;
-import com.fasterxml.jackson.databind.ObjectMapper;
import com.vincent.rsf.framework.exception.CoolException;
import com.vincent.rsf.server.ai.config.AiDefaults;
import com.vincent.rsf.server.ai.entity.AiMcpMount;
import com.vincent.rsf.server.ai.service.BuiltinMcpToolRegistry;
import com.vincent.rsf.server.ai.service.MountedToolCallback;
import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory;
-import io.modelcontextprotocol.client.McpClient;
import io.modelcontextprotocol.client.McpSyncClient;
-import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
-import io.modelcontextprotocol.client.transport.ServerParameters;
-import io.modelcontextprotocol.client.transport.StdioClientTransport;
-import io.modelcontextprotocol.json.jackson.JacksonMcpJsonMapper;
-import io.modelcontextprotocol.spec.McpSchema;
+import com.vincent.rsf.server.ai.service.impl.mcp.McpClientFactory;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.mcp.SyncMcpToolCallbackProvider;
@@ -22,22 +15,19 @@
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
-import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashSet;
-import java.util.LinkedHashMap;
import java.util.List;
-import java.util.Map;
@Slf4j
@Service
@RequiredArgsConstructor
public class McpMountRuntimeFactoryImpl implements McpMountRuntimeFactory {
- private final ObjectMapper objectMapper;
private final BuiltinMcpToolRegistry builtinMcpToolRegistry;
+ private final McpClientFactory mcpClientFactory;
/**
* 鎶婁竴缁� MCP 鎸傝浇璁板綍瑙f瀽鎴愪竴娆″璇濆彲鐩存帴浣跨敤鐨勮繍琛屾椂瀵硅薄銆�
@@ -60,7 +50,7 @@
mountedNames.add(mount.getName());
continue;
}
- McpSyncClient client = createClient(mount);
+ McpSyncClient client = mcpClientFactory.createClient(mount);
client.initialize();
client.listTools();
clients.add(client);
@@ -109,84 +99,6 @@
}
if (!duplicateNames.isEmpty()) {
throw new CoolException("MCP 宸ュ叿鍚嶇О閲嶅锛岃璋冩暣鎸傝浇閰嶇疆: " + String.join(", ", duplicateNames));
- }
- }
-
- private McpSyncClient createClient(AiMcpMount mount) {
- /**
- * 鎸夋寕杞介厤缃姩鎬佸垱寤� MCP Client銆�
- * 璇ユ柟娉曞彧璐熻矗 transport 灞傚垵濮嬪寲锛屼笉璐熻矗宸ュ叿鍘婚噸鍜岄敊璇仛鍚堛��
- */
- Duration timeout = Duration.ofMillis(mount.getRequestTimeoutMs() == null
- ? AiDefaults.DEFAULT_TIMEOUT_MS
- : mount.getRequestTimeoutMs());
- JacksonMcpJsonMapper jsonMapper = new JacksonMcpJsonMapper(objectMapper);
- if (AiDefaults.MCP_TRANSPORT_STDIO.equals(mount.getTransportType())) {
- ServerParameters.Builder parametersBuilder = ServerParameters.builder(mount.getCommand());
- List<String> args = readStringList(mount.getArgsJson());
- if (!args.isEmpty()) {
- parametersBuilder.args(args);
- }
- Map<String, String> env = readStringMap(mount.getEnvJson());
- if (!env.isEmpty()) {
- parametersBuilder.env(env);
- }
- StdioClientTransport transport = new StdioClientTransport(parametersBuilder.build(), jsonMapper);
- transport.setStdErrorHandler(message -> log.warn("MCP STDIO stderr [{}]: {}", mount.getName(), message));
- return McpClient.sync(transport)
- .requestTimeout(timeout)
- .initializationTimeout(timeout)
- .clientInfo(new McpSchema.Implementation("rsf-ai-client", "RSF AI Client", "1.0.0"))
- .build();
- }
- if (!AiDefaults.MCP_TRANSPORT_SSE_HTTP.equals(mount.getTransportType())) {
- throw new CoolException("涓嶆敮鎸佺殑 MCP 浼犺緭绫诲瀷: " + mount.getTransportType());
- }
-
- if (!StringUtils.hasText(mount.getServerUrl())) {
- throw new CoolException("MCP 鏈嶅姟鍦板潃涓嶈兘涓虹┖");
- }
- HttpClientSseClientTransport.Builder transportBuilder = HttpClientSseClientTransport.builder(mount.getServerUrl())
- .jsonMapper(jsonMapper)
- .connectTimeout(timeout);
- if (StringUtils.hasText(mount.getEndpoint())) {
- transportBuilder.sseEndpoint(mount.getEndpoint());
- }
- Map<String, String> headers = readStringMap(mount.getHeadersJson());
- if (!headers.isEmpty()) {
- transportBuilder.customizeRequest(builder -> headers.forEach(builder::header));
- }
- return McpClient.sync(transportBuilder.build())
- .requestTimeout(timeout)
- .initializationTimeout(timeout)
- .clientInfo(new McpSchema.Implementation("rsf-ai-client", "RSF AI Client", "1.0.0"))
- .build();
- }
-
- private List<String> readStringList(String json) {
- /** 瑙f瀽鎸傝浇琛ㄩ噷鐨� JSON 鏁扮粍閰嶇疆锛屼緥濡� STDIO args銆� */
- if (!StringUtils.hasText(json)) {
- return Collections.emptyList();
- }
- try {
- return objectMapper.readValue(json, new TypeReference<List<String>>() {
- });
- } catch (Exception e) {
- throw new CoolException("瑙f瀽 MCP 鍒楄〃閰嶇疆澶辫触: " + e.getMessage());
- }
- }
-
- private Map<String, String> readStringMap(String json) {
- /** 瑙f瀽鎸傝浇琛ㄩ噷鐨� JSON Map 閰嶇疆锛屼緥濡� headers 鎴栫幆澧冨彉閲忋�� */
- if (!StringUtils.hasText(json)) {
- return Collections.emptyMap();
- }
- try {
- Map<String, String> result = objectMapper.readValue(json, new TypeReference<LinkedHashMap<String, String>>() {
- });
- return result == null ? Collections.emptyMap() : result;
- } catch (Exception e) {
- throw new CoolException("瑙f瀽 MCP Map 閰嶇疆澶辫触: " + e.getMessage());
}
}
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/chat/AiChatFailureHandler.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/chat/AiChatFailureHandler.java
new file mode 100644
index 0000000..9fa35c5
--- /dev/null
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/chat/AiChatFailureHandler.java
@@ -0,0 +1,104 @@
+package com.vincent.rsf.server.ai.service.impl.chat;
+
+import com.vincent.rsf.server.ai.dto.AiChatErrorDto;
+import com.vincent.rsf.server.ai.enums.AiErrorCategory;
+import com.vincent.rsf.server.ai.exception.AiChatException;
+import com.vincent.rsf.server.ai.service.AiCallLogService;
+import com.vincent.rsf.server.ai.store.AiStreamStateStore;
+import lombok.RequiredArgsConstructor;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.stereotype.Component;
+import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
+
+import java.time.Instant;
+
+@Slf4j
+@Component
+@RequiredArgsConstructor
+public class AiChatFailureHandler {
+
+ private final AiSseEventPublisher aiSseEventPublisher;
+ private final AiCallLogService aiCallLogService;
+ private final AiStreamStateStore aiStreamStateStore;
+
+ public AiChatException buildAiException(String code, AiErrorCategory category, String stage, String message, Throwable cause) {
+ return new AiChatException(code, category, stage, message, cause);
+ }
+
+ public void handleStreamFailure(SseEmitter emitter, String requestId, Long sessionId, String model, long startedAt,
+ Long firstTokenAt, AiChatException exception, Long callLogId,
+ long toolSuccessCount, long toolFailureCount,
+ AiThinkingTraceEmitter thinkingTraceEmitter,
+ Long tenantId, Long userId, String promptCode) {
+ if (isClientAbortException(exception)) {
+ log.warn("AI chat aborted by client, requestId={}, sessionId={}, stage={}, message={}",
+ requestId, sessionId, exception.getStage(), exception.getMessage());
+ if (thinkingTraceEmitter != null) {
+ thinkingTraceEmitter.markTerminated("ABORTED");
+ }
+ aiSseEventPublisher.emitSafely(emitter, "status",
+ aiSseEventPublisher.buildTerminalStatus(requestId, sessionId, "ABORTED", model, startedAt, firstTokenAt));
+ aiCallLogService.failCallLog(
+ callLogId,
+ "ABORTED",
+ exception.getCategory().name(),
+ exception.getStage(),
+ exception.getMessage(),
+ System.currentTimeMillis() - startedAt,
+ aiSseEventPublisher.resolveFirstTokenLatency(startedAt, firstTokenAt),
+ toolSuccessCount,
+ toolFailureCount
+ );
+ aiStreamStateStore.markStreamState(requestId, tenantId, userId, sessionId, promptCode, "ABORTED", exception.getMessage());
+ emitter.completeWithError(exception);
+ return;
+ }
+ log.error("AI chat failed, requestId={}, sessionId={}, category={}, stage={}, message={}",
+ requestId, sessionId, exception.getCategory(), exception.getStage(), exception.getMessage(), exception);
+ if (thinkingTraceEmitter != null) {
+ thinkingTraceEmitter.markTerminated("FAILED");
+ }
+ aiSseEventPublisher.emitSafely(emitter, "status",
+ aiSseEventPublisher.buildTerminalStatus(requestId, sessionId, "FAILED", model, startedAt, firstTokenAt));
+ aiSseEventPublisher.emitSafely(emitter, "error", AiChatErrorDto.builder()
+ .requestId(requestId)
+ .sessionId(sessionId)
+ .code(exception.getCode())
+ .category(exception.getCategory().name())
+ .stage(exception.getStage())
+ .message(exception.getMessage())
+ .timestamp(Instant.now().toEpochMilli())
+ .build());
+ aiCallLogService.failCallLog(
+ callLogId,
+ "FAILED",
+ exception.getCategory().name(),
+ exception.getStage(),
+ exception.getMessage(),
+ System.currentTimeMillis() - startedAt,
+ aiSseEventPublisher.resolveFirstTokenLatency(startedAt, firstTokenAt),
+ toolSuccessCount,
+ toolFailureCount
+ );
+ aiStreamStateStore.markStreamState(requestId, tenantId, userId, sessionId, promptCode, "FAILED", exception.getMessage());
+ emitter.completeWithError(exception);
+ }
+
+ private boolean isClientAbortException(Throwable throwable) {
+ Throwable current = throwable;
+ while (current != null) {
+ String message = current.getMessage();
+ if (message != null) {
+ String normalized = message.toLowerCase();
+ if (normalized.contains("broken pipe")
+ || normalized.contains("connection reset")
+ || normalized.contains("forcibly closed")
+ || normalized.contains("abort")) {
+ return true;
+ }
+ }
+ current = current.getCause();
+ }
+ return false;
+ }
+}
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/chat/AiChatOrchestrator.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/chat/AiChatOrchestrator.java
new file mode 100644
index 0000000..055192b
--- /dev/null
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/chat/AiChatOrchestrator.java
@@ -0,0 +1,291 @@
+package com.vincent.rsf.server.ai.service.impl.chat;
+
+import com.vincent.rsf.server.ai.dto.AiChatMemoryDto;
+import com.vincent.rsf.server.ai.dto.AiChatMessageDto;
+import com.vincent.rsf.server.ai.dto.AiChatModelOptionDto;
+import com.vincent.rsf.server.ai.dto.AiChatRequest;
+import com.vincent.rsf.server.ai.dto.AiChatStatusDto;
+import com.vincent.rsf.server.ai.dto.AiResolvedConfig;
+import com.vincent.rsf.server.ai.entity.AiCallLog;
+import com.vincent.rsf.server.ai.entity.AiChatSession;
+import com.vincent.rsf.server.ai.enums.AiErrorCategory;
+import com.vincent.rsf.server.ai.exception.AiChatException;
+import com.vincent.rsf.server.ai.service.AiCallLogService;
+import com.vincent.rsf.server.ai.service.AiChatMemoryService;
+import com.vincent.rsf.server.ai.service.AiConfigResolverService;
+import com.vincent.rsf.server.ai.service.AiParamService;
+import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory;
+import com.vincent.rsf.server.ai.store.AiChatRateLimiter;
+import com.vincent.rsf.server.ai.store.AiStreamStateStore;
+import lombok.RequiredArgsConstructor;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.ai.chat.metadata.ChatResponseMetadata;
+import org.springframework.ai.chat.model.ChatResponse;
+import org.springframework.ai.chat.prompt.Prompt;
+import org.springframework.ai.openai.OpenAiChatModel;
+import org.springframework.ai.tool.ToolCallback;
+import org.springframework.stereotype.Component;
+import org.springframework.util.StringUtils;
+import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
+import reactor.core.publisher.Flux;
+
+import java.time.Instant;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.atomic.AtomicReference;
+
+@Slf4j
+@Component
+@RequiredArgsConstructor
+public class AiChatOrchestrator {
+
+ private final AiConfigResolverService aiConfigResolverService;
+ private final AiChatMemoryService aiChatMemoryService;
+ private final AiParamService aiParamService;
+ private final McpMountRuntimeFactory mcpMountRuntimeFactory;
+ private final AiCallLogService aiCallLogService;
+ private final AiChatRateLimiter aiChatRateLimiter;
+ private final AiStreamStateStore aiStreamStateStore;
+ private final AiChatRuntimeAssembler aiChatRuntimeAssembler;
+ private final AiPromptMessageBuilder aiPromptMessageBuilder;
+ private final AiOpenAiChatModelFactory aiOpenAiChatModelFactory;
+ private final AiToolObservationService aiToolObservationService;
+ private final AiSseEventPublisher aiSseEventPublisher;
+ private final AiChatFailureHandler aiChatFailureHandler;
+
+ public void executeStream(AiChatRequest request, Long userId, Long tenantId, SseEmitter emitter) {
+ String requestId = request.getRequestId();
+ long startedAt = System.currentTimeMillis();
+ AtomicReference<Long> firstTokenAtRef = new AtomicReference<>();
+ AtomicLong toolCallSequence = new AtomicLong(0);
+ AtomicLong toolSuccessCount = new AtomicLong(0);
+ AtomicLong toolFailureCount = new AtomicLong(0);
+ Long sessionId = request.getSessionId();
+ Long callLogId = null;
+ String model = null;
+ String resolvedPromptCode = request.getPromptCode();
+ AiThinkingTraceEmitter thinkingTraceEmitter = null;
+ try {
+ ensureIdentity(userId, tenantId);
+ AiResolvedConfig config = resolveConfig(request, tenantId);
+ List<AiChatModelOptionDto> modelOptions = aiParamService.listChatModelOptions(tenantId);
+ resolvedPromptCode = config.getPromptCode();
+ if (!aiChatRateLimiter.allowChatRequest(tenantId, userId, config.getPromptCode())) {
+ throw aiChatFailureHandler.buildAiException("AI_RATE_LIMITED", AiErrorCategory.REQUEST, "RATE_LIMIT",
+ "褰撳墠鎻愰棶杩囦簬棰戠箒锛岃绋嶅悗鍐嶈瘯", null);
+ }
+ final String resolvedModel = config.getAiParam().getModel();
+ model = resolvedModel;
+ AiChatSession session = resolveSession(request, userId, tenantId, config.getPromptCode());
+ sessionId = session.getId();
+ aiStreamStateStore.markStreamState(requestId, tenantId, userId, sessionId, config.getPromptCode(), "RUNNING", null);
+ AiChatMemoryDto memory = loadMemory(userId, tenantId, config.getPromptCode(), session.getId());
+ List<AiChatMessageDto> mergedMessages = aiPromptMessageBuilder.mergeMessages(memory.getShortMemoryMessages(), request.getMessages());
+ AiCallLog callLog = aiCallLogService.startCallLog(
+ requestId,
+ session.getId(),
+ userId,
+ tenantId,
+ config.getPromptCode(),
+ config.getPrompt().getName(),
+ config.getAiParam().getModel(),
+ config.getMcpMounts().size(),
+ config.getMcpMounts().size(),
+ config.getMcpMounts().stream().map(item -> item.getName()).toList()
+ );
+ callLogId = callLog.getId();
+ try (McpMountRuntimeFactory.McpMountRuntime runtime = createRuntime(config, userId)) {
+ aiSseEventPublisher.emitStrict(emitter, "start", aiChatRuntimeAssembler.buildRuntimeSnapshot(
+ requestId,
+ session.getId(),
+ config,
+ modelOptions,
+ runtime.getMountedCount(),
+ runtime.getMountedNames(),
+ runtime.getErrors(),
+ memory
+ ));
+ aiSseEventPublisher.emitSafely(emitter, "status", AiChatStatusDto.builder()
+ .requestId(requestId)
+ .sessionId(session.getId())
+ .status("STARTED")
+ .model(resolvedModel)
+ .timestamp(Instant.now().toEpochMilli())
+ .elapsedMs(0L)
+ .build());
+ log.info("AI chat started, requestId={}, userId={}, tenantId={}, sessionId={}, model={}",
+ requestId, userId, tenantId, session.getId(), resolvedModel);
+ thinkingTraceEmitter = new AiThinkingTraceEmitter(aiSseEventPublisher, emitter, requestId, session.getId());
+ thinkingTraceEmitter.startAnalyze();
+ AiThinkingTraceEmitter activeThinkingTraceEmitter = thinkingTraceEmitter;
+
+ ToolCallback[] observableToolCallbacks = aiToolObservationService.wrapToolCallbacks(
+ runtime.getToolCallbacks(), emitter, requestId, session.getId(), toolCallSequence,
+ toolSuccessCount, toolFailureCount, callLogId, userId, tenantId, activeThinkingTraceEmitter
+ );
+ Prompt prompt = new Prompt(
+ aiPromptMessageBuilder.buildPromptMessages(memory, mergedMessages, config.getPrompt(), request.getMetadata()),
+ aiOpenAiChatModelFactory.buildChatOptions(config.getAiParam(), observableToolCallbacks, userId, tenantId,
+ requestId, session.getId(), request.getMetadata())
+ );
+ OpenAiChatModel chatModel = aiOpenAiChatModelFactory.createChatModel(config.getAiParam());
+ if (Boolean.FALSE.equals(config.getAiParam().getStreamingEnabled())) {
+ ChatResponse response = invokeChatCall(chatModel, prompt);
+ String content = extractContent(response);
+ aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), content);
+ if (StringUtils.hasText(content)) {
+ aiSseEventPublisher.markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt, activeThinkingTraceEmitter);
+ aiSseEventPublisher.emitStrict(emitter, "delta", aiSseEventPublisher.buildMessagePayload("requestId", requestId, "content", content));
+ }
+ activeThinkingTraceEmitter.completeCurrentPhase();
+ aiSseEventPublisher.emitDone(emitter, requestId, response.getMetadata(), config.getAiParam().getModel(),
+ session.getId(), startedAt, firstTokenAtRef.get());
+ aiSseEventPublisher.emitSafely(emitter, "status",
+ aiSseEventPublisher.buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get()));
+ aiCallLogService.completeCallLog(
+ callLogId,
+ "COMPLETED",
+ System.currentTimeMillis() - startedAt,
+ aiSseEventPublisher.resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()),
+ response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getPromptTokens(),
+ response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getCompletionTokens(),
+ response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getTotalTokens(),
+ toolSuccessCount.get(),
+ toolFailureCount.get()
+ );
+ aiStreamStateStore.markStreamState(requestId, tenantId, userId, session.getId(), config.getPromptCode(), "COMPLETED", null);
+ log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}",
+ requestId, session.getId(), System.currentTimeMillis() - startedAt,
+ aiSseEventPublisher.resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()));
+ emitter.complete();
+ return;
+ }
+
+ Flux<ChatResponse> responseFlux = invokeChatStream(chatModel, prompt);
+ AtomicReference<ChatResponseMetadata> lastMetadata = new AtomicReference<>();
+ StringBuilder assistantContent = new StringBuilder();
+ try {
+ responseFlux.doOnNext(response -> {
+ lastMetadata.set(response.getMetadata());
+ String content = extractContent(response);
+ if (StringUtils.hasText(content)) {
+ aiSseEventPublisher.markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt, activeThinkingTraceEmitter);
+ assistantContent.append(content);
+ aiSseEventPublisher.emitStrict(emitter, "delta",
+ aiSseEventPublisher.buildMessagePayload("requestId", requestId, "content", content));
+ }
+ })
+ .blockLast();
+ } catch (Exception e) {
+ throw aiChatFailureHandler.buildAiException("AI_MODEL_STREAM_ERROR", AiErrorCategory.MODEL, "MODEL_STREAM",
+ e == null ? "AI 妯″瀷娴佸紡璋冪敤澶辫触" : e.getMessage(), e);
+ }
+ aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), assistantContent.toString());
+ activeThinkingTraceEmitter.completeCurrentPhase();
+ aiSseEventPublisher.emitDone(emitter, requestId, lastMetadata.get(), config.getAiParam().getModel(),
+ session.getId(), startedAt, firstTokenAtRef.get());
+ aiSseEventPublisher.emitSafely(emitter, "status",
+ aiSseEventPublisher.buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get()));
+ aiCallLogService.completeCallLog(
+ callLogId,
+ "COMPLETED",
+ System.currentTimeMillis() - startedAt,
+ aiSseEventPublisher.resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()),
+ lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getPromptTokens(),
+ lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getCompletionTokens(),
+ lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getTotalTokens(),
+ toolSuccessCount.get(),
+ toolFailureCount.get()
+ );
+ aiStreamStateStore.markStreamState(requestId, tenantId, userId, session.getId(), config.getPromptCode(), "COMPLETED", null);
+ log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}",
+ requestId, session.getId(), System.currentTimeMillis() - startedAt,
+ aiSseEventPublisher.resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()));
+ emitter.complete();
+ }
+ } catch (AiChatException e) {
+ aiChatFailureHandler.handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(), e,
+ callLogId, toolSuccessCount.get(), toolFailureCount.get(), thinkingTraceEmitter,
+ tenantId, userId, resolvedPromptCode);
+ } catch (Exception e) {
+ aiChatFailureHandler.handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(),
+ aiChatFailureHandler.buildAiException("AI_INTERNAL_ERROR", AiErrorCategory.INTERNAL, "INTERNAL",
+ e == null ? "AI 瀵硅瘽澶辫触" : e.getMessage(), e),
+ callLogId, toolSuccessCount.get(), toolFailureCount.get(), thinkingTraceEmitter,
+ tenantId, userId, resolvedPromptCode);
+ } finally {
+ log.debug("AI chat stream finished, requestId={}", requestId);
+ }
+ }
+
+ private void ensureIdentity(Long userId, Long tenantId) {
+ if (userId == null) {
+ throw aiChatFailureHandler.buildAiException("AI_AUTH_USER_MISSING", AiErrorCategory.AUTH, "AUTH_VALIDATE", "褰撳墠鐧诲綍鐢ㄦ埛涓嶅瓨鍦�", null);
+ }
+ if (tenantId == null) {
+ throw aiChatFailureHandler.buildAiException("AI_AUTH_TENANT_MISSING", AiErrorCategory.AUTH, "AUTH_VALIDATE", "褰撳墠绉熸埛涓嶅瓨鍦�", null);
+ }
+ }
+
+ private AiResolvedConfig resolveConfig(AiChatRequest request, Long tenantId) {
+ try {
+ return aiConfigResolverService.resolve(request.getPromptCode(), tenantId, request.getAiParamId());
+ } catch (Exception e) {
+ throw aiChatFailureHandler.buildAiException("AI_CONFIG_RESOLVE_ERROR", AiErrorCategory.CONFIG, "CONFIG_RESOLVE",
+ e == null ? "AI 閰嶇疆瑙f瀽澶辫触" : e.getMessage(), e);
+ }
+ }
+
+ private AiChatSession resolveSession(AiChatRequest request, Long userId, Long tenantId, String promptCode) {
+ try {
+ return aiChatMemoryService.resolveSession(userId, tenantId, promptCode, request.getSessionId(),
+ aiPromptMessageBuilder.resolveTitleSeed(request.getMessages()));
+ } catch (Exception e) {
+ throw aiChatFailureHandler.buildAiException("AI_SESSION_RESOLVE_ERROR", AiErrorCategory.REQUEST, "SESSION_RESOLVE",
+ e == null ? "AI 浼氳瘽瑙f瀽澶辫触" : e.getMessage(), e);
+ }
+ }
+
+ private AiChatMemoryDto loadMemory(Long userId, Long tenantId, String promptCode, Long sessionId) {
+ try {
+ return aiChatMemoryService.getMemory(userId, tenantId, promptCode, sessionId);
+ } catch (Exception e) {
+ throw aiChatFailureHandler.buildAiException("AI_MEMORY_LOAD_ERROR", AiErrorCategory.REQUEST, "MEMORY_LOAD",
+ e == null ? "AI 浼氳瘽璁板繂鍔犺浇澶辫触" : e.getMessage(), e);
+ }
+ }
+
+ private McpMountRuntimeFactory.McpMountRuntime createRuntime(AiResolvedConfig config, Long userId) {
+ try {
+ return mcpMountRuntimeFactory.create(config.getMcpMounts(), userId);
+ } catch (Exception e) {
+ throw aiChatFailureHandler.buildAiException("AI_MCP_MOUNT_ERROR", AiErrorCategory.MCP, "MCP_MOUNT",
+ e == null ? "MCP 鎸傝浇澶辫触" : e.getMessage(), e);
+ }
+ }
+
+ private ChatResponse invokeChatCall(OpenAiChatModel chatModel, Prompt prompt) {
+ try {
+ return chatModel.call(prompt);
+ } catch (Exception e) {
+ throw aiChatFailureHandler.buildAiException("AI_MODEL_CALL_ERROR", AiErrorCategory.MODEL, "MODEL_CALL",
+ e == null ? "AI 妯″瀷璋冪敤澶辫触" : e.getMessage(), e);
+ }
+ }
+
+ private Flux<ChatResponse> invokeChatStream(OpenAiChatModel chatModel, Prompt prompt) {
+ try {
+ return chatModel.stream(prompt);
+ } catch (Exception e) {
+ throw aiChatFailureHandler.buildAiException("AI_MODEL_STREAM_ERROR", AiErrorCategory.MODEL, "MODEL_STREAM_INIT",
+ e == null ? "AI 妯″瀷娴佸紡璋冪敤澶辫触" : e.getMessage(), e);
+ }
+ }
+
+ private String extractContent(ChatResponse response) {
+ if (response == null || response.getResult() == null || response.getResult().getOutput() == null) {
+ return null;
+ }
+ return response.getResult().getOutput().getText();
+ }
+}
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/chat/AiChatRuntimeAssembler.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/chat/AiChatRuntimeAssembler.java
new file mode 100644
index 0000000..6a673e4
--- /dev/null
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/chat/AiChatRuntimeAssembler.java
@@ -0,0 +1,36 @@
+package com.vincent.rsf.server.ai.service.impl.chat;
+
+import com.vincent.rsf.server.ai.dto.AiChatMemoryDto;
+import com.vincent.rsf.server.ai.dto.AiChatModelOptionDto;
+import com.vincent.rsf.server.ai.dto.AiChatRuntimeDto;
+import com.vincent.rsf.server.ai.dto.AiResolvedConfig;
+import org.springframework.stereotype.Component;
+
+import java.util.List;
+
+@Component
+public class AiChatRuntimeAssembler {
+
+ public AiChatRuntimeDto buildRuntimeSnapshot(String requestId, Long sessionId, AiResolvedConfig config,
+ List<AiChatModelOptionDto> modelOptions, Integer mountedMcpCount,
+ List<String> mountedMcpNames, List<String> mountErrors,
+ AiChatMemoryDto memory) {
+ return AiChatRuntimeDto.builder()
+ .requestId(requestId)
+ .sessionId(sessionId)
+ .aiParamId(config.getAiParam().getId())
+ .promptCode(config.getPromptCode())
+ .promptName(config.getPrompt().getName())
+ .model(config.getAiParam().getModel())
+ .modelOptions(modelOptions)
+ .configuredMcpCount(config.getMcpMounts().size())
+ .mountedMcpCount(mountedMcpCount)
+ .mountedMcpNames(mountedMcpNames)
+ .mountErrors(mountErrors)
+ .memorySummary(memory.getMemorySummary())
+ .memoryFacts(memory.getMemoryFacts())
+ .recentMessageCount(memory.getRecentMessageCount())
+ .persistedMessages(memory.getPersistedMessages())
+ .build();
+ }
+}
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/chat/AiOpenAiChatModelFactory.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/chat/AiOpenAiChatModelFactory.java
new file mode 100644
index 0000000..ed76691
--- /dev/null
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/chat/AiOpenAiChatModelFactory.java
@@ -0,0 +1,89 @@
+package com.vincent.rsf.server.ai.service.impl.chat;
+
+import com.vincent.rsf.framework.exception.CoolException;
+import com.vincent.rsf.server.ai.entity.AiParam;
+import com.vincent.rsf.server.ai.enums.AiErrorCategory;
+import com.vincent.rsf.server.ai.exception.AiChatException;
+import com.vincent.rsf.server.ai.service.impl.AiOpenAiApiSupport;
+import io.micrometer.observation.ObservationRegistry;
+import lombok.RequiredArgsConstructor;
+import org.springframework.ai.model.tool.DefaultToolCallingManager;
+import org.springframework.ai.model.tool.ToolCallingManager;
+import org.springframework.ai.openai.OpenAiChatModel;
+import org.springframework.ai.openai.OpenAiChatOptions;
+import org.springframework.ai.openai.api.OpenAiApi;
+import org.springframework.ai.tool.ToolCallback;
+import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor;
+import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver;
+import org.springframework.ai.util.json.schema.SchemaType;
+import org.springframework.context.support.GenericApplicationContext;
+import org.springframework.stereotype.Component;
+
+import java.util.Arrays;
+import java.util.LinkedHashMap;
+import java.util.Map;
+
+@Component
+@RequiredArgsConstructor
+public class AiOpenAiChatModelFactory {
+
+ private final GenericApplicationContext applicationContext;
+ private final ObservationRegistry observationRegistry;
+
+ public OpenAiChatModel createChatModel(AiParam aiParam) {
+ OpenAiApi openAiApi = AiOpenAiApiSupport.buildOpenAiApi(aiParam);
+ ToolCallingManager toolCallingManager = DefaultToolCallingManager.builder()
+ .observationRegistry(observationRegistry)
+ .toolCallbackResolver(new SpringBeanToolCallbackResolver(applicationContext, SchemaType.OPEN_API_SCHEMA))
+ .toolExecutionExceptionProcessor(new DefaultToolExecutionExceptionProcessor(false))
+ .build();
+ return new OpenAiChatModel(
+ openAiApi,
+ OpenAiChatOptions.builder()
+ .model(aiParam.getModel())
+ .temperature(aiParam.getTemperature())
+ .topP(aiParam.getTopP())
+ .maxTokens(aiParam.getMaxTokens())
+ .streamUsage(true)
+ .build(),
+ toolCallingManager,
+ org.springframework.retry.support.RetryTemplate.builder().maxAttempts(1).build(),
+ observationRegistry
+ );
+ }
+
+ public OpenAiChatOptions buildChatOptions(AiParam aiParam, ToolCallback[] toolCallbacks, Long userId, Long tenantId,
+ String requestId, Long sessionId, Map<String, Object> metadata) {
+ if (userId == null) {
+ throw new AiChatException("AI_AUTH_USER_MISSING", AiErrorCategory.AUTH, "OPTIONS_BUILD", "褰撳墠鐧诲綍鐢ㄦ埛涓嶅瓨鍦�", null);
+ }
+ OpenAiChatOptions.Builder builder = OpenAiChatOptions.builder()
+ .model(aiParam.getModel())
+ .temperature(aiParam.getTemperature())
+ .topP(aiParam.getTopP())
+ .maxTokens(aiParam.getMaxTokens())
+ .streamUsage(true)
+ .user(String.valueOf(userId));
+ if (toolCallbacks != null && toolCallbacks.length > 0) {
+ builder.toolCallbacks(Arrays.stream(toolCallbacks).toList());
+ }
+ Map<String, Object> toolContext = new LinkedHashMap<>();
+ toolContext.put("userId", userId);
+ toolContext.put("tenantId", tenantId);
+ toolContext.put("requestId", requestId);
+ toolContext.put("sessionId", sessionId);
+ Map<String, String> metadataMap = new LinkedHashMap<>();
+ if (metadata != null) {
+ metadata.forEach((key, value) -> {
+ String normalized = value == null ? "" : String.valueOf(value);
+ metadataMap.put(key, normalized);
+ toolContext.put(key, normalized);
+ });
+ }
+ builder.toolContext(toolContext);
+ if (!metadataMap.isEmpty()) {
+ builder.metadata(metadataMap);
+ }
+ return builder.build();
+ }
+}
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/chat/AiPromptMessageBuilder.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/chat/AiPromptMessageBuilder.java
new file mode 100644
index 0000000..0bf45b3
--- /dev/null
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/chat/AiPromptMessageBuilder.java
@@ -0,0 +1,116 @@
+package com.vincent.rsf.server.ai.service.impl.chat;
+
+import com.vincent.rsf.framework.common.Cools;
+import com.vincent.rsf.framework.exception.CoolException;
+import com.vincent.rsf.server.ai.dto.AiChatMemoryDto;
+import com.vincent.rsf.server.ai.dto.AiChatMessageDto;
+import com.vincent.rsf.server.ai.entity.AiPrompt;
+import org.springframework.ai.chat.messages.AssistantMessage;
+import org.springframework.ai.chat.messages.Message;
+import org.springframework.ai.chat.messages.SystemMessage;
+import org.springframework.ai.chat.messages.UserMessage;
+import org.springframework.stereotype.Component;
+import org.springframework.util.StringUtils;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+
+@Component
+public class AiPromptMessageBuilder {
+
+ public List<Message> buildPromptMessages(AiChatMemoryDto memory, List<AiChatMessageDto> sourceMessages, AiPrompt aiPrompt,
+ Map<String, Object> metadata) {
+ if (Cools.isEmpty(sourceMessages)) {
+ throw new CoolException("瀵硅瘽娑堟伅涓嶈兘涓虹┖");
+ }
+ List<Message> messages = new ArrayList<>();
+ if (StringUtils.hasText(aiPrompt.getSystemPrompt())) {
+ messages.add(new SystemMessage(aiPrompt.getSystemPrompt()));
+ }
+ if (memory != null && StringUtils.hasText(memory.getMemorySummary())) {
+ messages.add(new SystemMessage("鍘嗗彶鎽樿:\n" + memory.getMemorySummary()));
+ }
+ if (memory != null && StringUtils.hasText(memory.getMemoryFacts())) {
+ messages.add(new SystemMessage("鍏抽敭浜嬪疄:\n" + memory.getMemoryFacts()));
+ }
+ int lastUserIndex = -1;
+ for (int i = 0; i < sourceMessages.size(); i++) {
+ AiChatMessageDto item = sourceMessages.get(i);
+ if (item != null && "user".equalsIgnoreCase(item.getRole())) {
+ lastUserIndex = i;
+ }
+ }
+ for (int i = 0; i < sourceMessages.size(); i++) {
+ AiChatMessageDto item = sourceMessages.get(i);
+ if (item == null || !StringUtils.hasText(item.getContent())) {
+ continue;
+ }
+ String role = item.getRole() == null ? "user" : item.getRole().toLowerCase();
+ if ("system".equals(role)) {
+ continue;
+ }
+ String content = item.getContent();
+ if ("user".equals(role) && i == lastUserIndex) {
+ content = renderUserPrompt(aiPrompt.getUserPromptTemplate(), content, metadata);
+ }
+ if ("assistant".equals(role)) {
+ messages.add(new AssistantMessage(content));
+ } else {
+ messages.add(new UserMessage(content));
+ }
+ }
+ if (messages.stream().noneMatch(item -> item instanceof UserMessage)) {
+ throw new CoolException("鑷冲皯闇�瑕佷竴鏉$敤鎴锋秷鎭�");
+ }
+ return messages;
+ }
+
+ public List<AiChatMessageDto> mergeMessages(List<AiChatMessageDto> persistedMessages, List<AiChatMessageDto> memoryMessages) {
+ List<AiChatMessageDto> merged = new ArrayList<>();
+ if (!Cools.isEmpty(persistedMessages)) {
+ merged.addAll(persistedMessages);
+ }
+ if (!Cools.isEmpty(memoryMessages)) {
+ merged.addAll(memoryMessages);
+ }
+ if (merged.isEmpty()) {
+ throw new CoolException("瀵硅瘽娑堟伅涓嶈兘涓虹┖");
+ }
+ return merged;
+ }
+
+ public String resolveTitleSeed(List<AiChatMessageDto> messages) {
+ if (Cools.isEmpty(messages)) {
+ throw new CoolException("瀵硅瘽娑堟伅涓嶈兘涓虹┖");
+ }
+ for (int i = messages.size() - 1; i >= 0; i--) {
+ AiChatMessageDto item = messages.get(i);
+ if (item != null && "user".equalsIgnoreCase(item.getRole()) && StringUtils.hasText(item.getContent())) {
+ return item.getContent();
+ }
+ }
+ throw new CoolException("鑷冲皯闇�瑕佷竴鏉$敤鎴锋秷鎭�");
+ }
+
+ private String renderUserPrompt(String userPromptTemplate, String content, Map<String, Object> metadata) {
+ if (!StringUtils.hasText(userPromptTemplate)) {
+ return content;
+ }
+ String rendered = userPromptTemplate
+ .replace("{{input}}", content)
+ .replace("{input}", content);
+ if (metadata != null) {
+ for (Map.Entry<String, Object> entry : metadata.entrySet()) {
+ String value = entry.getValue() == null ? "" : String.valueOf(entry.getValue());
+ rendered = rendered.replace("{{" + entry.getKey() + "}}", value);
+ rendered = rendered.replace("{" + entry.getKey() + "}", value);
+ }
+ }
+ if (Objects.equals(rendered, userPromptTemplate)) {
+ return userPromptTemplate + "\n\n" + content;
+ }
+ return rendered;
+ }
+}
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/chat/AiSseEventPublisher.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/chat/AiSseEventPublisher.java
new file mode 100644
index 0000000..aace1a3
--- /dev/null
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/chat/AiSseEventPublisher.java
@@ -0,0 +1,112 @@
+package com.vincent.rsf.server.ai.service.impl.chat;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.vincent.rsf.framework.exception.CoolException;
+import com.vincent.rsf.server.ai.dto.AiChatDoneDto;
+import com.vincent.rsf.server.ai.dto.AiChatStatusDto;
+import com.vincent.rsf.server.ai.enums.AiErrorCategory;
+import com.vincent.rsf.server.ai.exception.AiChatException;
+import lombok.RequiredArgsConstructor;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.ai.chat.metadata.ChatResponseMetadata;
+import org.springframework.ai.chat.metadata.Usage;
+import org.springframework.http.MediaType;
+import org.springframework.stereotype.Component;
+import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
+
+import java.io.IOException;
+import java.time.Instant;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicReference;
+
+@Slf4j
+@Component
+@RequiredArgsConstructor
+public class AiSseEventPublisher {
+
+ private final ObjectMapper objectMapper;
+
+ public void markFirstToken(AtomicReference<Long> firstTokenAtRef, SseEmitter emitter, String requestId,
+ Long sessionId, String model, long startedAt, AiThinkingTraceEmitter thinkingTraceEmitter) {
+ if (!firstTokenAtRef.compareAndSet(null, System.currentTimeMillis())) {
+ return;
+ }
+ if (thinkingTraceEmitter != null) {
+ thinkingTraceEmitter.startAnswer();
+ }
+ emitSafely(emitter, "status", AiChatStatusDto.builder()
+ .requestId(requestId)
+ .sessionId(sessionId)
+ .status("FIRST_TOKEN")
+ .model(model)
+ .timestamp(Instant.now().toEpochMilli())
+ .elapsedMs(System.currentTimeMillis() - startedAt)
+ .firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()))
+ .build());
+ }
+
+ public void emitDone(SseEmitter emitter, String requestId, ChatResponseMetadata metadata, String fallbackModel,
+ Long sessionId, long startedAt, Long firstTokenAt) {
+ Usage usage = metadata == null ? null : metadata.getUsage();
+ emitStrict(emitter, "done", AiChatDoneDto.builder()
+ .requestId(requestId)
+ .sessionId(sessionId)
+ .model(metadata != null && metadata.getModel() != null && !metadata.getModel().isBlank() ? metadata.getModel() : fallbackModel)
+ .elapsedMs(System.currentTimeMillis() - startedAt)
+ .firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAt))
+ .promptTokens(usage == null ? null : usage.getPromptTokens())
+ .completionTokens(usage == null ? null : usage.getCompletionTokens())
+ .totalTokens(usage == null ? null : usage.getTotalTokens())
+ .build());
+ }
+
+ public AiChatStatusDto buildTerminalStatus(String requestId, Long sessionId, String status, String model, long startedAt, Long firstTokenAt) {
+ return AiChatStatusDto.builder()
+ .requestId(requestId)
+ .sessionId(sessionId)
+ .status(status)
+ .model(model)
+ .timestamp(Instant.now().toEpochMilli())
+ .elapsedMs(System.currentTimeMillis() - startedAt)
+ .firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAt))
+ .build();
+ }
+
+ public Long resolveFirstTokenLatency(long startedAt, Long firstTokenAt) {
+ return firstTokenAt == null ? null : Math.max(0L, firstTokenAt - startedAt);
+ }
+
+ public Map<String, String> buildMessagePayload(String... keyValues) {
+ Map<String, String> payload = new LinkedHashMap<>();
+ if (keyValues == null || keyValues.length == 0) {
+ return payload;
+ }
+ if (keyValues.length % 2 != 0) {
+ throw new CoolException("娑堟伅杞借嵎鍙傛暟蹇呴』鎴愬鍑虹幇");
+ }
+ for (int i = 0; i < keyValues.length; i += 2) {
+ payload.put(keyValues[i], keyValues[i + 1] == null ? "" : keyValues[i + 1]);
+ }
+ return payload;
+ }
+
+ public void emitStrict(SseEmitter emitter, String eventName, Object payload) {
+ try {
+ String data = objectMapper.writeValueAsString(payload);
+ emitter.send(SseEmitter.event()
+ .name(eventName)
+ .data(data, MediaType.APPLICATION_JSON));
+ } catch (IOException e) {
+ throw new AiChatException("AI_SSE_EMIT_ERROR", AiErrorCategory.STREAM, "SSE_EMIT", "SSE 杈撳嚭澶辫触: " + e.getMessage(), e);
+ }
+ }
+
+ public void emitSafely(SseEmitter emitter, String eventName, Object payload) {
+ try {
+ emitStrict(emitter, eventName, payload);
+ } catch (Exception e) {
+ log.warn("AI SSE event emit skipped, eventName={}, message={}", eventName, e.getMessage());
+ }
+ }
+}
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/chat/AiThinkingTraceEmitter.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/chat/AiThinkingTraceEmitter.java
new file mode 100644
index 0000000..206ad51
--- /dev/null
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/chat/AiThinkingTraceEmitter.java
@@ -0,0 +1,129 @@
+package com.vincent.rsf.server.ai.service.impl.chat;
+
+import com.vincent.rsf.server.ai.dto.AiChatThinkingEventDto;
+import org.springframework.util.StringUtils;
+import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
+
+import java.time.Instant;
+import java.util.Objects;
+
+public class AiThinkingTraceEmitter {
+
+ private final AiSseEventPublisher aiSseEventPublisher;
+ private final SseEmitter emitter;
+ private final String requestId;
+ private final Long sessionId;
+ private String currentPhase;
+ private String currentStatus;
+
+ public AiThinkingTraceEmitter(AiSseEventPublisher aiSseEventPublisher, SseEmitter emitter, String requestId, Long sessionId) {
+ this.aiSseEventPublisher = aiSseEventPublisher;
+ this.emitter = emitter;
+ this.requestId = requestId;
+ this.sessionId = sessionId;
+ }
+
+ public void startAnalyze() {
+ if (currentPhase != null) {
+ return;
+ }
+ currentPhase = "ANALYZE";
+ currentStatus = "STARTED";
+ emitThinkingEvent("ANALYZE", "STARTED", "姝e湪鍒嗘瀽闂",
+ "宸叉帴鏀朵綘鐨勯棶棰橈紝姝e湪鐞嗚В鎰忓浘骞跺垽鏂槸鍚﹂渶瑕佽皟鐢ㄥ伐鍏枫��", null);
+ }
+
+ public void onToolStart(String toolName, String toolCallId) {
+ switchPhase("TOOL_CALL", "STARTED", "姝e湪璋冪敤宸ュ叿", "宸插垽鏂渶瑕佽皟鐢ㄥ伐鍏凤紝姝e湪鏌ヨ鐩稿叧淇℃伅銆�", null);
+ currentStatus = "UPDATED";
+ emitThinkingEvent("TOOL_CALL", "UPDATED", "姝e湪璋冪敤宸ュ叿",
+ "姝e湪璋冪敤宸ュ叿 " + safeLabel(toolName, "鏈煡宸ュ叿") + " 鑾峰彇鎵�闇�淇℃伅銆�", toolCallId);
+ }
+
+ public void onToolResult(String toolName, String toolCallId, boolean failed) {
+ currentPhase = "TOOL_CALL";
+ currentStatus = failed ? "FAILED" : "UPDATED";
+ emitThinkingEvent("TOOL_CALL", failed ? "FAILED" : "UPDATED",
+ failed ? "宸ュ叿璋冪敤澶辫触" : "宸ュ叿璋冪敤瀹屾垚",
+ failed
+ ? "宸ュ叿 " + safeLabel(toolName, "鏈煡宸ュ叿") + " 璋冪敤澶辫触锛屾鍦ㄨ瘎浼板け璐ュ奖鍝嶅苟鏁寸悊鍙敤淇℃伅銆�"
+ : "宸ュ叿 " + safeLabel(toolName, "鏈煡宸ュ叿") + " 宸茶繑鍥炵粨鏋滐紝姝e湪缁х画鍒嗘瀽骞舵彁鐐煎叧閿俊鎭��",
+ toolCallId);
+ }
+
+ public void startAnswer() {
+ switchPhase("ANSWER", "STARTED", "姝e湪鏁寸悊绛旀", "宸插畬鎴愬垎鏋愶紝姝e湪缁勭粐鏈�缁堝洖澶嶅唴瀹广��", null);
+ }
+
+ public void completeCurrentPhase() {
+ if (!StringUtils.hasText(currentPhase) || isTerminalStatus(currentStatus)) {
+ return;
+ }
+ currentStatus = "COMPLETED";
+ emitThinkingEvent(currentPhase, "COMPLETED", resolveCompleteTitle(currentPhase),
+ resolveCompleteContent(currentPhase), null);
+ }
+
+ public void markTerminated(String terminalStatus) {
+ if (!StringUtils.hasText(currentPhase) || isTerminalStatus(currentStatus)) {
+ return;
+ }
+ currentStatus = terminalStatus;
+ emitThinkingEvent(currentPhase, terminalStatus,
+ "ABORTED".equals(terminalStatus) ? "鎬濊�冨凡涓" : "鎬濊�冨け璐�",
+ "ABORTED".equals(terminalStatus)
+ ? "鏈疆瀵硅瘽宸茶涓锛屾�濊�冭繃绋嬫彁鍓嶇粨鏉熴��"
+ : "鏈疆瀵硅瘽鍦ㄧ敓鎴愮瓟妗堝墠澶辫触锛屽綋鍓嶆�濊�冭繃绋嬪凡鍋滄銆�",
+ null);
+ }
+
+ private void switchPhase(String nextPhase, String nextStatus, String title, String content, String toolCallId) {
+ if (!Objects.equals(currentPhase, nextPhase)) {
+ completeCurrentPhase();
+ }
+ currentPhase = nextPhase;
+ currentStatus = nextStatus;
+ emitThinkingEvent(nextPhase, nextStatus, title, content, toolCallId);
+ }
+
+ private void emitThinkingEvent(String phase, String status, String title, String content, String toolCallId) {
+ aiSseEventPublisher.emitSafely(emitter, "thinking", AiChatThinkingEventDto.builder()
+ .requestId(requestId)
+ .sessionId(sessionId)
+ .phase(phase)
+ .status(status)
+ .title(title)
+ .content(content)
+ .toolCallId(toolCallId)
+ .timestamp(Instant.now().toEpochMilli())
+ .build());
+ }
+
+ private boolean isTerminalStatus(String status) {
+ return "COMPLETED".equals(status) || "FAILED".equals(status) || "ABORTED".equals(status);
+ }
+
+ private String resolveCompleteTitle(String phase) {
+ if ("ANSWER".equals(phase)) {
+ return "绛旀鏁寸悊瀹屾垚";
+ }
+ if ("TOOL_CALL".equals(phase)) {
+ return "宸ュ叿鍒嗘瀽瀹屾垚";
+ }
+ return "闂鍒嗘瀽瀹屾垚";
+ }
+
+ private String resolveCompleteContent(String phase) {
+ if ("ANSWER".equals(phase)) {
+ return "鏈�缁堢瓟澶嶅凡鐢熸垚瀹屾垚銆�";
+ }
+ if ("TOOL_CALL".equals(phase)) {
+ return "宸ュ叿璋冪敤闃舵宸茬粨鏉燂紝鐩稿叧淇℃伅宸叉暣鐞嗗畬姣曘��";
+ }
+ return "闂鎰忓浘鍜屽鐞嗘柟鍚戝凡鍒嗘瀽瀹屾垚銆�";
+ }
+
+ private String safeLabel(String value, String fallback) {
+ return StringUtils.hasText(value) ? value : fallback;
+ }
+}
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/chat/AiToolObservationService.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/chat/AiToolObservationService.java
new file mode 100644
index 0000000..82d2019
--- /dev/null
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/chat/AiToolObservationService.java
@@ -0,0 +1,205 @@
+package com.vincent.rsf.server.ai.service.impl.chat;
+
+import com.vincent.rsf.framework.exception.CoolException;
+import com.vincent.rsf.server.ai.dto.AiChatToolEventDto;
+import com.vincent.rsf.server.ai.service.AiCallLogService;
+import com.vincent.rsf.server.ai.service.MountedToolCallback;
+import com.vincent.rsf.server.ai.store.AiCachedToolResult;
+import com.vincent.rsf.server.ai.store.AiToolResultStore;
+import lombok.RequiredArgsConstructor;
+import org.springframework.ai.chat.model.ToolContext;
+import org.springframework.ai.tool.ToolCallback;
+import org.springframework.stereotype.Component;
+import org.springframework.util.StringUtils;
+import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicLong;
+
+@Component
+@RequiredArgsConstructor
+public class AiToolObservationService {
+
+ private final AiSseEventPublisher aiSseEventPublisher;
+ private final AiToolResultStore aiToolResultStore;
+ private final AiCallLogService aiCallLogService;
+
+ public ToolCallback[] wrapToolCallbacks(ToolCallback[] toolCallbacks, SseEmitter emitter, String requestId,
+ Long sessionId, AtomicLong toolCallSequence,
+ AtomicLong toolSuccessCount, AtomicLong toolFailureCount,
+ Long callLogId, Long userId, Long tenantId,
+ AiThinkingTraceEmitter thinkingTraceEmitter) {
+ if (toolCallbacks == null || toolCallbacks.length == 0) {
+ return toolCallbacks;
+ }
+ List<ToolCallback> wrappedCallbacks = new ArrayList<>();
+ for (ToolCallback callback : toolCallbacks) {
+ if (callback == null) {
+ continue;
+ }
+ wrappedCallbacks.add(new ObservableToolCallback(callback, emitter, requestId, sessionId, toolCallSequence,
+ toolSuccessCount, toolFailureCount, callLogId, userId, tenantId, thinkingTraceEmitter));
+ }
+ return wrappedCallbacks.toArray(new ToolCallback[0]);
+ }
+
+ private String summarizeToolPayload(String content, int maxLength) {
+ if (!StringUtils.hasText(content)) {
+ return null;
+ }
+ String normalized = content.trim()
+ .replace("\r", " ")
+ .replace("\n", " ")
+ .replaceAll("\\s+", " ");
+ return normalized.length() > maxLength ? normalized.substring(0, maxLength) : normalized;
+ }
+
+ private class ObservableToolCallback implements ToolCallback {
+
+ private final ToolCallback delegate;
+ private final SseEmitter emitter;
+ private final String requestId;
+ private final Long sessionId;
+ private final AtomicLong toolCallSequence;
+ private final AtomicLong toolSuccessCount;
+ private final AtomicLong toolFailureCount;
+ private final Long callLogId;
+ private final Long userId;
+ private final Long tenantId;
+ private final AiThinkingTraceEmitter thinkingTraceEmitter;
+
+ private ObservableToolCallback(ToolCallback delegate, SseEmitter emitter, String requestId,
+ Long sessionId, AtomicLong toolCallSequence,
+ AtomicLong toolSuccessCount, AtomicLong toolFailureCount,
+ Long callLogId, Long userId, Long tenantId,
+ AiThinkingTraceEmitter thinkingTraceEmitter) {
+ this.delegate = delegate;
+ this.emitter = emitter;
+ this.requestId = requestId;
+ this.sessionId = sessionId;
+ this.toolCallSequence = toolCallSequence;
+ this.toolSuccessCount = toolSuccessCount;
+ this.toolFailureCount = toolFailureCount;
+ this.callLogId = callLogId;
+ this.userId = userId;
+ this.tenantId = tenantId;
+ this.thinkingTraceEmitter = thinkingTraceEmitter;
+ }
+
+ @Override
+ public org.springframework.ai.tool.definition.ToolDefinition getToolDefinition() {
+ return delegate.getToolDefinition();
+ }
+
+ @Override
+ public org.springframework.ai.tool.metadata.ToolMetadata getToolMetadata() {
+ return delegate.getToolMetadata();
+ }
+
+ @Override
+ public String call(String toolInput) {
+ return call(toolInput, null);
+ }
+
+ @Override
+ public String call(String toolInput, ToolContext toolContext) {
+ String toolName = delegate.getToolDefinition() == null ? "unknown" : delegate.getToolDefinition().name();
+ String mountName = delegate instanceof MountedToolCallback ? ((MountedToolCallback) delegate).getMountName() : null;
+ String toolCallId = requestId + "-tool-" + toolCallSequence.incrementAndGet();
+ long startedAt = System.currentTimeMillis();
+ AiCachedToolResult cachedToolResult = aiToolResultStore.getToolResult(tenantId, requestId, toolName, toolInput);
+ if (cachedToolResult != null) {
+ aiSseEventPublisher.emitSafely(emitter, "tool_result", AiChatToolEventDto.builder()
+ .requestId(requestId)
+ .sessionId(sessionId)
+ .toolCallId(toolCallId)
+ .toolName(toolName)
+ .mountName(mountName)
+ .status(cachedToolResult.isSuccess() ? "COMPLETED" : "FAILED")
+ .inputSummary(summarizeToolPayload(toolInput, 400))
+ .outputSummary(summarizeToolPayload(cachedToolResult.getOutput(), 600))
+ .errorMessage(cachedToolResult.getErrorMessage())
+ .durationMs(0L)
+ .timestamp(System.currentTimeMillis())
+ .build());
+ if (thinkingTraceEmitter != null) {
+ thinkingTraceEmitter.onToolResult(toolName, toolCallId, !cachedToolResult.isSuccess());
+ }
+ if (cachedToolResult.isSuccess()) {
+ toolSuccessCount.incrementAndGet();
+ aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
+ "COMPLETED", summarizeToolPayload(toolInput, 400), summarizeToolPayload(cachedToolResult.getOutput(), 600),
+ null, 0L, userId, tenantId);
+ return cachedToolResult.getOutput();
+ }
+ toolFailureCount.incrementAndGet();
+ aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
+ "FAILED", summarizeToolPayload(toolInput, 400), null, cachedToolResult.getErrorMessage(),
+ 0L, userId, tenantId);
+ throw new CoolException(cachedToolResult.getErrorMessage());
+ }
+ if (thinkingTraceEmitter != null) {
+ thinkingTraceEmitter.onToolStart(toolName, toolCallId);
+ }
+ aiSseEventPublisher.emitSafely(emitter, "tool_start", AiChatToolEventDto.builder()
+ .requestId(requestId)
+ .sessionId(sessionId)
+ .toolCallId(toolCallId)
+ .toolName(toolName)
+ .mountName(mountName)
+ .status("STARTED")
+ .inputSummary(summarizeToolPayload(toolInput, 400))
+ .timestamp(startedAt)
+ .build());
+ try {
+ String output = toolContext == null ? delegate.call(toolInput) : delegate.call(toolInput, toolContext);
+ long durationMs = System.currentTimeMillis() - startedAt;
+ aiSseEventPublisher.emitSafely(emitter, "tool_result", AiChatToolEventDto.builder()
+ .requestId(requestId)
+ .sessionId(sessionId)
+ .toolCallId(toolCallId)
+ .toolName(toolName)
+ .mountName(mountName)
+ .status("COMPLETED")
+ .inputSummary(summarizeToolPayload(toolInput, 400))
+ .outputSummary(summarizeToolPayload(output, 600))
+ .durationMs(durationMs)
+ .timestamp(System.currentTimeMillis())
+ .build());
+ if (thinkingTraceEmitter != null) {
+ thinkingTraceEmitter.onToolResult(toolName, toolCallId, false);
+ }
+ aiToolResultStore.cacheToolResult(tenantId, requestId, toolName, toolInput, true, output, null);
+ toolSuccessCount.incrementAndGet();
+ aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
+ "COMPLETED", summarizeToolPayload(toolInput, 400), summarizeToolPayload(output, 600),
+ null, durationMs, userId, tenantId);
+ return output;
+ } catch (RuntimeException e) {
+ long durationMs = System.currentTimeMillis() - startedAt;
+ aiSseEventPublisher.emitSafely(emitter, "tool_error", AiChatToolEventDto.builder()
+ .requestId(requestId)
+ .sessionId(sessionId)
+ .toolCallId(toolCallId)
+ .toolName(toolName)
+ .mountName(mountName)
+ .status("FAILED")
+ .inputSummary(summarizeToolPayload(toolInput, 400))
+ .errorMessage(e.getMessage())
+ .durationMs(durationMs)
+ .timestamp(System.currentTimeMillis())
+ .build());
+ if (thinkingTraceEmitter != null) {
+ thinkingTraceEmitter.onToolResult(toolName, toolCallId, true);
+ }
+ aiToolResultStore.cacheToolResult(tenantId, requestId, toolName, toolInput, false, null, e.getMessage());
+ toolFailureCount.incrementAndGet();
+ aiCallLogService.saveMcpCallLog(callLogId, requestId, sessionId, toolCallId, mountName, toolName,
+ "FAILED", summarizeToolPayload(toolInput, 400), null, e.getMessage(),
+ durationMs, userId, tenantId);
+ throw e;
+ }
+ }
+ }
+}
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/conversation/AiConversationCommandService.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/conversation/AiConversationCommandService.java
new file mode 100644
index 0000000..0f79d6f
--- /dev/null
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/conversation/AiConversationCommandService.java
@@ -0,0 +1,231 @@
+package com.vincent.rsf.server.ai.service.impl.conversation;
+
+import com.vincent.rsf.framework.exception.CoolException;
+import com.vincent.rsf.server.ai.dto.AiChatMessageDto;
+import com.vincent.rsf.server.ai.dto.AiChatSessionPinRequest;
+import com.vincent.rsf.server.ai.dto.AiChatSessionRenameRequest;
+import com.vincent.rsf.server.ai.dto.AiChatSessionDto;
+import com.vincent.rsf.server.ai.entity.AiChatMessage;
+import com.vincent.rsf.server.ai.entity.AiChatSession;
+import com.vincent.rsf.server.ai.mapper.AiChatMessageMapper;
+import com.vincent.rsf.server.ai.mapper.AiChatSessionMapper;
+import com.vincent.rsf.server.system.enums.StatusType;
+import lombok.RequiredArgsConstructor;
+import org.springframework.stereotype.Service;
+import org.springframework.transaction.support.TransactionSynchronization;
+import org.springframework.transaction.support.TransactionSynchronizationManager;
+import org.springframework.transaction.annotation.Transactional;
+import org.springframework.util.StringUtils;
+
+import java.util.ArrayList;
+import java.util.Date;
+import java.util.List;
+
+@Service
+@RequiredArgsConstructor
+public class AiConversationCommandService {
+
+ private final AiChatSessionMapper aiChatSessionMapper;
+ private final AiChatMessageMapper aiChatMessageMapper;
+ private final AiConversationQueryService aiConversationQueryService;
+ private final AiMemoryProfileService aiMemoryProfileService;
+
+ @Transactional(rollbackFor = Exception.class)
+ public AiChatSession resolveSession(Long userId, Long tenantId, String promptCode, Long sessionId, String titleSeed) {
+ aiConversationQueryService.ensureIdentity(userId, tenantId);
+ String resolvedPromptCode = aiConversationQueryService.requirePromptCode(promptCode);
+ if (sessionId != null) {
+ return aiConversationQueryService.getSession(sessionId, userId, tenantId, resolvedPromptCode);
+ }
+ Date now = new Date();
+ AiChatSession session = new AiChatSession()
+ .setTitle(aiConversationQueryService.buildSessionTitle(titleSeed))
+ .setPromptCode(resolvedPromptCode)
+ .setUserId(userId)
+ .setTenantId(tenantId)
+ .setLastMessageTime(now)
+ .setPinned(0)
+ .setStatus(StatusType.ENABLE.val)
+ .setDeleted(0)
+ .setCreateBy(userId)
+ .setCreateTime(now)
+ .setUpdateBy(userId)
+ .setUpdateTime(now);
+ aiChatSessionMapper.insert(session);
+ afterConversationMutationCommitted(() -> aiConversationQueryService.evictConversationCaches(tenantId, userId));
+ return session;
+ }
+
+ @Transactional(rollbackFor = Exception.class)
+ public void saveRound(AiChatSession session, Long userId, Long tenantId, List<AiChatMessageDto> memoryMessages, String assistantContent) {
+ if (session == null || session.getId() == null) {
+ throw new CoolException("AI 浼氳瘽涓嶅瓨鍦�");
+ }
+ aiConversationQueryService.ensureIdentity(userId, tenantId);
+ List<AiChatMessageDto> normalizedMessages = normalizeMessages(memoryMessages);
+ if (normalizedMessages.isEmpty()) {
+ throw new CoolException("鏈疆娌℃湁鍙繚瀛樼殑瀵硅瘽娑堟伅");
+ }
+ int nextSeqNo = aiConversationQueryService.findNextSeqNo(session.getId());
+ Date now = new Date();
+ List<AiChatMessage> records = new ArrayList<>();
+ for (AiChatMessageDto message : normalizedMessages) {
+ records.add(buildMessageEntity(session.getId(), nextSeqNo++, message.getRole(), message.getContent(), userId, tenantId, now));
+ }
+ if (StringUtils.hasText(assistantContent)) {
+ records.add(buildMessageEntity(session.getId(), nextSeqNo, "assistant", assistantContent, userId, tenantId, now));
+ }
+ aiChatMessageMapper.insertBatch(records);
+ aiChatSessionMapper.updateById(new AiChatSession()
+ .setId(session.getId())
+ .setTitle(aiConversationQueryService.resolveUpdatedTitle(session.getTitle(), normalizedMessages))
+ .setLastMessageTime(now)
+ .setUpdateBy(userId)
+ .setUpdateTime(now));
+ afterConversationMutationCommitted(() -> {
+ aiConversationQueryService.evictConversationCaches(tenantId, userId);
+ aiMemoryProfileService.scheduleMemoryProfileRefresh(session.getId(), userId, tenantId);
+ });
+ }
+
+ @Transactional(rollbackFor = Exception.class)
+ public void removeSession(Long userId, Long tenantId, Long sessionId) {
+ aiConversationQueryService.ensureIdentity(userId, tenantId);
+ AiChatSession session = aiConversationQueryService.requireOwnedSession(sessionId, userId, tenantId);
+ aiChatSessionMapper.updateById(new AiChatSession()
+ .setId(session.getId())
+ .setDeleted(1)
+ .setUpdateBy(userId)
+ .setUpdateTime(new Date()));
+ aiChatMessageMapper.softDeleteBySessionId(sessionId);
+ afterConversationMutationCommitted(() -> aiConversationQueryService.evictConversationCaches(tenantId, userId));
+ }
+
+ @Transactional(rollbackFor = Exception.class)
+ public AiChatSessionDto renameSession(Long userId, Long tenantId, Long sessionId, AiChatSessionRenameRequest request) {
+ aiConversationQueryService.ensureIdentity(userId, tenantId);
+ if (request == null || !StringUtils.hasText(request.getTitle())) {
+ throw new CoolException("浼氳瘽鏍囬涓嶈兘涓虹┖");
+ }
+ AiChatSession session = aiConversationQueryService.requireOwnedSession(sessionId, userId, tenantId);
+ aiChatSessionMapper.updateById(new AiChatSession()
+ .setId(sessionId)
+ .setTitle(aiConversationQueryService.buildSessionTitle(request.getTitle()))
+ .setUpdateBy(userId)
+ .setUpdateTime(new Date()));
+ afterConversationMutationCommitted(() -> aiConversationQueryService.evictConversationCaches(tenantId, userId));
+ return reloadSessionDto(sessionId, userId, tenantId, session.getPromptCode());
+ }
+
+ @Transactional(rollbackFor = Exception.class)
+ public AiChatSessionDto pinSession(Long userId, Long tenantId, Long sessionId, AiChatSessionPinRequest request) {
+ aiConversationQueryService.ensureIdentity(userId, tenantId);
+ if (request == null || request.getPinned() == null) {
+ throw new CoolException("缃《鐘舵�佷笉鑳戒负绌�");
+ }
+ AiChatSession session = aiConversationQueryService.requireOwnedSession(sessionId, userId, tenantId);
+ aiChatSessionMapper.updateById(new AiChatSession()
+ .setId(sessionId)
+ .setPinned(Boolean.TRUE.equals(request.getPinned()) ? 1 : 0)
+ .setUpdateBy(userId)
+ .setUpdateTime(new Date()));
+ afterConversationMutationCommitted(() -> aiConversationQueryService.evictConversationCaches(tenantId, userId));
+ return reloadSessionDto(sessionId, userId, tenantId, session.getPromptCode());
+ }
+
+ @Transactional(rollbackFor = Exception.class)
+ public void clearSessionMemory(Long userId, Long tenantId, Long sessionId) {
+ aiConversationQueryService.ensureIdentity(userId, tenantId);
+ AiChatSession session = aiConversationQueryService.requireOwnedSession(sessionId, userId, tenantId);
+ aiChatMessageMapper.softDeleteBySessionId(sessionId);
+ aiChatSessionMapper.updateById(new AiChatSession()
+ .setId(sessionId)
+ .setMemorySummary(null)
+ .setMemoryFacts(null)
+ .setUpdateBy(userId)
+ .setUpdateTime(new Date())
+ .setLastMessageTime(session.getCreateTime()));
+ afterConversationMutationCommitted(() -> aiConversationQueryService.evictConversationCaches(tenantId, userId));
+ }
+
+ @Transactional(rollbackFor = Exception.class)
+ public void retainLatestRound(Long userId, Long tenantId, Long sessionId) {
+ aiConversationQueryService.ensureIdentity(userId, tenantId);
+ aiConversationQueryService.requireOwnedSession(sessionId, userId, tenantId);
+ List<AiChatMessage> records = aiConversationQueryService.listMessageRecords(sessionId);
+ if (records.isEmpty()) {
+ return;
+ }
+ List<AiChatMessage> retained = aiConversationQueryService.tailMessageRecordsByRounds(records, 1);
+ List<Long> retainedIds = retained.stream().map(AiChatMessage::getId).toList();
+ List<Long> deletedIds = records.stream()
+ .map(AiChatMessage::getId)
+ .filter(id -> !retainedIds.contains(id))
+ .toList();
+ if (!deletedIds.isEmpty()) {
+ aiChatMessageMapper.softDeleteByIds(deletedIds);
+ }
+ afterConversationMutationCommitted(() -> {
+ aiConversationQueryService.evictConversationCaches(tenantId, userId);
+ aiMemoryProfileService.scheduleMemoryProfileRefresh(sessionId, userId, tenantId);
+ });
+ }
+
+ private AiChatSessionDto reloadSessionDto(Long sessionId, Long userId, Long tenantId, String promptCode) {
+ return aiConversationQueryService.listSessions(userId, tenantId, promptCode, null).stream()
+ .filter(item -> sessionId.equals(item.getSessionId()))
+ .findFirst()
+ .orElseThrow(() -> new CoolException("AI 浼氳瘽涓嶅瓨鍦ㄦ垨鏃犳潈璁块棶"));
+ }
+
+ private List<AiChatMessageDto> normalizeMessages(List<AiChatMessageDto> memoryMessages) {
+ List<AiChatMessageDto> normalized = new ArrayList<>();
+ if (memoryMessages == null || memoryMessages.isEmpty()) {
+ return normalized;
+ }
+ for (AiChatMessageDto item : memoryMessages) {
+ if (item == null || !StringUtils.hasText(item.getContent())) {
+ continue;
+ }
+ String role = item.getRole() == null ? "user" : item.getRole().toLowerCase();
+ if ("system".equals(role)) {
+ continue;
+ }
+ AiChatMessageDto normalizedItem = new AiChatMessageDto();
+ normalizedItem.setRole("assistant".equals(role) ? "assistant" : "user");
+ normalizedItem.setContent(item.getContent().trim());
+ normalized.add(normalizedItem);
+ }
+ return normalized;
+ }
+
+ private AiChatMessage buildMessageEntity(Long sessionId, int seqNo, String role, String content, Long userId, Long tenantId, Date createTime) {
+ return new AiChatMessage()
+ .setSessionId(sessionId)
+ .setSeqNo(seqNo)
+ .setRole(role)
+ .setContent(content)
+ .setContentLength(content == null ? 0 : content.length())
+ .setUserId(userId)
+ .setTenantId(tenantId)
+ .setDeleted(0)
+ .setCreateBy(userId)
+ .setCreateTime(createTime);
+ }
+
+ private void afterConversationMutationCommitted(Runnable action) {
+ if (action == null) {
+ return;
+ }
+ if (!TransactionSynchronizationManager.isSynchronizationActive()) {
+ action.run();
+ return;
+ }
+ TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronization() {
+ @Override
+ public void afterCommit() {
+ action.run();
+ }
+ });
+ }
+}
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/conversation/AiConversationQueryService.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/conversation/AiConversationQueryService.java
new file mode 100644
index 0000000..c506652
--- /dev/null
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/conversation/AiConversationQueryService.java
@@ -0,0 +1,318 @@
+package com.vincent.rsf.server.ai.service.impl.conversation;
+
+import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
+import com.vincent.rsf.framework.common.Cools;
+import com.vincent.rsf.framework.exception.CoolException;
+import com.vincent.rsf.server.ai.config.AiDefaults;
+import com.vincent.rsf.server.ai.dto.AiChatMemoryDto;
+import com.vincent.rsf.server.ai.dto.AiChatMessageDto;
+import com.vincent.rsf.server.ai.dto.AiChatSessionDto;
+import com.vincent.rsf.server.ai.entity.AiChatMessage;
+import com.vincent.rsf.server.ai.entity.AiChatSession;
+import com.vincent.rsf.server.ai.mapper.AiChatMessageMapper;
+import com.vincent.rsf.server.ai.mapper.AiChatSessionMapper;
+import com.vincent.rsf.server.ai.store.AiConversationCacheStore;
+import com.vincent.rsf.server.system.enums.StatusType;
+import lombok.RequiredArgsConstructor;
+import org.springframework.stereotype.Service;
+import org.springframework.util.StringUtils;
+
+import java.util.ArrayList;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+
+@Service
+@RequiredArgsConstructor
+public class AiConversationQueryService {
+
+ private final AiChatSessionMapper aiChatSessionMapper;
+ private final AiChatMessageMapper aiChatMessageMapper;
+ private final AiConversationCacheStore aiConversationCacheStore;
+
+ public AiChatMemoryDto getMemory(Long userId, Long tenantId, String promptCode, Long sessionId) {
+ ensureIdentity(userId, tenantId);
+ String resolvedPromptCode = requirePromptCode(promptCode);
+ AiChatMemoryDto cached = aiConversationCacheStore.getMemory(tenantId, userId, resolvedPromptCode, sessionId);
+ if (cached != null) {
+ return cached;
+ }
+ AiChatSession session = sessionId == null
+ ? findLatestSession(userId, tenantId, resolvedPromptCode)
+ : getSession(sessionId, userId, tenantId, resolvedPromptCode);
+ AiChatMemoryDto memory;
+ if (session == null) {
+ memory = AiChatMemoryDto.builder()
+ .sessionId(null)
+ .memorySummary(null)
+ .memoryFacts(null)
+ .recentMessageCount(0)
+ .persistedMessages(List.of())
+ .shortMemoryMessages(List.of())
+ .build();
+ aiConversationCacheStore.cacheMemory(tenantId, userId, resolvedPromptCode, sessionId, memory);
+ return memory;
+ }
+ List<AiChatMessageDto> persistedMessages = listMessages(session.getId());
+ List<AiChatMessageDto> shortMemoryMessages = tailMessagesByRounds(persistedMessages, AiDefaults.MEMORY_RECENT_ROUNDS);
+ memory = AiChatMemoryDto.builder()
+ .sessionId(session.getId())
+ .memorySummary(session.getMemorySummary())
+ .memoryFacts(session.getMemoryFacts())
+ .recentMessageCount(shortMemoryMessages.size())
+ .persistedMessages(persistedMessages)
+ .shortMemoryMessages(shortMemoryMessages)
+ .build();
+ aiConversationCacheStore.cacheMemory(tenantId, userId, resolvedPromptCode, session.getId(), memory);
+ if (sessionId == null || !session.getId().equals(sessionId)) {
+ aiConversationCacheStore.cacheMemory(tenantId, userId, resolvedPromptCode, null, memory);
+ }
+ return memory;
+ }
+
+ public List<AiChatSessionDto> listSessions(Long userId, Long tenantId, String promptCode, String keyword) {
+ ensureIdentity(userId, tenantId);
+ String resolvedPromptCode = requirePromptCode(promptCode);
+ List<AiChatSessionDto> cached = aiConversationCacheStore.getSessionList(tenantId, userId, resolvedPromptCode, keyword);
+ if (cached != null) {
+ return cached;
+ }
+ List<AiChatSession> sessions = aiChatSessionMapper.selectList(new LambdaQueryWrapper<AiChatSession>()
+ .eq(AiChatSession::getUserId, userId)
+ .eq(AiChatSession::getTenantId, tenantId)
+ .eq(AiChatSession::getPromptCode, resolvedPromptCode)
+ .eq(AiChatSession::getDeleted, 0)
+ .eq(AiChatSession::getStatus, StatusType.ENABLE.val)
+ .like(StringUtils.hasText(keyword), AiChatSession::getTitle, keyword == null ? null : keyword.trim())
+ .orderByDesc(AiChatSession::getPinned)
+ .orderByDesc(AiChatSession::getLastMessageTime)
+ .orderByDesc(AiChatSession::getId));
+ if (Cools.isEmpty(sessions)) {
+ aiConversationCacheStore.cacheSessionList(tenantId, userId, resolvedPromptCode, keyword, List.of());
+ return List.of();
+ }
+ List<Long> sessionIds = sessions.stream().map(AiChatSession::getId).toList();
+ Map<Long, AiChatMessage> latestMessageMap = new LinkedHashMap<>();
+ for (AiChatMessage message : aiChatMessageMapper.selectLatestMessagesBySessionIds(sessionIds)) {
+ latestMessageMap.put(message.getSessionId(), message);
+ }
+ List<AiChatSessionDto> result = new ArrayList<>();
+ for (AiChatSession session : sessions) {
+ result.add(buildSessionDto(session, latestMessageMap.get(session.getId())));
+ }
+ aiConversationCacheStore.cacheSessionList(tenantId, userId, resolvedPromptCode, keyword, result);
+ return result;
+ }
+
+ public void evictConversationCaches(Long tenantId, Long userId) {
+ aiConversationCacheStore.evictUserConversationCaches(tenantId, userId);
+ }
+
+ public AiChatSession findLatestSession(Long userId, Long tenantId, String promptCode) {
+ return aiChatSessionMapper.selectOne(new LambdaQueryWrapper<AiChatSession>()
+ .eq(AiChatSession::getUserId, userId)
+ .eq(AiChatSession::getTenantId, tenantId)
+ .eq(AiChatSession::getPromptCode, promptCode)
+ .eq(AiChatSession::getDeleted, 0)
+ .eq(AiChatSession::getStatus, StatusType.ENABLE.val)
+ .orderByDesc(AiChatSession::getLastMessageTime)
+ .orderByDesc(AiChatSession::getId)
+ .last("limit 1"));
+ }
+
+ public AiChatSession getSession(Long sessionId, Long userId, Long tenantId, String promptCode) {
+ AiChatSession session = aiChatSessionMapper.selectOne(new LambdaQueryWrapper<AiChatSession>()
+ .eq(AiChatSession::getId, sessionId)
+ .eq(AiChatSession::getUserId, userId)
+ .eq(AiChatSession::getTenantId, tenantId)
+ .eq(AiChatSession::getPromptCode, promptCode)
+ .eq(AiChatSession::getDeleted, 0)
+ .eq(AiChatSession::getStatus, StatusType.ENABLE.val)
+ .last("limit 1"));
+ if (session == null) {
+ throw new CoolException("AI 浼氳瘽涓嶅瓨鍦ㄦ垨鏃犳潈璁块棶");
+ }
+ return session;
+ }
+
+ public AiChatSession requireOwnedSession(Long sessionId, Long userId, Long tenantId) {
+ if (sessionId == null) {
+ throw new CoolException("AI 浼氳瘽 ID 涓嶈兘涓虹┖");
+ }
+ AiChatSession session = aiChatSessionMapper.selectOne(new LambdaQueryWrapper<AiChatSession>()
+ .eq(AiChatSession::getId, sessionId)
+ .eq(AiChatSession::getUserId, userId)
+ .eq(AiChatSession::getTenantId, tenantId)
+ .eq(AiChatSession::getDeleted, 0)
+ .eq(AiChatSession::getStatus, StatusType.ENABLE.val)
+ .last("limit 1"));
+ if (session == null) {
+ throw new CoolException("AI 浼氳瘽涓嶅瓨鍦ㄦ垨鏃犳潈璁块棶");
+ }
+ return session;
+ }
+
+ public List<AiChatMessageDto> listMessages(Long sessionId) {
+ List<AiChatMessage> records = listMessageRecords(sessionId);
+ if (Cools.isEmpty(records)) {
+ return List.of();
+ }
+ List<AiChatMessageDto> messages = new ArrayList<>();
+ for (AiChatMessage record : records) {
+ if (!StringUtils.hasText(record.getContent())) {
+ continue;
+ }
+ AiChatMessageDto item = new AiChatMessageDto();
+ item.setRole(record.getRole());
+ item.setContent(record.getContent());
+ messages.add(item);
+ }
+ return messages;
+ }
+
+ public List<AiChatMessage> listMessageRecords(Long sessionId) {
+ return aiChatMessageMapper.selectList(new LambdaQueryWrapper<AiChatMessage>()
+ .eq(AiChatMessage::getSessionId, sessionId)
+ .eq(AiChatMessage::getDeleted, 0)
+ .orderByAsc(AiChatMessage::getSeqNo)
+ .orderByAsc(AiChatMessage::getId));
+ }
+
+ public int findNextSeqNo(Long sessionId) {
+ AiChatMessage lastMessage = aiChatMessageMapper.selectOne(new LambdaQueryWrapper<AiChatMessage>()
+ .eq(AiChatMessage::getSessionId, sessionId)
+ .eq(AiChatMessage::getDeleted, 0)
+ .orderByDesc(AiChatMessage::getSeqNo)
+ .orderByDesc(AiChatMessage::getId)
+ .last("limit 1"));
+ return lastMessage == null || lastMessage.getSeqNo() == null ? 1 : lastMessage.getSeqNo() + 1;
+ }
+
+ public String resolveUpdatedTitle(String currentTitle, List<AiChatMessageDto> memoryMessages) {
+ if (StringUtils.hasText(currentTitle)) {
+ return currentTitle;
+ }
+ for (AiChatMessageDto item : memoryMessages) {
+ if ("user".equals(item.getRole()) && StringUtils.hasText(item.getContent())) {
+ return buildSessionTitle(item.getContent());
+ }
+ }
+ return null;
+ }
+
+ public String buildSessionTitle(String titleSeed) {
+ if (!StringUtils.hasText(titleSeed)) {
+ throw new CoolException("AI 浼氳瘽鏍囬涓嶈兘涓虹┖");
+ }
+ String title = titleSeed.trim()
+ .replace("\r", " ")
+ .replace("\n", " ")
+ .replaceAll("\\s+", " ");
+ int punctuationIndex = findSummaryBreakIndex(title);
+ if (punctuationIndex > 0) {
+ title = title.substring(0, punctuationIndex).trim();
+ }
+ return title.length() > 48 ? title.substring(0, 48) : title;
+ }
+
+ public List<AiChatMessageDto> tailMessagesByRounds(List<AiChatMessageDto> source, int rounds) {
+ if (Cools.isEmpty(source) || rounds <= 0) {
+ return List.of();
+ }
+ int userCount = 0;
+ int startIndex = source.size();
+ for (int i = source.size() - 1; i >= 0; i--) {
+ AiChatMessageDto item = source.get(i);
+ startIndex = i;
+ if (item != null && "user".equalsIgnoreCase(item.getRole())) {
+ userCount++;
+ if (userCount >= rounds) {
+ break;
+ }
+ }
+ }
+ return new ArrayList<>(source.subList(Math.max(0, startIndex), source.size()));
+ }
+
+ public List<AiChatMessage> tailMessageRecordsByRounds(List<AiChatMessage> source, int rounds) {
+ if (Cools.isEmpty(source) || rounds <= 0) {
+ return List.of();
+ }
+ int userCount = 0;
+ int startIndex = source.size();
+ for (int i = source.size() - 1; i >= 0; i--) {
+ AiChatMessage item = source.get(i);
+ startIndex = i;
+ if (item != null && "user".equalsIgnoreCase(item.getRole())) {
+ userCount++;
+ if (userCount >= rounds) {
+ break;
+ }
+ }
+ }
+ return new ArrayList<>(source.subList(Math.max(0, startIndex), source.size()));
+ }
+
+ public String compactText(String content, int maxLength) {
+ if (!StringUtils.hasText(content)) {
+ return null;
+ }
+ String normalized = content.trim()
+ .replace("\r", " ")
+ .replace("\n", " ")
+ .replaceAll("\\s+", " ");
+ return normalized.length() > maxLength ? normalized.substring(0, maxLength) : normalized;
+ }
+
+ public String requirePromptCode(String promptCode) {
+ if (!StringUtils.hasText(promptCode)) {
+ throw new CoolException("Prompt 缂栫爜涓嶈兘涓虹┖");
+ }
+ return promptCode;
+ }
+
+ public void ensureIdentity(Long userId, Long tenantId) {
+ if (userId == null) {
+ throw new CoolException("褰撳墠鐧诲綍鐢ㄦ埛涓嶅瓨鍦�");
+ }
+ if (tenantId == null) {
+ throw new CoolException("褰撳墠绉熸埛涓嶅瓨鍦�");
+ }
+ }
+
+ private int findSummaryBreakIndex(String title) {
+ String[] separators = {"銆�", "锛�", "锛�", ".", "!", "?"};
+ int result = -1;
+ for (String separator : separators) {
+ int index = title.indexOf(separator);
+ if (index > 0 && (result < 0 || index < result)) {
+ result = index;
+ }
+ }
+ return result;
+ }
+
+ private AiChatSessionDto buildSessionDto(AiChatSession session, AiChatMessage lastMessage) {
+ return AiChatSessionDto.builder()
+ .sessionId(session.getId())
+ .title(session.getTitle())
+ .promptCode(session.getPromptCode())
+ .pinned(session.getPinned() != null && session.getPinned() == 1)
+ .lastMessagePreview(buildLastMessagePreview(lastMessage))
+ .lastMessageTime(session.getLastMessageTime())
+ .build();
+ }
+
+ private String buildLastMessagePreview(AiChatMessage message) {
+ if (message == null || !StringUtils.hasText(message.getContent())) {
+ return null;
+ }
+ String preview = message.getContent().trim()
+ .replace("\r", " ")
+ .replace("\n", " ")
+ .replaceAll("\\s+", " ");
+ String prefix = "assistant".equalsIgnoreCase(message.getRole()) ? "AI: " : "浣�: ";
+ String normalized = prefix + preview;
+ return normalized.length() > 80 ? normalized.substring(0, 80) : normalized;
+ }
+}
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/conversation/AiMemoryProfileService.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/conversation/AiMemoryProfileService.java
new file mode 100644
index 0000000..4c512aa
--- /dev/null
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/conversation/AiMemoryProfileService.java
@@ -0,0 +1,130 @@
+package com.vincent.rsf.server.ai.service.impl.conversation;
+
+import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
+import com.vincent.rsf.framework.common.Cools;
+import com.vincent.rsf.server.ai.config.AiDefaults;
+import com.vincent.rsf.server.ai.dto.AiChatMessageDto;
+import com.vincent.rsf.server.ai.entity.AiChatMessage;
+import com.vincent.rsf.server.ai.entity.AiChatSession;
+import com.vincent.rsf.server.ai.mapper.AiChatMessageMapper;
+import com.vincent.rsf.server.ai.mapper.AiChatSessionMapper;
+import lombok.RequiredArgsConstructor;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.beans.factory.annotation.Qualifier;
+import org.springframework.stereotype.Service;
+import org.springframework.util.StringUtils;
+
+import java.util.Date;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.Executor;
+
+@Slf4j
+@Service
+@RequiredArgsConstructor
+public class AiMemoryProfileService {
+
+ private final AiConversationQueryService aiConversationQueryService;
+ private final AiChatSessionMapper aiChatSessionMapper;
+ private final AiChatMessageMapper aiChatMessageMapper;
+ @Qualifier("aiMemoryTaskExecutor")
+ private final Executor aiMemoryTaskExecutor;
+
+ private final Set<Long> refreshingSessionIds = ConcurrentHashMap.newKeySet();
+ private final Set<Long> pendingRefreshSessionIds = ConcurrentHashMap.newKeySet();
+
+ public void scheduleMemoryProfileRefresh(Long sessionId, Long userId, Long tenantId) {
+ if (sessionId == null) {
+ return;
+ }
+ if (!refreshingSessionIds.add(sessionId)) {
+ pendingRefreshSessionIds.add(sessionId);
+ return;
+ }
+ aiMemoryTaskExecutor.execute(() -> runMemoryProfileRefreshLoop(sessionId, userId, tenantId));
+ }
+
+ private void runMemoryProfileRefreshLoop(Long sessionId, Long userId, Long tenantId) {
+ try {
+ boolean shouldContinue;
+ do {
+ pendingRefreshSessionIds.remove(sessionId);
+ try {
+ refreshMemoryProfile(sessionId, userId);
+ aiConversationQueryService.evictConversationCaches(tenantId, userId);
+ } catch (Exception e) {
+ log.warn("AI memory profile refresh failed, sessionId={}, userId={}, tenantId={}, message={}",
+ sessionId, userId, tenantId, e.getMessage(), e);
+ }
+ shouldContinue = pendingRefreshSessionIds.remove(sessionId);
+ } while (shouldContinue);
+ } finally {
+ refreshingSessionIds.remove(sessionId);
+ if (pendingRefreshSessionIds.remove(sessionId) && refreshingSessionIds.add(sessionId)) {
+ aiMemoryTaskExecutor.execute(() -> runMemoryProfileRefreshLoop(sessionId, userId, tenantId));
+ }
+ }
+ }
+
+ private void refreshMemoryProfile(Long sessionId, Long userId) {
+ List<AiChatMessageDto> messages = aiConversationQueryService.listMessages(sessionId);
+ List<AiChatMessageDto> shortMemoryMessages = aiConversationQueryService.tailMessagesByRounds(messages, AiDefaults.MEMORY_RECENT_ROUNDS);
+ List<AiChatMessageDto> historyMessages = messages.size() > shortMemoryMessages.size()
+ ? messages.subList(0, messages.size() - shortMemoryMessages.size())
+ : List.of();
+ String memorySummary = historyMessages.size() >= AiDefaults.MEMORY_SUMMARY_TRIGGER_MESSAGES
+ ? buildMemorySummary(historyMessages)
+ : null;
+ String memoryFacts = buildMemoryFacts(messages);
+ AiChatMessage lastMessage = aiChatMessageMapper.selectOne(new LambdaQueryWrapper<AiChatMessage>()
+ .eq(AiChatMessage::getSessionId, sessionId)
+ .eq(AiChatMessage::getDeleted, 0)
+ .orderByDesc(AiChatMessage::getSeqNo)
+ .orderByDesc(AiChatMessage::getId)
+ .last("limit 1"));
+ aiChatSessionMapper.updateById(new AiChatSession()
+ .setId(sessionId)
+ .setMemorySummary(memorySummary)
+ .setMemoryFacts(memoryFacts)
+ .setLastMessageTime(lastMessage == null ? null : lastMessage.getCreateTime())
+ .setUpdateBy(userId)
+ .setUpdateTime(new Date()));
+ }
+
+ private String buildMemorySummary(List<AiChatMessageDto> historyMessages) {
+ StringBuilder builder = new StringBuilder("杈冩棭瀵硅瘽鎽樿:\n");
+ for (AiChatMessageDto item : historyMessages) {
+ if (item == null || !StringUtils.hasText(item.getContent())) {
+ continue;
+ }
+ String prefix = "assistant".equalsIgnoreCase(item.getRole()) ? "- AI: " : "- 鐢ㄦ埛: ";
+ String content = aiConversationQueryService.compactText(item.getContent(), 120);
+ if (!StringUtils.hasText(content)) {
+ continue;
+ }
+ builder.append(prefix).append(content).append("\n");
+ if (builder.length() >= AiDefaults.MEMORY_SUMMARY_MAX_LENGTH) {
+ break;
+ }
+ }
+ return aiConversationQueryService.compactText(builder.toString(), AiDefaults.MEMORY_SUMMARY_MAX_LENGTH);
+ }
+
+ private String buildMemoryFacts(List<AiChatMessageDto> messages) {
+ if (Cools.isEmpty(messages)) {
+ return null;
+ }
+ StringBuilder builder = new StringBuilder("鍏抽敭浜嬪疄:\n");
+ int userFacts = 0;
+ for (int i = messages.size() - 1; i >= 0 && userFacts < 4; i--) {
+ AiChatMessageDto item = messages.get(i);
+ if (item == null || !"user".equalsIgnoreCase(item.getRole()) || !StringUtils.hasText(item.getContent())) {
+ continue;
+ }
+ builder.append("- 鐢ㄦ埛鍏虫敞: ").append(aiConversationQueryService.compactText(item.getContent(), 100)).append("\n");
+ userFacts++;
+ }
+ return userFacts == 0 ? null : aiConversationQueryService.compactText(builder.toString(), AiDefaults.MEMORY_FACTS_MAX_LENGTH);
+ }
+}
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/mcp/AiMcpAdminService.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/mcp/AiMcpAdminService.java
new file mode 100644
index 0000000..eab0e6d
--- /dev/null
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/mcp/AiMcpAdminService.java
@@ -0,0 +1,243 @@
+package com.vincent.rsf.server.ai.service.impl.mcp;
+
+import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
+import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.vincent.rsf.framework.exception.CoolException;
+import com.vincent.rsf.server.ai.config.AiDefaults;
+import com.vincent.rsf.server.ai.dto.AiMcpConnectivityTestDto;
+import com.vincent.rsf.server.ai.dto.AiMcpToolPreviewDto;
+import com.vincent.rsf.server.ai.dto.AiMcpToolTestDto;
+import com.vincent.rsf.server.ai.dto.AiMcpToolTestRequest;
+import com.vincent.rsf.server.ai.entity.AiMcpMount;
+import com.vincent.rsf.server.ai.mapper.AiMcpMountMapper;
+import com.vincent.rsf.server.ai.service.BuiltinMcpToolRegistry;
+import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory;
+import lombok.RequiredArgsConstructor;
+import org.springframework.ai.chat.model.ToolContext;
+import org.springframework.ai.tool.ToolCallback;
+import org.springframework.stereotype.Service;
+import org.springframework.util.StringUtils;
+
+import java.text.SimpleDateFormat;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Date;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+
+@Service
+@RequiredArgsConstructor
+public class AiMcpAdminService {
+
+ private final AiMcpMountMapper aiMcpMountMapper;
+ private final BuiltinMcpToolRegistry builtinMcpToolRegistry;
+ private final McpMountRuntimeFactory mcpMountRuntimeFactory;
+ private final ObjectMapper objectMapper;
+
+ public List<AiMcpToolPreviewDto> previewTools(AiMcpMount mount, Long userId) {
+ long startedAt = System.currentTimeMillis();
+ try (McpMountRuntimeFactory.McpMountRuntime runtime = mcpMountRuntimeFactory.create(List.of(mount), userId)) {
+ List<AiMcpToolPreviewDto> tools = buildToolPreviewDtos(runtime.getToolCallbacks(),
+ AiDefaults.MCP_TRANSPORT_BUILTIN.equals(mount.getTransportType())
+ ? builtinMcpToolRegistry.listBuiltinToolCatalog(mount.getBuiltinCode())
+ : List.of());
+ if (!runtime.getErrors().isEmpty()) {
+ String message = String.join("锛�", runtime.getErrors());
+ if (mount.getId() != null) {
+ updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY, message, System.currentTimeMillis() - startedAt);
+ }
+ throw new CoolException(message);
+ }
+ if (mount.getId() != null) {
+ updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_HEALTHY,
+ "宸ュ叿瑙f瀽鎴愬姛锛屽叡 " + tools.size() + " 涓伐鍏�", System.currentTimeMillis() - startedAt);
+ }
+ return tools;
+ } catch (CoolException e) {
+ throw e;
+ } catch (Exception e) {
+ if (mount.getId() != null) {
+ updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY,
+ "宸ュ叿瑙f瀽澶辫触: " + e.getMessage(), System.currentTimeMillis() - startedAt);
+ }
+ throw new CoolException("鑾峰彇宸ュ叿鍒楄〃澶辫触: " + e.getMessage());
+ }
+ }
+
+ public AiMcpConnectivityTestDto testConnectivity(AiMcpMount mount, Long userId, boolean persistHealth) {
+ long startedAt = System.currentTimeMillis();
+ try (McpMountRuntimeFactory.McpMountRuntime runtime = mcpMountRuntimeFactory.create(List.of(mount), userId)) {
+ long elapsedMs = System.currentTimeMillis() - startedAt;
+ if (!runtime.getErrors().isEmpty()) {
+ String message = String.join("锛�", runtime.getErrors());
+ if (persistHealth && mount.getId() != null) {
+ updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY, message, elapsedMs);
+ AiMcpMount latest = requireMount(mount.getId(), mount.getTenantId());
+ return buildConnectivityDto(latest, message, elapsedMs, runtime.getToolCallbacks().length);
+ }
+ return AiMcpConnectivityTestDto.builder()
+ .mountId(mount.getId())
+ .mountName(mount.getName())
+ .healthStatus(AiDefaults.MCP_HEALTH_UNHEALTHY)
+ .message(message)
+ .initElapsedMs(elapsedMs)
+ .toolCount(runtime.getToolCallbacks().length)
+ .testedAt(new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new Date()))
+ .build();
+ }
+ String message = persistHealth && mount.getId() != null
+ ? "杩為�氭�ф祴璇曟垚鍔燂紝瑙f瀽鍑� " + runtime.getToolCallbacks().length + " 涓伐鍏�"
+ : "鑽夌杩為�氭�ф祴璇曟垚鍔燂紝瑙f瀽鍑� " + runtime.getToolCallbacks().length + " 涓伐鍏�";
+ if (persistHealth && mount.getId() != null) {
+ updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_HEALTHY, message, elapsedMs);
+ AiMcpMount latest = requireMount(mount.getId(), mount.getTenantId());
+ return buildConnectivityDto(latest, message, elapsedMs, runtime.getToolCallbacks().length);
+ }
+ return AiMcpConnectivityTestDto.builder()
+ .mountId(mount.getId())
+ .mountName(mount.getName())
+ .healthStatus(AiDefaults.MCP_HEALTH_HEALTHY)
+ .message(message)
+ .initElapsedMs(elapsedMs)
+ .toolCount(runtime.getToolCallbacks().length)
+ .testedAt(new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new Date()))
+ .build();
+ } catch (CoolException e) {
+ throw e;
+ } catch (Exception e) {
+ long elapsedMs = System.currentTimeMillis() - startedAt;
+ String message = (persistHealth ? "杩為�氭�ф祴璇曞け璐�: " : "鑽夌杩為�氭�ф祴璇曞け璐�: ") + e.getMessage();
+ if (persistHealth && mount.getId() != null) {
+ updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY, message, elapsedMs);
+ AiMcpMount latest = requireMount(mount.getId(), mount.getTenantId());
+ return buildConnectivityDto(latest, message, elapsedMs, 0);
+ }
+ throw new CoolException(message);
+ }
+ }
+
+ public AiMcpToolTestDto testTool(AiMcpMount mount, Long userId, Long tenantId, AiMcpToolTestRequest request) {
+ if (userId == null) {
+ throw new CoolException("褰撳墠鐧诲綍鐢ㄦ埛涓嶅瓨鍦�");
+ }
+ if (tenantId == null) {
+ throw new CoolException("褰撳墠绉熸埛涓嶅瓨鍦�");
+ }
+ if (request == null) {
+ throw new CoolException("宸ュ叿娴嬭瘯鍙傛暟涓嶈兘涓虹┖");
+ }
+ if (!StringUtils.hasText(request.getToolName())) {
+ throw new CoolException("宸ュ叿鍚嶇О涓嶈兘涓虹┖");
+ }
+ if (!StringUtils.hasText(request.getInputJson())) {
+ throw new CoolException("宸ュ叿杈撳叆 JSON 涓嶈兘涓虹┖");
+ }
+ try {
+ objectMapper.readTree(request.getInputJson());
+ } catch (Exception e) {
+ throw new CoolException("宸ュ叿杈撳叆 JSON 鏍煎紡閿欒: " + e.getMessage());
+ }
+ long startedAt = System.currentTimeMillis();
+ try (McpMountRuntimeFactory.McpMountRuntime runtime = mcpMountRuntimeFactory.create(List.of(mount), userId)) {
+ ToolCallback callback = Arrays.stream(runtime.getToolCallbacks())
+ .filter(item -> item != null && item.getToolDefinition() != null)
+ .filter(item -> request.getToolName().equals(item.getToolDefinition().name()))
+ .findFirst()
+ .orElseThrow(() -> new CoolException("鏈壘鍒拌娴嬭瘯鐨勫伐鍏�: " + request.getToolName()));
+ String output = callback.call(
+ request.getInputJson(),
+ new ToolContext(Map.of("userId", userId, "tenantId", tenantId, "mountId", mount.getId()))
+ );
+ if (mount.getId() != null) {
+ updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_HEALTHY,
+ "宸ュ叿娴嬭瘯鎴愬姛: " + request.getToolName(), System.currentTimeMillis() - startedAt);
+ }
+ return AiMcpToolTestDto.builder()
+ .toolName(request.getToolName())
+ .inputJson(request.getInputJson())
+ .output(output)
+ .build();
+ } catch (CoolException e) {
+ if (mount.getId() != null) {
+ updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY,
+ "宸ュ叿娴嬭瘯澶辫触: " + e.getMessage(), System.currentTimeMillis() - startedAt);
+ }
+ throw e;
+ } catch (Exception e) {
+ if (mount.getId() != null) {
+ updateHealthStatus(mount.getId(), AiDefaults.MCP_HEALTH_UNHEALTHY,
+ "宸ュ叿娴嬭瘯澶辫触: " + e.getMessage(), System.currentTimeMillis() - startedAt);
+ }
+ throw new CoolException("宸ュ叿娴嬭瘯澶辫触: " + e.getMessage());
+ }
+ }
+
+ public AiMcpMount requireMount(Long mountId, Long tenantId) {
+ if (tenantId == null) {
+ throw new CoolException("褰撳墠绉熸埛涓嶅瓨鍦�");
+ }
+ if (mountId == null) {
+ throw new CoolException("MCP 鎸傝浇 ID 涓嶈兘涓虹┖");
+ }
+ AiMcpMount mount = aiMcpMountMapper.selectOne(new LambdaQueryWrapper<AiMcpMount>()
+ .eq(AiMcpMount::getId, mountId)
+ .eq(AiMcpMount::getTenantId, tenantId)
+ .eq(AiMcpMount::getDeleted, 0)
+ .last("limit 1"));
+ if (mount == null) {
+ throw new CoolException("MCP 鎸傝浇涓嶅瓨鍦�");
+ }
+ return mount;
+ }
+
+ private List<AiMcpToolPreviewDto> buildToolPreviewDtos(ToolCallback[] callbacks, List<AiMcpToolPreviewDto> governedCatalog) {
+ List<AiMcpToolPreviewDto> tools = new ArrayList<>();
+ Map<String, AiMcpToolPreviewDto> catalogMap = new LinkedHashMap<>();
+ for (AiMcpToolPreviewDto item : governedCatalog) {
+ if (item == null || !StringUtils.hasText(item.getName())) {
+ continue;
+ }
+ catalogMap.put(item.getName(), item);
+ }
+ for (ToolCallback callback : callbacks) {
+ if (callback == null || callback.getToolDefinition() == null) {
+ continue;
+ }
+ AiMcpToolPreviewDto governedItem = catalogMap.get(callback.getToolDefinition().name());
+ tools.add(AiMcpToolPreviewDto.builder()
+ .name(callback.getToolDefinition().name())
+ .description(callback.getToolDefinition().description())
+ .inputSchema(callback.getToolDefinition().inputSchema())
+ .returnDirect(callback.getToolMetadata() == null ? null : callback.getToolMetadata().returnDirect())
+ .toolGroup(governedItem == null ? null : governedItem.getToolGroup())
+ .toolPurpose(governedItem == null ? null : governedItem.getToolPurpose())
+ .queryBoundary(governedItem == null ? null : governedItem.getQueryBoundary())
+ .exampleQuestions(governedItem == null ? List.of() : governedItem.getExampleQuestions())
+ .build());
+ }
+ return tools;
+ }
+
+ private void updateHealthStatus(Long mountId, String healthStatus, String message, Long initElapsedMs) {
+ aiMcpMountMapper.update(null, new LambdaUpdateWrapper<AiMcpMount>()
+ .eq(AiMcpMount::getId, mountId)
+ .set(AiMcpMount::getHealthStatus, healthStatus)
+ .set(AiMcpMount::getLastTestTime, new Date())
+ .set(AiMcpMount::getLastTestMessage, message)
+ .set(AiMcpMount::getLastInitElapsedMs, initElapsedMs));
+ }
+
+ private AiMcpConnectivityTestDto buildConnectivityDto(AiMcpMount mount, String message, Long initElapsedMs, Integer toolCount) {
+ return AiMcpConnectivityTestDto.builder()
+ .mountId(mount.getId())
+ .mountName(mount.getName())
+ .healthStatus(mount.getHealthStatus())
+ .message(message)
+ .initElapsedMs(initElapsedMs)
+ .toolCount(toolCount)
+ .testedAt(mount.getLastTestTime$())
+ .build();
+ }
+}
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/mcp/BuiltinMcpToolCatalogProvider.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/mcp/BuiltinMcpToolCatalogProvider.java
new file mode 100644
index 0000000..50d8836
--- /dev/null
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/mcp/BuiltinMcpToolCatalogProvider.java
@@ -0,0 +1,79 @@
+package com.vincent.rsf.server.ai.service.impl.mcp;
+
+import com.vincent.rsf.framework.exception.CoolException;
+import com.vincent.rsf.server.ai.config.AiDefaults;
+import com.vincent.rsf.server.ai.dto.AiMcpToolPreviewDto;
+import org.springframework.stereotype.Component;
+
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+
+@Component
+public class BuiltinMcpToolCatalogProvider {
+
+ public List<String> supportedBuiltinCodes() {
+ return List.of(AiDefaults.MCP_BUILTIN_RSF_WMS);
+ }
+
+ public Map<String, AiMcpToolPreviewDto> getCatalog(String builtinCode) {
+ if (AiDefaults.MCP_BUILTIN_RSF_WMS.equals(builtinCode)) {
+ Map<String, AiMcpToolPreviewDto> catalog = new LinkedHashMap<>();
+ catalog.put("rsf_query_available_inventory", buildCatalogItem(
+ "rsf_query_available_inventory",
+ "搴撳瓨鏌ヨ",
+ "鏌ヨ鎸囧畾鐗╂枡褰撳墠鍙敤浜庡嚭搴撶殑搴撳瓨鏄庣粏銆�",
+ "蹇呴』鎻愪緵鐗╂枡缂栫爜鎴栫墿鏂欏悕绉帮紝骞朵笖鏈�澶氳繑鍥� 50 鏉″簱瀛樿褰曘��",
+ List.of("鏌ヨ鐗╂枡 MAT001 褰撳墠鍙嚭搴撳簱瀛�", "鎸夌墿鏂欏悕绉版煡璇㈡墭鐩樺簱瀛樻槑缁�")
+ ));
+ catalog.put("rsf_query_station_list", buildCatalogItem(
+ "rsf_query_station_list",
+ "搴撳瓨鏌ヨ",
+ "鏌ヨ鎸囧畾浣滀笟绫诲瀷鍙敤鐨勮澶囩珯鐐广��",
+ "蹇呴』鎻愪緵绔欑偣绫诲瀷鍒楄〃锛岀被鍨嬫暟閲忔渶澶� 10 涓紝鏈�澶氳繑鍥� 50 涓珯鐐广��",
+ List.of("鏌ヨ鍏ュ簱鍜屽嚭搴撲綔涓氬彲鐢ㄧ珯鐐�", "鍒楀嚭 AGV_PICK 绫诲瀷鐨勪綔涓氱珯鐐�")
+ ));
+ catalog.put("rsf_query_task", buildCatalogItem(
+ "rsf_query_task",
+ "浠诲姟鏌ヨ",
+ "鎸変换鍔″彿銆佺姸鎬併�佺被鍨嬫垨绔欑偣鏉′欢鏌ヨ浠诲姟锛涙敮鎸佷粠鑷劧璇█涓嚜鍔ㄦ彁鍙栦换鍔″彿锛岀簿纭懡涓椂杩斿洖浠诲姟鏄庣粏銆�",
+ "杩囨护鏉′欢鍧囦负鍙�夛紝涓嶄紶杩囨护鏉′欢鏃堕粯璁よ繑鍥炴渶杩戜换鍔★紝鏈�澶氳繑鍥� 50 鏉¤褰曘��",
+ List.of("鏌ヨ鏈�杩� 10 鏉′换鍔�", "鏌ヨ浠诲姟鍙� TASK24001 鐨勮鎯�")
+ ));
+ catalog.put("rsf_query_warehouses", buildCatalogItem(
+ "rsf_query_warehouses",
+ "鍩虹璧勬枡",
+ "鏌ヨ浠撳簱鍩虹淇℃伅銆�",
+ "鑷冲皯鎻愪緵浠撳簱缂栫爜鎴栧悕绉帮紝鏈�澶氳繑鍥� 50 鏉′粨搴撹褰曘��",
+ List.of("鏌ヨ缂栫爜鍖呭惈 WH 鐨勪粨搴�", "鎸変粨搴撳悕绉版煡璇粨搴撳湴鍧�")
+ ));
+ catalog.put("rsf_query_bas_stations", buildCatalogItem(
+ "rsf_query_bas_stations",
+ "鍩虹璧勬枡",
+ "鏌ヨ鍩虹绔欑偣淇℃伅銆�",
+ "鑷冲皯鎻愪緵绔欑偣缂栧彿銆佺珯鐐瑰悕绉版垨浣跨敤鐘舵�佷箣涓�锛屾渶澶氳繑鍥� 50 鏉$珯鐐硅褰曘��",
+ List.of("鏌ヨ浣跨敤涓殑鍩虹绔欑偣", "鎸夌珯鐐圭紪鍙锋煡璇㈠熀纭�绔欑偣")
+ ));
+ catalog.put("rsf_query_dict_data", buildCatalogItem(
+ "rsf_query_dict_data",
+ "鍩虹璧勬枡",
+ "鏌ヨ鎸囧畾瀛楀吀绫诲瀷涓嬬殑瀛楀吀鏁版嵁銆�",
+ "蹇呴』鎻愪緵瀛楀吀绫诲瀷缂栫爜锛屾渶澶氳繑鍥� 100 鏉″瓧鍏歌褰曘��",
+ List.of("鏌ヨ task_status 瀛楀吀", "鎸夊瓧鍏告爣绛捐繃婊� task_type 瀛楀吀鏁版嵁")
+ ));
+ return catalog;
+ }
+ throw new CoolException("涓嶆敮鎸佺殑鍐呯疆 MCP 缂栫爜: " + builtinCode);
+ }
+
+ private AiMcpToolPreviewDto buildCatalogItem(String name, String toolGroup, String toolPurpose,
+ String queryBoundary, List<String> exampleQuestions) {
+ return AiMcpToolPreviewDto.builder()
+ .name(name)
+ .toolGroup(toolGroup)
+ .toolPurpose(toolPurpose)
+ .queryBoundary(queryBoundary)
+ .exampleQuestions(exampleQuestions)
+ .build();
+ }
+}
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/mcp/McpClientFactory.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/mcp/McpClientFactory.java
new file mode 100644
index 0000000..15f8749
--- /dev/null
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/mcp/McpClientFactory.java
@@ -0,0 +1,104 @@
+package com.vincent.rsf.server.ai.service.impl.mcp;
+
+import com.fasterxml.jackson.core.type.TypeReference;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.vincent.rsf.framework.exception.CoolException;
+import com.vincent.rsf.server.ai.config.AiDefaults;
+import com.vincent.rsf.server.ai.entity.AiMcpMount;
+import io.modelcontextprotocol.client.McpClient;
+import io.modelcontextprotocol.client.McpSyncClient;
+import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
+import io.modelcontextprotocol.client.transport.ServerParameters;
+import io.modelcontextprotocol.client.transport.StdioClientTransport;
+import io.modelcontextprotocol.json.jackson.JacksonMcpJsonMapper;
+import io.modelcontextprotocol.spec.McpSchema;
+import lombok.RequiredArgsConstructor;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.stereotype.Component;
+import org.springframework.util.StringUtils;
+
+import java.time.Duration;
+import java.util.Collections;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+
+@Slf4j
+@Component
+@RequiredArgsConstructor
+public class McpClientFactory {
+
+ private final ObjectMapper objectMapper;
+
+ public McpSyncClient createClient(AiMcpMount mount) {
+ Duration timeout = Duration.ofMillis(mount.getRequestTimeoutMs() == null
+ ? AiDefaults.DEFAULT_TIMEOUT_MS
+ : mount.getRequestTimeoutMs());
+ JacksonMcpJsonMapper jsonMapper = new JacksonMcpJsonMapper(objectMapper);
+ if (AiDefaults.MCP_TRANSPORT_STDIO.equals(mount.getTransportType())) {
+ ServerParameters.Builder parametersBuilder = ServerParameters.builder(mount.getCommand());
+ List<String> args = readStringList(mount.getArgsJson());
+ if (!args.isEmpty()) {
+ parametersBuilder.args(args);
+ }
+ Map<String, String> env = readStringMap(mount.getEnvJson());
+ if (!env.isEmpty()) {
+ parametersBuilder.env(env);
+ }
+ StdioClientTransport transport = new StdioClientTransport(parametersBuilder.build(), jsonMapper);
+ transport.setStdErrorHandler(message -> log.warn("MCP STDIO stderr [{}]: {}", mount.getName(), message));
+ return McpClient.sync(transport)
+ .requestTimeout(timeout)
+ .initializationTimeout(timeout)
+ .clientInfo(new McpSchema.Implementation("rsf-ai-client", "RSF AI Client", "1.0.0"))
+ .build();
+ }
+ if (!AiDefaults.MCP_TRANSPORT_SSE_HTTP.equals(mount.getTransportType())) {
+ throw new CoolException("涓嶆敮鎸佺殑 MCP 浼犺緭绫诲瀷: " + mount.getTransportType());
+ }
+
+ if (!StringUtils.hasText(mount.getServerUrl())) {
+ throw new CoolException("MCP 鏈嶅姟鍦板潃涓嶈兘涓虹┖");
+ }
+ HttpClientSseClientTransport.Builder transportBuilder = HttpClientSseClientTransport.builder(mount.getServerUrl())
+ .jsonMapper(jsonMapper)
+ .connectTimeout(timeout);
+ if (StringUtils.hasText(mount.getEndpoint())) {
+ transportBuilder.sseEndpoint(mount.getEndpoint());
+ }
+ Map<String, String> headers = readStringMap(mount.getHeadersJson());
+ if (!headers.isEmpty()) {
+ transportBuilder.customizeRequest(builder -> headers.forEach(builder::header));
+ }
+ return McpClient.sync(transportBuilder.build())
+ .requestTimeout(timeout)
+ .initializationTimeout(timeout)
+ .clientInfo(new McpSchema.Implementation("rsf-ai-client", "RSF AI Client", "1.0.0"))
+ .build();
+ }
+
+ private List<String> readStringList(String json) {
+ if (!StringUtils.hasText(json)) {
+ return Collections.emptyList();
+ }
+ try {
+ return objectMapper.readValue(json, new TypeReference<List<String>>() {
+ });
+ } catch (Exception e) {
+ throw new CoolException("瑙f瀽 MCP 鍒楄〃閰嶇疆澶辫触: " + e.getMessage());
+ }
+ }
+
+ private Map<String, String> readStringMap(String json) {
+ if (!StringUtils.hasText(json)) {
+ return Collections.emptyMap();
+ }
+ try {
+ Map<String, String> result = objectMapper.readValue(json, new TypeReference<LinkedHashMap<String, String>>() {
+ });
+ return result == null ? Collections.emptyMap() : result;
+ } catch (Exception e) {
+ throw new CoolException("瑙f瀽 MCP Map 閰嶇疆澶辫触: " + e.getMessage());
+ }
+ }
+}
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/store/AiCachedToolResult.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/store/AiCachedToolResult.java
new file mode 100644
index 0000000..e698510
--- /dev/null
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/store/AiCachedToolResult.java
@@ -0,0 +1,19 @@
+package com.vincent.rsf.server.ai.store;
+
+import lombok.AllArgsConstructor;
+import lombok.Builder;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+
+@Data
+@Builder
+@NoArgsConstructor
+@AllArgsConstructor
+public class AiCachedToolResult {
+
+ private boolean success;
+
+ private String output;
+
+ private String errorMessage;
+}
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/store/AiChatRateLimiter.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/store/AiChatRateLimiter.java
new file mode 100644
index 0000000..177218d
--- /dev/null
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/store/AiChatRateLimiter.java
@@ -0,0 +1,36 @@
+package com.vincent.rsf.server.ai.store;
+
+import com.vincent.rsf.server.ai.config.AiDefaults;
+import com.vincent.rsf.server.ai.store.support.AiRedisExecutor;
+import com.vincent.rsf.server.ai.store.support.AiRedisKeys;
+import lombok.RequiredArgsConstructor;
+import org.springframework.stereotype.Component;
+
+import java.time.Instant;
+import java.util.UUID;
+
+@Component
+@RequiredArgsConstructor
+public class AiChatRateLimiter {
+
+ private final AiRedisExecutor aiRedisExecutor;
+ private final AiRedisKeys aiRedisKeys;
+
+ public boolean allowChatRequest(Long tenantId, Long userId, String promptCode) {
+ String key = aiRedisKeys.buildRateLimitKey(tenantId, userId, promptCode);
+ long now = Instant.now().toEpochMilli();
+ long windowStart = now - (AiDefaults.CHAT_RATE_LIMIT_WINDOW_SECONDS * 1000L);
+ Boolean allowed = aiRedisExecutor.execute(jedis -> {
+ jedis.zremrangeByScore(key, 0, windowStart);
+ long count = jedis.zcard(key);
+ if (count >= AiDefaults.CHAT_RATE_LIMIT_MAX_REQUESTS) {
+ jedis.expire(key, AiDefaults.CHAT_RATE_LIMIT_WINDOW_SECONDS);
+ return Boolean.FALSE;
+ }
+ jedis.zadd(key, now, now + ":" + UUID.randomUUID());
+ jedis.expire(key, AiDefaults.CHAT_RATE_LIMIT_WINDOW_SECONDS);
+ return Boolean.TRUE;
+ });
+ return Boolean.TRUE.equals(allowed);
+ }
+}
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/store/AiConfigCacheStore.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/store/AiConfigCacheStore.java
new file mode 100644
index 0000000..f2093d2
--- /dev/null
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/store/AiConfigCacheStore.java
@@ -0,0 +1,32 @@
+package com.vincent.rsf.server.ai.store;
+
+import com.vincent.rsf.server.ai.config.AiDefaults;
+import com.vincent.rsf.server.ai.dto.AiResolvedConfig;
+import com.vincent.rsf.server.ai.store.support.AiRedisExecutor;
+import com.vincent.rsf.server.ai.store.support.AiRedisIndexSupport;
+import com.vincent.rsf.server.ai.store.support.AiRedisKeys;
+import lombok.RequiredArgsConstructor;
+import org.springframework.stereotype.Component;
+
+@Component
+@RequiredArgsConstructor
+public class AiConfigCacheStore {
+
+ private final AiRedisExecutor aiRedisExecutor;
+ private final AiRedisIndexSupport aiRedisIndexSupport;
+ private final AiRedisKeys aiRedisKeys;
+
+ public AiResolvedConfig getResolvedConfig(Long tenantId, String promptCode, Long aiParamId) {
+ return aiRedisExecutor.readJson(aiRedisKeys.buildConfigKey(tenantId, promptCode, aiParamId), AiResolvedConfig.class);
+ }
+
+ public void cacheResolvedConfig(Long tenantId, String promptCode, Long aiParamId, AiResolvedConfig config) {
+ String key = aiRedisKeys.buildConfigKey(tenantId, promptCode, aiParamId);
+ aiRedisExecutor.writeJson(key, config, AiDefaults.CONFIG_CACHE_TTL_SECONDS);
+ aiRedisIndexSupport.remember(aiRedisKeys.buildConfigIndexKey(tenantId), key);
+ }
+
+ public void evictTenantConfigCaches(Long tenantId) {
+ aiRedisIndexSupport.deleteTrackedKeys(aiRedisKeys.buildConfigIndexKey(tenantId));
+ }
+}
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/store/AiConversationCacheStore.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/store/AiConversationCacheStore.java
new file mode 100644
index 0000000..c415c87
--- /dev/null
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/store/AiConversationCacheStore.java
@@ -0,0 +1,67 @@
+package com.vincent.rsf.server.ai.store;
+
+import com.fasterxml.jackson.core.type.TypeReference;
+import com.vincent.rsf.server.ai.config.AiDefaults;
+import com.vincent.rsf.server.ai.dto.AiChatMemoryDto;
+import com.vincent.rsf.server.ai.dto.AiChatRuntimeDto;
+import com.vincent.rsf.server.ai.dto.AiChatSessionDto;
+import com.vincent.rsf.server.ai.store.support.AiRedisExecutor;
+import com.vincent.rsf.server.ai.store.support.AiRedisIndexSupport;
+import com.vincent.rsf.server.ai.store.support.AiRedisKeys;
+import lombok.RequiredArgsConstructor;
+import org.springframework.stereotype.Component;
+
+import java.util.List;
+
+@Component
+@RequiredArgsConstructor
+public class AiConversationCacheStore {
+
+ private final AiRedisExecutor aiRedisExecutor;
+ private final AiRedisIndexSupport aiRedisIndexSupport;
+ private final AiRedisKeys aiRedisKeys;
+
+ public AiChatRuntimeDto getRuntime(Long tenantId, Long userId, String promptCode, Long sessionId, Long aiParamId) {
+ return aiRedisExecutor.readJson(aiRedisKeys.buildRuntimeKey(tenantId, userId, promptCode, sessionId, aiParamId), AiChatRuntimeDto.class);
+ }
+
+ public void cacheRuntime(Long tenantId, Long userId, String promptCode, Long sessionId, Long aiParamId, AiChatRuntimeDto runtime) {
+ String key = aiRedisKeys.buildRuntimeKey(tenantId, userId, promptCode, sessionId, aiParamId);
+ aiRedisExecutor.writeJson(key, runtime, AiDefaults.RUNTIME_CACHE_TTL_SECONDS);
+ rememberConversationKey(tenantId, userId, key);
+ aiRedisIndexSupport.remember(aiRedisKeys.buildTenantRuntimeIndexKey(tenantId), key);
+ }
+
+ public AiChatMemoryDto getMemory(Long tenantId, Long userId, String promptCode, Long sessionId) {
+ return aiRedisExecutor.readJson(aiRedisKeys.buildMemoryKey(tenantId, userId, promptCode, sessionId), AiChatMemoryDto.class);
+ }
+
+ public void cacheMemory(Long tenantId, Long userId, String promptCode, Long sessionId, AiChatMemoryDto memory) {
+ String key = aiRedisKeys.buildMemoryKey(tenantId, userId, promptCode, sessionId);
+ aiRedisExecutor.writeJson(key, memory, AiDefaults.MEMORY_CACHE_TTL_SECONDS);
+ rememberConversationKey(tenantId, userId, key);
+ }
+
+ public List<AiChatSessionDto> getSessionList(Long tenantId, Long userId, String promptCode, String keyword) {
+ return aiRedisExecutor.readJson(aiRedisKeys.buildSessionsKey(tenantId, userId, promptCode, keyword), new TypeReference<List<AiChatSessionDto>>() {
+ });
+ }
+
+ public void cacheSessionList(Long tenantId, Long userId, String promptCode, String keyword, List<AiChatSessionDto> sessions) {
+ String key = aiRedisKeys.buildSessionsKey(tenantId, userId, promptCode, keyword);
+ aiRedisExecutor.writeJson(key, sessions, AiDefaults.SESSION_LIST_CACHE_TTL_SECONDS);
+ rememberConversationKey(tenantId, userId, key);
+ }
+
+ public void evictUserConversationCaches(Long tenantId, Long userId) {
+ aiRedisIndexSupport.deleteTrackedKeys(aiRedisKeys.buildConversationIndexKey(tenantId, userId));
+ }
+
+ public void evictTenantRuntimeCaches(Long tenantId) {
+ aiRedisIndexSupport.deleteTrackedKeys(aiRedisKeys.buildTenantRuntimeIndexKey(tenantId));
+ }
+
+ private void rememberConversationKey(Long tenantId, Long userId, String key) {
+ aiRedisIndexSupport.remember(aiRedisKeys.buildConversationIndexKey(tenantId, userId), key);
+ }
+}
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/store/AiMcpCacheStore.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/store/AiMcpCacheStore.java
new file mode 100644
index 0000000..8a0d509
--- /dev/null
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/store/AiMcpCacheStore.java
@@ -0,0 +1,45 @@
+package com.vincent.rsf.server.ai.store;
+
+import com.fasterxml.jackson.core.type.TypeReference;
+import com.vincent.rsf.server.ai.config.AiDefaults;
+import com.vincent.rsf.server.ai.dto.AiMcpConnectivityTestDto;
+import com.vincent.rsf.server.ai.dto.AiMcpToolPreviewDto;
+import com.vincent.rsf.server.ai.store.support.AiRedisExecutor;
+import com.vincent.rsf.server.ai.store.support.AiRedisKeys;
+import lombok.RequiredArgsConstructor;
+import org.springframework.stereotype.Component;
+
+import java.util.List;
+
+@Component
+@RequiredArgsConstructor
+public class AiMcpCacheStore {
+
+ private final AiRedisExecutor aiRedisExecutor;
+ private final AiRedisKeys aiRedisKeys;
+
+ public List<AiMcpToolPreviewDto> getToolPreview(Long tenantId, Long mountId) {
+ return aiRedisExecutor.readJson(aiRedisKeys.buildMcpPreviewKey(tenantId, mountId), new TypeReference<List<AiMcpToolPreviewDto>>() {
+ });
+ }
+
+ public void cacheToolPreview(Long tenantId, Long mountId, List<AiMcpToolPreviewDto> tools) {
+ aiRedisExecutor.writeJson(aiRedisKeys.buildMcpPreviewKey(tenantId, mountId), tools, AiDefaults.MCP_PREVIEW_CACHE_TTL_SECONDS);
+ }
+
+ public AiMcpConnectivityTestDto getConnectivity(Long tenantId, Long mountId) {
+ return aiRedisExecutor.readJson(aiRedisKeys.buildMcpHealthKey(tenantId, mountId), AiMcpConnectivityTestDto.class);
+ }
+
+ public void cacheConnectivity(Long tenantId, Long mountId, AiMcpConnectivityTestDto connectivity) {
+ aiRedisExecutor.writeJson(aiRedisKeys.buildMcpHealthKey(tenantId, mountId), connectivity, AiDefaults.MCP_HEALTH_CACHE_TTL_SECONDS);
+ }
+
+ public void evictMcpMountCaches(Long tenantId, Long mountId) {
+ if (mountId == null) {
+ return;
+ }
+ aiRedisExecutor.delete(aiRedisKeys.buildMcpPreviewKey(tenantId, mountId));
+ aiRedisExecutor.delete(aiRedisKeys.buildMcpHealthKey(tenantId, mountId));
+ }
+}
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/store/AiObserveStatsStore.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/store/AiObserveStatsStore.java
new file mode 100644
index 0000000..f6c095a
--- /dev/null
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/store/AiObserveStatsStore.java
@@ -0,0 +1,163 @@
+package com.vincent.rsf.server.ai.store;
+
+import com.vincent.rsf.server.ai.dto.AiObserveStatsDto;
+import com.vincent.rsf.server.ai.store.support.AiRedisExecutor;
+import com.vincent.rsf.server.ai.store.support.AiRedisKeys;
+import lombok.RequiredArgsConstructor;
+import org.springframework.stereotype.Component;
+import org.springframework.util.StringUtils;
+
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.function.Supplier;
+
+@Component
+@RequiredArgsConstructor
+public class AiObserveStatsStore {
+
+ private static final String FIELD_CALL_COUNT = "callCount";
+ private static final String FIELD_SUCCESS_COUNT = "successCount";
+ private static final String FIELD_FAILURE_COUNT = "failureCount";
+ private static final String FIELD_ELAPSED_SUM = "elapsedSum";
+ private static final String FIELD_ELAPSED_COUNT = "elapsedCount";
+ private static final String FIELD_FIRST_TOKEN_SUM = "firstTokenSum";
+ private static final String FIELD_FIRST_TOKEN_COUNT = "firstTokenCount";
+ private static final String FIELD_TOTAL_TOKENS_SUM = "totalTokensSum";
+ private static final String FIELD_TOTAL_TOKENS_COUNT = "totalTokensCount";
+ private static final String FIELD_TOOL_CALL_COUNT = "toolCallCount";
+ private static final String FIELD_TOOL_SUCCESS_COUNT = "toolSuccessCount";
+ private static final String FIELD_TOOL_FAILURE_COUNT = "toolFailureCount";
+
+ private final AiRedisExecutor aiRedisExecutor;
+ private final AiRedisKeys aiRedisKeys;
+
+ public void recordObserveCallStarted(Long tenantId) {
+ aiRedisExecutor.executeVoid(jedis -> jedis.hincrBy(aiRedisKeys.buildObserveStatsKey(tenantId), FIELD_CALL_COUNT, 1));
+ }
+
+ public void recordObserveCallFinished(Long tenantId, String status, Long elapsedMs, Long firstTokenLatencyMs, Integer totalTokens) {
+ aiRedisExecutor.executeVoid(jedis -> {
+ String key = aiRedisKeys.buildObserveStatsKey(tenantId);
+ if ("COMPLETED".equals(status)) {
+ jedis.hincrBy(key, FIELD_SUCCESS_COUNT, 1);
+ } else if ("FAILED".equals(status)) {
+ jedis.hincrBy(key, FIELD_FAILURE_COUNT, 1);
+ }
+ if (elapsedMs != null) {
+ jedis.hincrBy(key, FIELD_ELAPSED_SUM, elapsedMs);
+ jedis.hincrBy(key, FIELD_ELAPSED_COUNT, 1);
+ }
+ if (firstTokenLatencyMs != null) {
+ jedis.hincrBy(key, FIELD_FIRST_TOKEN_SUM, firstTokenLatencyMs);
+ jedis.hincrBy(key, FIELD_FIRST_TOKEN_COUNT, 1);
+ }
+ if (totalTokens != null) {
+ jedis.hincrBy(key, FIELD_TOTAL_TOKENS_SUM, totalTokens.longValue());
+ jedis.hincrBy(key, FIELD_TOTAL_TOKENS_COUNT, 1);
+ }
+ });
+ }
+
+ public void recordObserveToolCall(Long tenantId, String toolName, String status) {
+ aiRedisExecutor.executeVoid(jedis -> {
+ String key = aiRedisKeys.buildObserveStatsKey(tenantId);
+ jedis.hincrBy(key, FIELD_TOOL_CALL_COUNT, 1);
+ if ("COMPLETED".equals(status)) {
+ jedis.hincrBy(key, FIELD_TOOL_SUCCESS_COUNT, 1);
+ } else if ("FAILED".equals(status)) {
+ jedis.hincrBy(key, FIELD_TOOL_FAILURE_COUNT, 1);
+ }
+ if (StringUtils.hasText(toolName)) {
+ jedis.zincrby(aiRedisKeys.buildToolRankKey(tenantId), 1D, toolName);
+ if ("FAILED".equals(status)) {
+ jedis.zincrby(aiRedisKeys.buildToolFailRankKey(tenantId), 1D, toolName);
+ }
+ }
+ });
+ }
+
+ public AiObserveStatsDto getObserveStats(Long tenantId, Supplier<AiObserveStatsDto> fallbackLoader) {
+ AiObserveStatsDto cached = readObserveStats(tenantId);
+ if (cached != null) {
+ return cached;
+ }
+ AiObserveStatsDto snapshot = fallbackLoader.get();
+ if (snapshot != null) {
+ seedObserveStats(tenantId, snapshot);
+ }
+ return snapshot;
+ }
+
+ private AiObserveStatsDto readObserveStats(Long tenantId) {
+ Map<String, String> fields = aiRedisExecutor.execute(jedis -> {
+ String key = aiRedisKeys.buildObserveStatsKey(tenantId);
+ if (!jedis.exists(key)) {
+ return null;
+ }
+ return jedis.hgetAll(key);
+ });
+ if (fields == null || fields.isEmpty()) {
+ return null;
+ }
+ long callCount = parseLong(fields.get(FIELD_CALL_COUNT));
+ long successCount = parseLong(fields.get(FIELD_SUCCESS_COUNT));
+ long failureCount = parseLong(fields.get(FIELD_FAILURE_COUNT));
+ long elapsedSum = parseLong(fields.get(FIELD_ELAPSED_SUM));
+ long elapsedCount = parseLong(fields.get(FIELD_ELAPSED_COUNT));
+ long firstTokenSum = parseLong(fields.get(FIELD_FIRST_TOKEN_SUM));
+ long firstTokenCount = parseLong(fields.get(FIELD_FIRST_TOKEN_COUNT));
+ long totalTokensSum = parseLong(fields.get(FIELD_TOTAL_TOKENS_SUM));
+ long totalTokensCount = parseLong(fields.get(FIELD_TOTAL_TOKENS_COUNT));
+ long toolCallCount = parseLong(fields.get(FIELD_TOOL_CALL_COUNT));
+ long toolSuccessCount = parseLong(fields.get(FIELD_TOOL_SUCCESS_COUNT));
+ long toolFailureCount = parseLong(fields.get(FIELD_TOOL_FAILURE_COUNT));
+ return AiObserveStatsDto.builder()
+ .callCount(callCount)
+ .successCount(successCount)
+ .failureCount(failureCount)
+ .avgElapsedMs(elapsedCount == 0 ? 0L : elapsedSum / elapsedCount)
+ .avgFirstTokenLatencyMs(firstTokenCount == 0 ? 0L : firstTokenSum / firstTokenCount)
+ .totalTokens(totalTokensSum)
+ .avgTotalTokens(totalTokensCount == 0 ? 0L : totalTokensSum / totalTokensCount)
+ .toolCallCount(toolCallCount)
+ .toolSuccessCount(toolSuccessCount)
+ .toolFailureCount(toolFailureCount)
+ .toolSuccessRate(toolCallCount == 0 ? 0D : (toolSuccessCount * 100D) / toolCallCount)
+ .build();
+ }
+
+ private void seedObserveStats(Long tenantId, AiObserveStatsDto snapshot) {
+ aiRedisExecutor.executeVoid(jedis -> {
+ String key = aiRedisKeys.buildObserveStatsKey(tenantId);
+ Map<String, String> values = new LinkedHashMap<>();
+ values.put(FIELD_CALL_COUNT, String.valueOf(defaultLong(snapshot.getCallCount())));
+ values.put(FIELD_SUCCESS_COUNT, String.valueOf(defaultLong(snapshot.getSuccessCount())));
+ values.put(FIELD_FAILURE_COUNT, String.valueOf(defaultLong(snapshot.getFailureCount())));
+ values.put(FIELD_ELAPSED_SUM, String.valueOf(defaultLong(snapshot.getAvgElapsedMs()) * defaultLong(snapshot.getCallCount())));
+ values.put(FIELD_ELAPSED_COUNT, String.valueOf(defaultLong(snapshot.getCallCount())));
+ values.put(FIELD_FIRST_TOKEN_SUM, String.valueOf(defaultLong(snapshot.getAvgFirstTokenLatencyMs()) * defaultLong(snapshot.getCallCount())));
+ values.put(FIELD_FIRST_TOKEN_COUNT, String.valueOf(defaultLong(snapshot.getCallCount())));
+ values.put(FIELD_TOTAL_TOKENS_SUM, String.valueOf(defaultLong(snapshot.getTotalTokens())));
+ values.put(FIELD_TOTAL_TOKENS_COUNT, String.valueOf(defaultLong(snapshot.getCallCount())));
+ values.put(FIELD_TOOL_CALL_COUNT, String.valueOf(defaultLong(snapshot.getToolCallCount())));
+ values.put(FIELD_TOOL_SUCCESS_COUNT, String.valueOf(defaultLong(snapshot.getToolSuccessCount())));
+ values.put(FIELD_TOOL_FAILURE_COUNT, String.valueOf(defaultLong(snapshot.getToolFailureCount())));
+ jedis.hset(key, values);
+ });
+ }
+
+ private long parseLong(String source) {
+ if (!StringUtils.hasText(source)) {
+ return 0L;
+ }
+ try {
+ return Long.parseLong(source);
+ } catch (Exception e) {
+ return 0L;
+ }
+ }
+
+ private long defaultLong(Long value) {
+ return value == null ? 0L : value;
+ }
+}
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/store/AiStreamStateStore.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/store/AiStreamStateStore.java
new file mode 100644
index 0000000..9901393
--- /dev/null
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/store/AiStreamStateStore.java
@@ -0,0 +1,62 @@
+package com.vincent.rsf.server.ai.store;
+
+import com.vincent.rsf.server.ai.config.AiDefaults;
+import com.vincent.rsf.server.ai.store.support.AiRedisExecutor;
+import com.vincent.rsf.server.ai.store.support.AiRedisKeys;
+import lombok.AllArgsConstructor;
+import lombok.Builder;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+import lombok.RequiredArgsConstructor;
+import org.springframework.stereotype.Component;
+import org.springframework.util.StringUtils;
+
+import java.time.Instant;
+
+@Component
+@RequiredArgsConstructor
+public class AiStreamStateStore {
+
+ private final AiRedisExecutor aiRedisExecutor;
+ private final AiRedisKeys aiRedisKeys;
+
+ public void markStreamState(String requestId, Long tenantId, Long userId, Long sessionId, String promptCode,
+ String status, String errorMessage) {
+ if (!StringUtils.hasText(requestId)) {
+ return;
+ }
+ aiRedisExecutor.writeJson(aiRedisKeys.buildStreamStateKey(tenantId, requestId), AiStreamState.builder()
+ .requestId(requestId)
+ .tenantId(tenantId)
+ .userId(userId)
+ .sessionId(sessionId)
+ .promptCode(promptCode)
+ .status(status)
+ .errorMessage(errorMessage)
+ .timestamp(Instant.now().toEpochMilli())
+ .build(), AiDefaults.STREAM_STATE_TTL_SECONDS);
+ }
+
+ @Data
+ @Builder
+ @NoArgsConstructor
+ @AllArgsConstructor
+ private static class AiStreamState {
+
+ private String requestId;
+
+ private Long tenantId;
+
+ private Long userId;
+
+ private Long sessionId;
+
+ private String promptCode;
+
+ private String status;
+
+ private String errorMessage;
+
+ private Long timestamp;
+ }
+}
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/store/AiToolResultStore.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/store/AiToolResultStore.java
new file mode 100644
index 0000000..a29c958
--- /dev/null
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/store/AiToolResultStore.java
@@ -0,0 +1,30 @@
+package com.vincent.rsf.server.ai.store;
+
+import com.vincent.rsf.server.ai.config.AiDefaults;
+import com.vincent.rsf.server.ai.store.support.AiRedisExecutor;
+import com.vincent.rsf.server.ai.store.support.AiRedisKeys;
+import lombok.RequiredArgsConstructor;
+import org.springframework.stereotype.Component;
+
+@Component
+@RequiredArgsConstructor
+public class AiToolResultStore {
+
+ private final AiRedisExecutor aiRedisExecutor;
+ private final AiRedisKeys aiRedisKeys;
+
+ public AiCachedToolResult getToolResult(Long tenantId, String requestId, String toolName, String toolInput) {
+ return aiRedisExecutor.readJson(aiRedisKeys.buildToolResultKey(tenantId, requestId, toolName, toolInput), AiCachedToolResult.class);
+ }
+
+ public void cacheToolResult(Long tenantId, String requestId, String toolName, String toolInput,
+ boolean success, String output, String errorMessage) {
+ aiRedisExecutor.writeJson(aiRedisKeys.buildToolResultKey(tenantId, requestId, toolName, toolInput),
+ AiCachedToolResult.builder()
+ .success(success)
+ .output(output)
+ .errorMessage(errorMessage)
+ .build(),
+ AiDefaults.TOOL_RESULT_CACHE_TTL_SECONDS);
+ }
+}
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/store/support/AiRedisExecutor.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/store/support/AiRedisExecutor.java
new file mode 100644
index 0000000..287ddfe
--- /dev/null
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/store/support/AiRedisExecutor.java
@@ -0,0 +1,94 @@
+package com.vincent.rsf.server.ai.store.support;
+
+import com.fasterxml.jackson.core.type.TypeReference;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.vincent.rsf.server.common.service.RedisService;
+import lombok.RequiredArgsConstructor;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.stereotype.Component;
+import org.springframework.util.StringUtils;
+import redis.clients.jedis.Jedis;
+
+import java.util.function.Consumer;
+import java.util.function.Function;
+
+@Slf4j
+@Component
+@RequiredArgsConstructor
+public class AiRedisExecutor {
+
+ private final RedisService redisService;
+ private final ObjectMapper objectMapper;
+
+ public <T> T readJson(String key, Class<T> type) {
+ return readJson(key, value -> objectMapper.readValue(value, type));
+ }
+
+ public <T> T readJson(String key, TypeReference<T> typeReference) {
+ return readJson(key, value -> objectMapper.readValue(value, typeReference));
+ }
+
+ public <T> T readJson(String key, JsonReader<T> reader) {
+ return execute(jedis -> {
+ String value = jedis.get(key);
+ if (!StringUtils.hasText(value)) {
+ return null;
+ }
+ try {
+ return reader.read(value);
+ } catch (Exception e) {
+ log.warn("AI redis cache deserialize failed, key={}, message={}", key, e.getMessage());
+ jedis.del(key);
+ return null;
+ }
+ });
+ }
+
+ public void writeJson(String key, Object value, int ttlSeconds) {
+ if (value == null) {
+ delete(key);
+ return;
+ }
+ executeVoid(jedis -> {
+ try {
+ jedis.setex(key, ttlSeconds, objectMapper.writeValueAsString(value));
+ } catch (Exception e) {
+ log.warn("AI redis cache serialize failed, key={}, message={}", key, e.getMessage());
+ }
+ });
+ }
+
+ public void delete(String key) {
+ executeVoid(jedis -> jedis.del(key));
+ }
+
+ public <T> T execute(Function<Jedis, T> action) {
+ Jedis jedis = redisService.getJedis();
+ if (jedis == null) {
+ return null;
+ }
+ try (jedis) {
+ return action.apply(jedis);
+ } catch (Exception e) {
+ log.warn("AI redis operation skipped, message={}", e.getMessage());
+ return null;
+ }
+ }
+
+ public void executeVoid(Consumer<Jedis> action) {
+ Jedis jedis = redisService.getJedis();
+ if (jedis == null) {
+ return;
+ }
+ try (jedis) {
+ action.accept(jedis);
+ } catch (Exception e) {
+ log.warn("AI redis operation skipped, message={}", e.getMessage());
+ }
+ }
+
+ @FunctionalInterface
+ public interface JsonReader<T> {
+ T read(String value) throws Exception;
+ }
+}
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/store/support/AiRedisIndexSupport.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/store/support/AiRedisIndexSupport.java
new file mode 100644
index 0000000..ee08698
--- /dev/null
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/store/support/AiRedisIndexSupport.java
@@ -0,0 +1,40 @@
+package com.vincent.rsf.server.ai.store.support;
+
+import lombok.RequiredArgsConstructor;
+import org.springframework.stereotype.Component;
+
+import java.util.Set;
+
+@Component
+@RequiredArgsConstructor
+public class AiRedisIndexSupport {
+
+ private final AiRedisExecutor aiRedisExecutor;
+
+ public void remember(String indexKey, String cacheKey) {
+ if (indexKey == null || cacheKey == null) {
+ return;
+ }
+ aiRedisExecutor.executeVoid(jedis -> jedis.sadd(indexKey, cacheKey));
+ }
+
+ public void forget(String indexKey, String cacheKey) {
+ if (indexKey == null || cacheKey == null) {
+ return;
+ }
+ aiRedisExecutor.executeVoid(jedis -> jedis.srem(indexKey, cacheKey));
+ }
+
+ public void deleteTrackedKeys(String indexKey) {
+ if (indexKey == null) {
+ return;
+ }
+ aiRedisExecutor.executeVoid(jedis -> {
+ Set<String> keys = jedis.smembers(indexKey);
+ if (keys != null && !keys.isEmpty()) {
+ jedis.del(keys.toArray(new String[0]));
+ }
+ jedis.del(indexKey);
+ });
+ }
+}
diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/store/support/AiRedisKeys.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/store/support/AiRedisKeys.java
new file mode 100644
index 0000000..9f5bbd7
--- /dev/null
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/store/support/AiRedisKeys.java
@@ -0,0 +1,117 @@
+package com.vincent.rsf.server.ai.store.support;
+
+import org.springframework.stereotype.Component;
+import org.springframework.util.StringUtils;
+
+import java.net.URLEncoder;
+import java.nio.charset.StandardCharsets;
+import java.security.MessageDigest;
+
+@Component
+public class AiRedisKeys {
+
+ private static final String CONFIG_KEY_PREFIX = "AI:CONFIG:";
+ private static final String RUNTIME_KEY_PREFIX = "AI:RUNTIME:";
+ private static final String MEMORY_KEY_PREFIX = "AI:MEMORY:";
+ private static final String SESSIONS_KEY_PREFIX = "AI:SESSIONS:";
+ private static final String MCP_PREVIEW_KEY_PREFIX = "AI:MCP:PREVIEW:";
+ private static final String MCP_HEALTH_KEY_PREFIX = "AI:MCP:HEALTH:";
+ private static final String STREAM_STATE_KEY_PREFIX = "AI:STREAM:";
+ private static final String TOOL_RESULT_KEY_PREFIX = "AI:TOOL:RESULT:";
+ private static final String RATE_LIMIT_KEY_PREFIX = "AI:RATE:";
+ private static final String OBSERVE_STATS_KEY_PREFIX = "AI:OBSERVE:STATS:";
+ private static final String OBSERVE_TOOL_RANK_KEY_PREFIX = "AI:OBSERVE:TOOL:RANK:";
+ private static final String OBSERVE_TOOL_FAIL_RANK_KEY_PREFIX = "AI:OBSERVE:TOOL:FAIL:RANK:";
+ private static final String CONFIG_INDEX_PREFIX = "AI:IDX:CONFIG:";
+ private static final String CONVERSATION_INDEX_PREFIX = "AI:IDX:CONVERSATION:";
+ private static final String TENANT_RUNTIME_INDEX_PREFIX = "AI:IDX:RUNTIME:TENANT:";
+
+ public String buildConfigKey(Long tenantId, String promptCode, Long aiParamId) {
+ return CONFIG_KEY_PREFIX + tenantId + ":" + safeToken(promptCode) + ":" + aiParamToken(aiParamId);
+ }
+
+ public String buildRuntimeKey(Long tenantId, Long userId, String promptCode, Long sessionId, Long aiParamId) {
+ return RUNTIME_KEY_PREFIX + tenantId + ":" + userId + ":" + safeToken(promptCode) + ":" + sessionToken(sessionId) + ":" + aiParamToken(aiParamId);
+ }
+
+ public String buildMemoryKey(Long tenantId, Long userId, String promptCode, Long sessionId) {
+ return MEMORY_KEY_PREFIX + tenantId + ":" + userId + ":" + safeToken(promptCode) + ":" + sessionToken(sessionId);
+ }
+
+ public String buildSessionsKey(Long tenantId, Long userId, String promptCode, String keyword) {
+ return SESSIONS_KEY_PREFIX + tenantId + ":" + userId + ":" + safeToken(promptCode) + ":" + safeToken(keyword);
+ }
+
+ public String buildMcpPreviewKey(Long tenantId, Long mountId) {
+ return MCP_PREVIEW_KEY_PREFIX + tenantId + ":" + mountId;
+ }
+
+ public String buildMcpHealthKey(Long tenantId, Long mountId) {
+ return MCP_HEALTH_KEY_PREFIX + tenantId + ":" + mountId;
+ }
+
+ public String buildStreamStateKey(Long tenantId, String requestId) {
+ return STREAM_STATE_KEY_PREFIX + tenantId + ":" + safeToken(requestId);
+ }
+
+ public String buildToolResultKey(Long tenantId, String requestId, String toolName, String toolInput) {
+ return TOOL_RESULT_KEY_PREFIX + tenantId + ":" + safeToken(requestId) + ":" + safeToken(toolName) + ":" + digest(toolInput);
+ }
+
+ public String buildRateLimitKey(Long tenantId, Long userId, String promptCode) {
+ return RATE_LIMIT_KEY_PREFIX + tenantId + ":" + userId + ":" + safeToken(promptCode);
+ }
+
+ public String buildObserveStatsKey(Long tenantId) {
+ return OBSERVE_STATS_KEY_PREFIX + tenantId;
+ }
+
+ public String buildToolRankKey(Long tenantId) {
+ return OBSERVE_TOOL_RANK_KEY_PREFIX + tenantId;
+ }
+
+ public String buildToolFailRankKey(Long tenantId) {
+ return OBSERVE_TOOL_FAIL_RANK_KEY_PREFIX + tenantId;
+ }
+
+ public String buildConfigIndexKey(Long tenantId) {
+ return CONFIG_INDEX_PREFIX + tenantId;
+ }
+
+ public String buildConversationIndexKey(Long tenantId, Long userId) {
+ return CONVERSATION_INDEX_PREFIX + tenantId + ":" + userId;
+ }
+
+ public String buildTenantRuntimeIndexKey(Long tenantId) {
+ return TENANT_RUNTIME_INDEX_PREFIX + tenantId;
+ }
+
+ public String sessionToken(Long sessionId) {
+ return sessionId == null ? "LATEST" : String.valueOf(sessionId);
+ }
+
+ public String aiParamToken(Long aiParamId) {
+ return aiParamId == null ? "DEFAULT" : String.valueOf(aiParamId);
+ }
+
+ public String safeToken(String source) {
+ if (!StringUtils.hasText(source)) {
+ return "_";
+ }
+ return URLEncoder.encode(source.trim(), StandardCharsets.UTF_8);
+ }
+
+ public String digest(String source) {
+ try {
+ MessageDigest messageDigest = MessageDigest.getInstance("SHA-256");
+ byte[] bytes = messageDigest.digest((source == null ? "" : source).getBytes(StandardCharsets.UTF_8));
+ StringBuilder builder = new StringBuilder();
+ for (byte value : bytes) {
+ builder.append(String.format("%02x", value));
+ }
+ return builder.toString();
+ } catch (Exception e) {
+ return safeToken(source);
+ }
+ }
+}
diff --git a/rsf-server/src/main/resources/mapper/ai/AiChatMessageMapper.xml b/rsf-server/src/main/resources/mapper/ai/AiChatMessageMapper.xml
new file mode 100644
index 0000000..c1c187c
--- /dev/null
+++ b/rsf-server/src/main/resources/mapper/ai/AiChatMessageMapper.xml
@@ -0,0 +1,66 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!DOCTYPE mapper
+ PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
+ "http://mybatis.org/dtd/mybatis-3-mapper.dtd">
+<mapper namespace="com.vincent.rsf.server.ai.mapper.AiChatMessageMapper">
+
+ <insert id="insertBatch">
+ insert into sys_ai_chat_message
+ (session_id, seq_no, role, content, content_length, user_id, tenant_id, deleted, create_by, create_time)
+ values
+ <foreach collection="list" item="item" separator=",">
+ (#{item.sessionId}, #{item.seqNo}, #{item.role}, #{item.content}, #{item.contentLength},
+ #{item.userId}, #{item.tenantId}, #{item.deleted}, #{item.createBy}, #{item.createTime})
+ </foreach>
+ </insert>
+
+ <update id="softDeleteByIds">
+ update sys_ai_chat_message
+ set deleted = 1
+ where deleted = 0
+ and id in
+ <foreach collection="ids" item="id" open="(" separator="," close=")">
+ #{id}
+ </foreach>
+ </update>
+
+ <update id="softDeleteBySessionId">
+ update sys_ai_chat_message
+ set deleted = 1
+ where session_id = #{sessionId}
+ and deleted = 0
+ </update>
+
+ <select id="selectLatestMessagesBySessionIds" resultType="com.vincent.rsf.server.ai.entity.AiChatMessage">
+ select m.id,
+ m.session_id,
+ m.seq_no,
+ m.role,
+ m.content,
+ m.content_length,
+ m.user_id,
+ m.tenant_id,
+ m.deleted,
+ m.create_by,
+ m.create_time
+ from sys_ai_chat_message m
+ inner join (
+ select latest.session_id, max(candidate.id) as max_id
+ from sys_ai_chat_message candidate
+ inner join (
+ select session_id, max(seq_no) as max_seq_no
+ from sys_ai_chat_message
+ where deleted = 0
+ and session_id in
+ <foreach collection="sessionIds" item="sessionId" open="(" separator="," close=")">
+ #{sessionId}
+ </foreach>
+ group by session_id
+ ) latest on latest.session_id = candidate.session_id
+ and latest.max_seq_no = candidate.seq_no
+ where candidate.deleted = 0
+ group by latest.session_id
+ ) latest on latest.max_id = m.id
+ where m.deleted = 0
+ </select>
+</mapper>
diff --git a/rsf-server/src/main/resources/sql/ai/20260323_ai_indexes.sql b/rsf-server/src/main/resources/sql/ai/20260323_ai_indexes.sql
new file mode 100644
index 0000000..6afa077
--- /dev/null
+++ b/rsf-server/src/main/resources/sql/ai/20260323_ai_indexes.sql
@@ -0,0 +1,83 @@
+-- AI module performance indexes
+-- MySQL repeatable script: create each index only when the same index name does not already exist.
+-- If an existing index uses the same name but different columns, compare it manually before changing it.
+
+set @db_name = database();
+
+set @sql = if (
+ exists (
+ select 1
+ from information_schema.statistics
+ where table_schema = @db_name
+ and table_name = 'sys_ai_chat_session'
+ and index_name = 'idx_sys_ai_chat_session_user_prompt_active'
+ ),
+ 'select ''skip idx_sys_ai_chat_session_user_prompt_active''',
+ 'create index idx_sys_ai_chat_session_user_prompt_active on sys_ai_chat_session (tenant_id, user_id, prompt_code, deleted, status, pinned, last_message_time, id)'
+);
+prepare stmt from @sql;
+execute stmt;
+deallocate prepare stmt;
+
+set @sql = if (
+ exists (
+ select 1
+ from information_schema.statistics
+ where table_schema = @db_name
+ and table_name = 'sys_ai_chat_message'
+ and index_name = 'idx_sys_ai_chat_message_session_active_seq'
+ ),
+ 'select ''skip idx_sys_ai_chat_message_session_active_seq''',
+ 'create index idx_sys_ai_chat_message_session_active_seq on sys_ai_chat_message (session_id, deleted, seq_no, id)'
+);
+prepare stmt from @sql;
+execute stmt;
+deallocate prepare stmt;
+
+set @sql = if (
+ exists (
+ select 1
+ from information_schema.statistics
+ where table_schema = @db_name
+ and table_name = 'sys_ai_call_log'
+ and index_name = 'idx_sys_ai_call_log_tenant_status'
+ ),
+ 'select ''skip idx_sys_ai_call_log_tenant_status''',
+ 'create index idx_sys_ai_call_log_tenant_status on sys_ai_call_log (tenant_id, deleted, status, id)'
+);
+prepare stmt from @sql;
+execute stmt;
+deallocate prepare stmt;
+
+set @sql = if (
+ exists (
+ select 1
+ from information_schema.statistics
+ where table_schema = @db_name
+ and table_name = 'sys_ai_mcp_call_log'
+ and index_name = 'idx_sys_ai_mcp_call_log_call'
+ ),
+ 'select ''skip idx_sys_ai_mcp_call_log_call''',
+ 'create index idx_sys_ai_mcp_call_log_call on sys_ai_mcp_call_log (call_log_id, tenant_id, id)'
+);
+prepare stmt from @sql;
+execute stmt;
+deallocate prepare stmt;
+
+set @sql = if (
+ exists (
+ select 1
+ from information_schema.statistics
+ where table_schema = @db_name
+ and table_name = 'sys_ai_mcp_call_log'
+ and index_name = 'idx_sys_ai_mcp_call_log_tenant_status'
+ ),
+ 'select ''skip idx_sys_ai_mcp_call_log_tenant_status''',
+ 'create index idx_sys_ai_mcp_call_log_tenant_status on sys_ai_mcp_call_log (tenant_id, status, id)'
+);
+prepare stmt from @sql;
+execute stmt;
+deallocate prepare stmt;
+
+-- Deployment note:
+-- Clear legacy AI cache keys before the new version takes traffic to avoid stale payload compatibility issues.
diff --git a/rsf-server/src/test/java/com/vincent/rsf/server/AI/mapper/AiChatMessageMapperIntegrationTest.java b/rsf-server/src/test/java/com/vincent/rsf/server/AI/mapper/AiChatMessageMapperIntegrationTest.java
new file mode 100644
index 0000000..513ce32
--- /dev/null
+++ b/rsf-server/src/test/java/com/vincent/rsf/server/AI/mapper/AiChatMessageMapperIntegrationTest.java
@@ -0,0 +1,123 @@
+package com.vincent.rsf.server.ai.mapper;
+
+import com.baomidou.mybatisplus.core.MybatisConfiguration;
+import com.baomidou.mybatisplus.extension.spring.MybatisSqlSessionFactoryBean;
+import com.vincent.rsf.server.ai.entity.AiChatMessage;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.core.io.support.PathMatchingResourcePatternResolver;
+import org.springframework.jdbc.core.JdbcTemplate;
+import org.springframework.jdbc.datasource.DriverManagerDataSource;
+import org.springframework.test.context.ContextConfiguration;
+import org.springframework.test.context.junit.jupiter.SpringExtension;
+import org.mybatis.spring.annotation.MapperScan;
+
+import javax.sql.DataSource;
+import java.util.Date;
+import java.util.List;
+import java.util.stream.Collectors;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+@ExtendWith(SpringExtension.class)
+@ContextConfiguration(classes = AiChatMessageMapperIntegrationTest.TestConfig.class)
+class AiChatMessageMapperIntegrationTest {
+
+ @Autowired
+ private AiChatMessageMapper aiChatMessageMapper;
+
+ @Autowired
+ private JdbcTemplate jdbcTemplate;
+
+ @BeforeEach
+ void setUpSchema() {
+ jdbcTemplate.execute("drop table if exists sys_ai_chat_message");
+ jdbcTemplate.execute("""
+ create table sys_ai_chat_message (
+ id bigint auto_increment primary key,
+ session_id bigint,
+ seq_no integer,
+ role varchar(32),
+ content varchar(1000),
+ content_length integer,
+ user_id bigint,
+ tenant_id bigint,
+ deleted integer,
+ create_by bigint,
+ create_time timestamp
+ )
+ """);
+ }
+
+ @Test
+ void shouldInsertSelectLatestAndSoftDeleteMessages() {
+ Date now = new Date();
+ aiChatMessageMapper.insertBatch(List.of(
+ message(10L, 1, "user", "hello", now),
+ message(10L, 2, "assistant", "world", now),
+ message(10L, 2, "assistant", "world latest", now),
+ message(20L, 1, "user", "other", now)
+ ));
+
+ List<AiChatMessage> latest = aiChatMessageMapper.selectLatestMessagesBySessionIds(List.of(10L, 20L));
+
+ assertEquals(List.of(10L, 20L), latest.stream().map(AiChatMessage::getSessionId).sorted().collect(Collectors.toList()));
+ assertEquals("world latest", latest.stream().filter(item -> item.getSessionId().equals(10L)).findFirst().orElseThrow().getContent());
+
+ aiChatMessageMapper.softDeleteByIds(List.of(latest.stream().filter(item -> item.getSessionId().equals(20L)).findFirst().orElseThrow().getId()));
+ assertEquals(3, jdbcTemplate.queryForObject("select count(*) from sys_ai_chat_message where deleted = 0", Integer.class));
+
+ aiChatMessageMapper.softDeleteBySessionId(10L);
+ assertEquals(0, jdbcTemplate.queryForObject("select count(*) from sys_ai_chat_message where deleted = 0", Integer.class));
+ }
+
+ private AiChatMessage message(Long sessionId, int seqNo, String role, String content, Date createTime) {
+ return new AiChatMessage()
+ .setSessionId(sessionId)
+ .setSeqNo(seqNo)
+ .setRole(role)
+ .setContent(content)
+ .setContentLength(content.length())
+ .setUserId(1L)
+ .setTenantId(1L)
+ .setDeleted(0)
+ .setCreateBy(1L)
+ .setCreateTime(createTime);
+ }
+
+ @Configuration
+ @MapperScan(basePackageClasses = AiChatMessageMapper.class)
+ static class TestConfig {
+
+ @Bean
+ DataSource dataSource() {
+ DriverManagerDataSource dataSource = new DriverManagerDataSource();
+ dataSource.setDriverClassName("org.h2.Driver");
+ dataSource.setUrl("jdbc:h2:mem:ai-chat-message;MODE=MySQL;DB_CLOSE_DELAY=-1");
+ dataSource.setUsername("sa");
+ dataSource.setPassword("");
+ return dataSource;
+ }
+
+ @Bean
+ JdbcTemplate jdbcTemplate(DataSource dataSource) {
+ return new JdbcTemplate(dataSource);
+ }
+
+ @Bean
+ MybatisSqlSessionFactoryBean sqlSessionFactory(DataSource dataSource) throws Exception {
+ MybatisSqlSessionFactoryBean factoryBean = new MybatisSqlSessionFactoryBean();
+ factoryBean.setDataSource(dataSource);
+ factoryBean.setMapperLocations(new PathMatchingResourcePatternResolver()
+ .getResources("classpath*:mapper/ai/*.xml"));
+ MybatisConfiguration configuration = new MybatisConfiguration();
+ configuration.setMapUnderscoreToCamelCase(true);
+ factoryBean.setConfiguration(configuration);
+ return factoryBean;
+ }
+ }
+}
diff --git a/rsf-server/src/test/java/com/vincent/rsf/server/AI/service/impl/chat/AiChatOrchestratorTest.java b/rsf-server/src/test/java/com/vincent/rsf/server/AI/service/impl/chat/AiChatOrchestratorTest.java
new file mode 100644
index 0000000..9e137e4
--- /dev/null
+++ b/rsf-server/src/test/java/com/vincent/rsf/server/AI/service/impl/chat/AiChatOrchestratorTest.java
@@ -0,0 +1,105 @@
+package com.vincent.rsf.server.ai.service.impl.chat;
+
+import com.vincent.rsf.server.ai.dto.AiChatRequest;
+import com.vincent.rsf.server.ai.dto.AiChatMessageDto;
+import com.vincent.rsf.server.ai.dto.AiResolvedConfig;
+import com.vincent.rsf.server.ai.entity.AiParam;
+import com.vincent.rsf.server.ai.entity.AiPrompt;
+import com.vincent.rsf.server.ai.service.AiCallLogService;
+import com.vincent.rsf.server.ai.service.AiChatMemoryService;
+import com.vincent.rsf.server.ai.service.AiConfigResolverService;
+import com.vincent.rsf.server.ai.service.AiParamService;
+import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory;
+import com.vincent.rsf.server.ai.store.AiChatRateLimiter;
+import com.vincent.rsf.server.ai.store.AiStreamStateStore;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
+import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
+
+import java.util.List;
+
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+@ExtendWith(MockitoExtension.class)
+class AiChatOrchestratorTest {
+
+ @Mock
+ private AiConfigResolverService aiConfigResolverService;
+ @Mock
+ private AiChatMemoryService aiChatMemoryService;
+ @Mock
+ private AiParamService aiParamService;
+ @Mock
+ private McpMountRuntimeFactory mcpMountRuntimeFactory;
+ @Mock
+ private AiCallLogService aiCallLogService;
+ @Mock
+ private AiChatRateLimiter aiChatRateLimiter;
+ @Mock
+ private AiStreamStateStore aiStreamStateStore;
+ @Mock
+ private AiChatRuntimeAssembler aiChatRuntimeAssembler;
+ @Mock
+ private AiPromptMessageBuilder aiPromptMessageBuilder;
+ @Mock
+ private AiOpenAiChatModelFactory aiOpenAiChatModelFactory;
+ @Mock
+ private AiToolObservationService aiToolObservationService;
+ @Mock
+ private AiSseEventPublisher aiSseEventPublisher;
+ @Mock
+ private AiChatFailureHandler aiChatFailureHandler;
+
+ @Test
+ void shouldShortCircuitWhenRateLimited() {
+ AiChatOrchestrator aiChatOrchestrator = new AiChatOrchestrator(
+ aiConfigResolverService,
+ aiChatMemoryService,
+ aiParamService,
+ mcpMountRuntimeFactory,
+ aiCallLogService,
+ aiChatRateLimiter,
+ aiStreamStateStore,
+ aiChatRuntimeAssembler,
+ aiPromptMessageBuilder,
+ aiOpenAiChatModelFactory,
+ aiToolObservationService,
+ aiSseEventPublisher,
+ aiChatFailureHandler
+ );
+ AiChatRequest request = new AiChatRequest();
+ request.setRequestId("req-1");
+ request.setPromptCode("home.default");
+ request.setMessages(List.of(message("user", "hello")));
+ AiResolvedConfig config = AiResolvedConfig.builder()
+ .promptCode("home.default")
+ .aiParam(new AiParam().setModel("gpt-test"))
+ .prompt(new AiPrompt().setName("default"))
+ .mcpMounts(List.of())
+ .build();
+ when(aiConfigResolverService.resolve("home.default", 1L, null)).thenReturn(config);
+ when(aiParamService.listChatModelOptions(1L)).thenReturn(List.of());
+ when(aiChatRateLimiter.allowChatRequest(1L, 2L, "home.default")).thenReturn(false);
+ when(aiChatFailureHandler.buildAiException(eq("AI_RATE_LIMITED"), any(), eq("RATE_LIMIT"), any(), eq(null)))
+ .thenCallRealMethod();
+
+ aiChatOrchestrator.executeStream(request, 2L, 1L, new SseEmitter(1000L));
+
+ verify(aiChatFailureHandler).handleStreamFailure(any(), eq("req-1"), eq(null), eq(null), any(Long.class),
+ eq(null), any(), eq(null), eq(0L), eq(0L), eq(null), eq(1L), eq(2L), eq("home.default"));
+ verify(aiCallLogService, never()).startCallLog(any(), any(), any(), any(), any(), any(), any(), any(), any(), any());
+ }
+
+ private AiChatMessageDto message(String role, String content) {
+ AiChatMessageDto dto = new AiChatMessageDto();
+ dto.setRole(role);
+ dto.setContent(content);
+ return dto;
+ }
+}
diff --git a/rsf-server/src/test/java/com/vincent/rsf/server/AI/service/impl/chat/AiPromptMessageBuilderTest.java b/rsf-server/src/test/java/com/vincent/rsf/server/AI/service/impl/chat/AiPromptMessageBuilderTest.java
new file mode 100644
index 0000000..41cd3c0
--- /dev/null
+++ b/rsf-server/src/test/java/com/vincent/rsf/server/AI/service/impl/chat/AiPromptMessageBuilderTest.java
@@ -0,0 +1,68 @@
+package com.vincent.rsf.server.ai.service.impl.chat;
+
+import com.vincent.rsf.server.ai.dto.AiChatMemoryDto;
+import com.vincent.rsf.server.ai.dto.AiChatMessageDto;
+import com.vincent.rsf.server.ai.entity.AiPrompt;
+import org.junit.jupiter.api.Test;
+import org.springframework.ai.chat.messages.Message;
+import org.springframework.ai.chat.messages.SystemMessage;
+import org.springframework.ai.chat.messages.UserMessage;
+
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertInstanceOf;
+
+class AiPromptMessageBuilderTest {
+
+ private final AiPromptMessageBuilder builder = new AiPromptMessageBuilder();
+
+ @Test
+ void shouldBuildPromptMessagesInExpectedOrderAndRenderLastUserPrompt() {
+ AiChatMemoryDto memory = AiChatMemoryDto.builder()
+ .memorySummary("summary")
+ .memoryFacts("facts")
+ .build();
+ AiPrompt prompt = new AiPrompt()
+ .setSystemPrompt("system")
+ .setUserPromptTemplate("鐢ㄦ埛闂: {{input}} | 浠撳簱: {{warehouse}}");
+ List<AiChatMessageDto> messages = List.of(
+ message("user", "old question"),
+ message("assistant", "old answer"),
+ message("user", "latest question")
+ );
+
+ List<Message> built = builder.buildPromptMessages(memory, messages, prompt, Map.of("warehouse", "WH1"));
+
+ assertEquals(6, built.size());
+ assertInstanceOf(SystemMessage.class, built.get(0));
+ assertEquals("system", built.get(0).getText());
+ assertEquals("鍘嗗彶鎽樿:\nsummary", built.get(1).getText());
+ assertEquals("鍏抽敭浜嬪疄:\nfacts", built.get(2).getText());
+ assertInstanceOf(UserMessage.class, built.get(3));
+ assertEquals("old question", built.get(3).getText());
+ assertEquals("old answer", built.get(4).getText());
+ assertInstanceOf(UserMessage.class, built.get(5));
+ assertEquals("鐢ㄦ埛闂: latest question | 浠撳簱: WH1", built.get(5).getText());
+ }
+
+ @Test
+ void shouldMergePersistedAndMemoryMessages() {
+ List<AiChatMessageDto> merged = builder.mergeMessages(
+ List.of(message("user", "persisted")),
+ List.of(message("assistant", "memory"))
+ );
+
+ assertEquals(2, merged.size());
+ assertEquals("persisted", merged.get(0).getContent());
+ assertEquals("memory", merged.get(1).getContent());
+ }
+
+ private AiChatMessageDto message(String role, String content) {
+ AiChatMessageDto dto = new AiChatMessageDto();
+ dto.setRole(role);
+ dto.setContent(content);
+ return dto;
+ }
+}
diff --git a/rsf-server/src/test/java/com/vincent/rsf/server/AI/service/impl/conversation/AiConversationCommandServiceTest.java b/rsf-server/src/test/java/com/vincent/rsf/server/AI/service/impl/conversation/AiConversationCommandServiceTest.java
new file mode 100644
index 0000000..fb31dd1
--- /dev/null
+++ b/rsf-server/src/test/java/com/vincent/rsf/server/AI/service/impl/conversation/AiConversationCommandServiceTest.java
@@ -0,0 +1,113 @@
+package com.vincent.rsf.server.ai.service.impl.conversation;
+
+import com.vincent.rsf.server.ai.dto.AiChatMessageDto;
+import com.vincent.rsf.server.ai.entity.AiChatMessage;
+import com.vincent.rsf.server.ai.entity.AiChatSession;
+import com.vincent.rsf.server.ai.mapper.AiChatMessageMapper;
+import com.vincent.rsf.server.ai.mapper.AiChatSessionMapper;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.ArgumentCaptor;
+import org.mockito.InjectMocks;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
+import org.springframework.transaction.support.TransactionSynchronization;
+import org.springframework.transaction.support.TransactionSynchronizationManager;
+
+import java.util.List;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.mockito.Mockito.never;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+@ExtendWith(MockitoExtension.class)
+class AiConversationCommandServiceTest {
+
+ @Mock
+ private AiChatSessionMapper aiChatSessionMapper;
+ @Mock
+ private AiChatMessageMapper aiChatMessageMapper;
+ @Mock
+ private AiConversationQueryService aiConversationQueryService;
+ @Mock
+ private AiMemoryProfileService aiMemoryProfileService;
+
+ @InjectMocks
+ private AiConversationCommandService aiConversationCommandService;
+
+ @Test
+ void shouldBatchInsertMessagesAndScheduleRefreshWhenSavingRound() {
+ AiChatSession session = new AiChatSession().setId(10L).setTitle(null);
+ when(aiConversationQueryService.findNextSeqNo(10L)).thenReturn(3);
+ when(aiConversationQueryService.resolveUpdatedTitle(eq(null), any())).thenReturn("new title");
+
+ aiConversationCommandService.saveRound(session, 7L, 8L,
+ List.of(message("user", "hello"), message("assistant", "draft answer")),
+ "final answer");
+
+ ArgumentCaptor<List<AiChatMessage>> captor = ArgumentCaptor.forClass(List.class);
+ verify(aiChatMessageMapper).insertBatch(captor.capture());
+ assertEquals(3, captor.getValue().size());
+ assertEquals(3, captor.getValue().get(0).getSeqNo());
+ assertEquals(5, captor.getValue().get(2).getSeqNo());
+ verify(aiMemoryProfileService).scheduleMemoryProfileRefresh(10L, 7L, 8L);
+ verify(aiConversationQueryService).evictConversationCaches(8L, 7L);
+ }
+
+ @Test
+ void shouldDeferConversationSideEffectsUntilAfterCommitWhenSynchronizationActive() {
+ AiChatSession session = new AiChatSession().setId(10L).setTitle(null);
+ when(aiConversationQueryService.findNextSeqNo(10L)).thenReturn(1);
+ when(aiConversationQueryService.resolveUpdatedTitle(eq(null), any())).thenReturn("new title");
+
+ TransactionSynchronizationManager.initSynchronization();
+ try {
+ aiConversationCommandService.saveRound(session, 7L, 8L, List.of(message("user", "hello")), "final answer");
+
+ verify(aiConversationQueryService, never()).evictConversationCaches(8L, 7L);
+ verify(aiMemoryProfileService, never()).scheduleMemoryProfileRefresh(10L, 7L, 8L);
+
+ for (TransactionSynchronization synchronization : TransactionSynchronizationManager.getSynchronizations()) {
+ synchronization.afterCommit();
+ }
+
+ verify(aiConversationQueryService).evictConversationCaches(8L, 7L);
+ verify(aiMemoryProfileService).scheduleMemoryProfileRefresh(10L, 7L, 8L);
+ } finally {
+ TransactionSynchronizationManager.clearSynchronization();
+ }
+ }
+
+ @Test
+ void shouldSoftDeleteOnlyOlderMessagesWhenRetainingLatestRound() {
+ when(aiConversationQueryService.listMessageRecords(10L)).thenReturn(List.of(
+ record(1L, "user"),
+ record(2L, "assistant"),
+ record(3L, "user"),
+ record(4L, "assistant")
+ ));
+ when(aiConversationQueryService.tailMessageRecordsByRounds(any(), eq(1))).thenReturn(List.of(
+ record(3L, "user"),
+ record(4L, "assistant")
+ ));
+
+ aiConversationCommandService.retainLatestRound(7L, 8L, 10L);
+
+ verify(aiChatMessageMapper).softDeleteByIds(List.of(1L, 2L));
+ verify(aiMemoryProfileService).scheduleMemoryProfileRefresh(10L, 7L, 8L);
+ }
+
+ private AiChatMessageDto message(String role, String content) {
+ AiChatMessageDto dto = new AiChatMessageDto();
+ dto.setRole(role);
+ dto.setContent(content);
+ return dto;
+ }
+
+ private AiChatMessage record(Long id, String role) {
+ return new AiChatMessage().setId(id).setRole(role);
+ }
+}
diff --git a/rsf-server/src/test/java/com/vincent/rsf/server/AI/service/impl/mcp/AiMcpAdminServiceTest.java b/rsf-server/src/test/java/com/vincent/rsf/server/AI/service/impl/mcp/AiMcpAdminServiceTest.java
new file mode 100644
index 0000000..e983b0b
--- /dev/null
+++ b/rsf-server/src/test/java/com/vincent/rsf/server/AI/service/impl/mcp/AiMcpAdminServiceTest.java
@@ -0,0 +1,76 @@
+package com.vincent.rsf.server.ai.service.impl.mcp;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.vincent.rsf.server.ai.config.AiDefaults;
+import com.vincent.rsf.server.ai.dto.AiMcpToolPreviewDto;
+import com.vincent.rsf.server.ai.entity.AiMcpMount;
+import com.vincent.rsf.server.ai.mapper.AiMcpMountMapper;
+import com.vincent.rsf.server.ai.service.BuiltinMcpToolRegistry;
+import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
+import org.springframework.ai.tool.ToolCallback;
+import org.springframework.ai.tool.definition.ToolDefinition;
+import org.springframework.ai.tool.metadata.ToolMetadata;
+
+import java.util.List;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+@ExtendWith(MockitoExtension.class)
+class AiMcpAdminServiceTest {
+
+ @Mock
+ private AiMcpMountMapper aiMcpMountMapper;
+ @Mock
+ private BuiltinMcpToolRegistry builtinMcpToolRegistry;
+ @Mock
+ private McpMountRuntimeFactory mcpMountRuntimeFactory;
+
+ @Test
+ void shouldMergeGovernedCatalogIntoPreviewResult() {
+ AiMcpAdminService aiMcpAdminService = new AiMcpAdminService(
+ aiMcpMountMapper,
+ builtinMcpToolRegistry,
+ mcpMountRuntimeFactory,
+ new ObjectMapper()
+ );
+ AiMcpMount mount = new AiMcpMount()
+ .setTransportType(AiDefaults.MCP_TRANSPORT_BUILTIN)
+ .setBuiltinCode(AiDefaults.MCP_BUILTIN_RSF_WMS)
+ .setName("builtin");
+ ToolCallback callback = mock(ToolCallback.class);
+ ToolDefinition toolDefinition = mock(ToolDefinition.class);
+ ToolMetadata toolMetadata = mock(ToolMetadata.class);
+ McpMountRuntimeFactory.McpMountRuntime runtime = mock(McpMountRuntimeFactory.McpMountRuntime.class);
+ when(toolDefinition.name()).thenReturn("rsf_query_task");
+ when(toolDefinition.description()).thenReturn("desc");
+ when(toolDefinition.inputSchema()).thenReturn("{schema}");
+ when(toolMetadata.returnDirect()).thenReturn(Boolean.FALSE);
+ when(callback.getToolDefinition()).thenReturn(toolDefinition);
+ when(callback.getToolMetadata()).thenReturn(toolMetadata);
+ when(runtime.getToolCallbacks()).thenReturn(new ToolCallback[]{callback});
+ when(runtime.getErrors()).thenReturn(List.of());
+ when(mcpMountRuntimeFactory.create(List.of(mount), 1L)).thenReturn(runtime);
+ when(builtinMcpToolRegistry.listBuiltinToolCatalog(AiDefaults.MCP_BUILTIN_RSF_WMS)).thenReturn(List.of(
+ AiMcpToolPreviewDto.builder()
+ .name("rsf_query_task")
+ .toolGroup("浠诲姟鏌ヨ")
+ .toolPurpose("purpose")
+ .queryBoundary("boundary")
+ .exampleQuestions(List.of("q1"))
+ .build()
+ ));
+
+ List<AiMcpToolPreviewDto> result = aiMcpAdminService.previewTools(mount, 1L);
+
+ assertEquals(1, result.size());
+ assertEquals("浠诲姟鏌ヨ", result.get(0).getToolGroup());
+ assertEquals("purpose", result.get(0).getToolPurpose());
+ assertEquals("{schema}", result.get(0).getInputSchema());
+ }
+}
diff --git a/rsf-server/src/test/java/com/vincent/rsf/server/AI/store/support/AiRedisKeysTest.java b/rsf-server/src/test/java/com/vincent/rsf/server/AI/store/support/AiRedisKeysTest.java
new file mode 100644
index 0000000..ba6e361
--- /dev/null
+++ b/rsf-server/src/test/java/com/vincent/rsf/server/AI/store/support/AiRedisKeysTest.java
@@ -0,0 +1,30 @@
+package com.vincent.rsf.server.ai.store.support;
+
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+class AiRedisKeysTest {
+
+ private final AiRedisKeys aiRedisKeys = new AiRedisKeys();
+
+ @Test
+ void shouldBuildStableRuntimeAndToolKeys() {
+ String runtimeKey = aiRedisKeys.buildRuntimeKey(1L, 2L, "home.default", 3L, 4L);
+ String toolKeyA = aiRedisKeys.buildToolResultKey(1L, "req-1", "tool", "{\"a\":1}");
+ String toolKeyB = aiRedisKeys.buildToolResultKey(1L, "req-1", "tool", "{\"a\":2}");
+
+ assertEquals("AI:RUNTIME:1:2:home.default:3:4", runtimeKey);
+ assertTrue(toolKeyA.startsWith("AI:TOOL:RESULT:1:req-1:tool:"));
+ assertNotEquals(toolKeyA, toolKeyB);
+ }
+
+ @Test
+ void shouldUseSentinelTokensForNullableValues() {
+ assertEquals("LATEST", aiRedisKeys.sessionToken(null));
+ assertEquals("DEFAULT", aiRedisKeys.aiParamToken(null));
+ assertEquals("_", aiRedisKeys.safeToken(null));
+ }
+}
--
Gitblit v1.9.1