package com.vincent.rsf.server.ai.service.impl; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.vincent.rsf.server.ai.config.AiDefaults; import com.vincent.rsf.server.ai.dto.AiChatMemoryDto; import com.vincent.rsf.server.ai.dto.AiChatRuntimeDto; import com.vincent.rsf.server.ai.dto.AiChatSessionDto; import com.vincent.rsf.server.ai.dto.AiMcpConnectivityTestDto; import com.vincent.rsf.server.ai.dto.AiMcpToolPreviewDto; import com.vincent.rsf.server.ai.dto.AiObserveStatsDto; import com.vincent.rsf.server.ai.dto.AiResolvedConfig; import com.vincent.rsf.server.common.service.RedisService; import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; import org.springframework.util.StringUtils; import redis.clients.jedis.Jedis; import java.net.URLEncoder; import java.nio.charset.StandardCharsets; import java.security.MessageDigest; import java.time.Instant; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; import java.util.UUID; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; @Slf4j @Service @RequiredArgsConstructor public class AiRedisSupport { /** 统一收口 AI 模块的 Redis key、TTL 和序列化策略,避免业务类直接散写 Redis。 */ private static final String CONFIG_KEY_PREFIX = "AI:CONFIG:"; private static final String RUNTIME_KEY_PREFIX = "AI:RUNTIME:"; private static final String MEMORY_KEY_PREFIX = "AI:MEMORY:"; private static final String SESSIONS_KEY_PREFIX = "AI:SESSIONS:"; private static final String MCP_PREVIEW_KEY_PREFIX = "AI:MCP:PREVIEW:"; private static final String MCP_HEALTH_KEY_PREFIX = "AI:MCP:HEALTH:"; private static final String STREAM_STATE_KEY_PREFIX = "AI:STREAM:"; private static final String TOOL_RESULT_KEY_PREFIX = "AI:TOOL:RESULT:"; private static final String RATE_LIMIT_KEY_PREFIX = "AI:RATE:"; private static final String OBSERVE_STATS_KEY_PREFIX = "AI:OBSERVE:STATS:"; private static final String OBSERVE_TOOL_RANK_KEY_PREFIX = "AI:OBSERVE:TOOL:RANK:"; private static final String OBSERVE_TOOL_FAIL_RANK_KEY_PREFIX = "AI:OBSERVE:TOOL:FAIL:RANK:"; private static final String FIELD_CALL_COUNT = "callCount"; private static final String FIELD_SUCCESS_COUNT = "successCount"; private static final String FIELD_FAILURE_COUNT = "failureCount"; private static final String FIELD_ELAPSED_SUM = "elapsedSum"; private static final String FIELD_ELAPSED_COUNT = "elapsedCount"; private static final String FIELD_FIRST_TOKEN_SUM = "firstTokenSum"; private static final String FIELD_FIRST_TOKEN_COUNT = "firstTokenCount"; private static final String FIELD_TOTAL_TOKENS_SUM = "totalTokensSum"; private static final String FIELD_TOTAL_TOKENS_COUNT = "totalTokensCount"; private static final String FIELD_TOOL_CALL_COUNT = "toolCallCount"; private static final String FIELD_TOOL_SUCCESS_COUNT = "toolSuccessCount"; private static final String FIELD_TOOL_FAILURE_COUNT = "toolFailureCount"; private final RedisService redisService; private final ObjectMapper objectMapper; public AiResolvedConfig getResolvedConfig(Long tenantId, String promptCode, Long aiParamId) { return readJson(buildConfigKey(tenantId, promptCode, aiParamId), AiResolvedConfig.class); } public void cacheResolvedConfig(Long tenantId, String promptCode, Long aiParamId, AiResolvedConfig config) { writeJson(buildConfigKey(tenantId, promptCode, aiParamId), config, AiDefaults.CONFIG_CACHE_TTL_SECONDS); } public void evictTenantConfigCaches(Long tenantId) { deleteByPrefix(CONFIG_KEY_PREFIX + tenantId + ":"); deleteByPrefix(RUNTIME_KEY_PREFIX + tenantId + ":"); } public AiChatRuntimeDto getRuntime(Long tenantId, Long userId, String promptCode, Long sessionId, Long aiParamId) { return readJson(buildRuntimeKey(tenantId, userId, promptCode, sessionId, aiParamId), AiChatRuntimeDto.class); } public void cacheRuntime(Long tenantId, Long userId, String promptCode, Long sessionId, Long aiParamId, AiChatRuntimeDto runtime) { writeJson(buildRuntimeKey(tenantId, userId, promptCode, sessionId, aiParamId), runtime, AiDefaults.RUNTIME_CACHE_TTL_SECONDS); } public AiChatMemoryDto getMemory(Long tenantId, Long userId, String promptCode, Long sessionId) { return readJson(buildMemoryKey(tenantId, userId, promptCode, sessionId), AiChatMemoryDto.class); } public void cacheMemory(Long tenantId, Long userId, String promptCode, Long sessionId, AiChatMemoryDto memory) { writeJson(buildMemoryKey(tenantId, userId, promptCode, sessionId), memory, AiDefaults.MEMORY_CACHE_TTL_SECONDS); } public List getSessionList(Long tenantId, Long userId, String promptCode, String keyword) { return readJson(buildSessionsKey(tenantId, userId, promptCode, keyword), new TypeReference>() { }); } public void cacheSessionList(Long tenantId, Long userId, String promptCode, String keyword, List sessions) { writeJson(buildSessionsKey(tenantId, userId, promptCode, keyword), sessions, AiDefaults.SESSION_LIST_CACHE_TTL_SECONDS); } public void evictUserConversationCaches(Long tenantId, Long userId) { deleteByPrefix(RUNTIME_KEY_PREFIX + tenantId + ":" + userId + ":"); deleteByPrefix(MEMORY_KEY_PREFIX + tenantId + ":" + userId + ":"); deleteByPrefix(SESSIONS_KEY_PREFIX + tenantId + ":" + userId + ":"); } public List getToolPreview(Long tenantId, Long mountId) { return readJson(buildMcpPreviewKey(tenantId, mountId), new TypeReference>() { }); } public void cacheToolPreview(Long tenantId, Long mountId, List tools) { writeJson(buildMcpPreviewKey(tenantId, mountId), tools, AiDefaults.MCP_PREVIEW_CACHE_TTL_SECONDS); } public AiMcpConnectivityTestDto getConnectivity(Long tenantId, Long mountId) { return readJson(buildMcpHealthKey(tenantId, mountId), AiMcpConnectivityTestDto.class); } public void cacheConnectivity(Long tenantId, Long mountId, AiMcpConnectivityTestDto connectivity) { writeJson(buildMcpHealthKey(tenantId, mountId), connectivity, AiDefaults.MCP_HEALTH_CACHE_TTL_SECONDS); } public void evictMcpMountCaches(Long tenantId, Long mountId) { if (mountId != null) { delete(buildMcpPreviewKey(tenantId, mountId)); delete(buildMcpHealthKey(tenantId, mountId)); } else { deleteByPrefix(MCP_PREVIEW_KEY_PREFIX + tenantId + ":"); deleteByPrefix(MCP_HEALTH_KEY_PREFIX + tenantId + ":"); } evictTenantConfigCaches(tenantId); } public boolean allowChatRequest(Long tenantId, Long userId, String promptCode) { String key = buildRateLimitKey(tenantId, userId, promptCode); long now = Instant.now().toEpochMilli(); long windowStart = now - (AiDefaults.CHAT_RATE_LIMIT_WINDOW_SECONDS * 1000L); Boolean allowed = execute(jedis -> { // 用 zset 维护滑动窗口,而不是简单计数器,避免窗口边界出现突刺误判。 jedis.zremrangeByScore(key, 0, windowStart); long count = jedis.zcard(key); if (count >= AiDefaults.CHAT_RATE_LIMIT_MAX_REQUESTS) { jedis.expire(key, AiDefaults.CHAT_RATE_LIMIT_WINDOW_SECONDS); return Boolean.FALSE; } jedis.zadd(key, now, now + ":" + UUID.randomUUID()); jedis.expire(key, AiDefaults.CHAT_RATE_LIMIT_WINDOW_SECONDS); return Boolean.TRUE; }); return Boolean.TRUE.equals(allowed); } public void markStreamState(String requestId, Long tenantId, Long userId, Long sessionId, String promptCode, String status, String errorMessage) { if (!StringUtils.hasText(requestId)) { return; } writeJson(buildStreamStateKey(tenantId, requestId), AiStreamState.builder() .requestId(requestId) .tenantId(tenantId) .userId(userId) .sessionId(sessionId) .promptCode(promptCode) .status(status) .errorMessage(errorMessage) .timestamp(Instant.now().toEpochMilli()) .build(), AiDefaults.STREAM_STATE_TTL_SECONDS); } public CachedToolResult getToolResult(Long tenantId, String requestId, String toolName, String toolInput) { return readJson(buildToolResultKey(tenantId, requestId, toolName, toolInput), CachedToolResult.class); } public void cacheToolResult(Long tenantId, String requestId, String toolName, String toolInput, boolean success, String output, String errorMessage) { writeJson(buildToolResultKey(tenantId, requestId, toolName, toolInput), CachedToolResult.builder() .success(success) .output(output) .errorMessage(errorMessage) .build(), AiDefaults.TOOL_RESULT_CACHE_TTL_SECONDS); } public void recordObserveCallStarted(Long tenantId) { executeVoid(jedis -> jedis.hincrBy(buildObserveStatsKey(tenantId), FIELD_CALL_COUNT, 1)); } public void recordObserveCallFinished(Long tenantId, String status, Long elapsedMs, Long firstTokenLatencyMs, Integer totalTokens) { executeVoid(jedis -> { String key = buildObserveStatsKey(tenantId); if ("COMPLETED".equals(status)) { jedis.hincrBy(key, FIELD_SUCCESS_COUNT, 1); } else if ("FAILED".equals(status)) { jedis.hincrBy(key, FIELD_FAILURE_COUNT, 1); } if (elapsedMs != null) { jedis.hincrBy(key, FIELD_ELAPSED_SUM, elapsedMs); jedis.hincrBy(key, FIELD_ELAPSED_COUNT, 1); } if (firstTokenLatencyMs != null) { jedis.hincrBy(key, FIELD_FIRST_TOKEN_SUM, firstTokenLatencyMs); jedis.hincrBy(key, FIELD_FIRST_TOKEN_COUNT, 1); } if (totalTokens != null) { jedis.hincrBy(key, FIELD_TOTAL_TOKENS_SUM, totalTokens.longValue()); jedis.hincrBy(key, FIELD_TOTAL_TOKENS_COUNT, 1); } }); } public void recordObserveToolCall(Long tenantId, String toolName, String status) { executeVoid(jedis -> { String key = buildObserveStatsKey(tenantId); jedis.hincrBy(key, FIELD_TOOL_CALL_COUNT, 1); if ("COMPLETED".equals(status)) { jedis.hincrBy(key, FIELD_TOOL_SUCCESS_COUNT, 1); } else if ("FAILED".equals(status)) { jedis.hincrBy(key, FIELD_TOOL_FAILURE_COUNT, 1); } if (StringUtils.hasText(toolName)) { jedis.zincrby(buildToolRankKey(tenantId), 1D, toolName); if ("FAILED".equals(status)) { jedis.zincrby(buildToolFailRankKey(tenantId), 1D, toolName); } } }); } public AiObserveStatsDto getObserveStats(Long tenantId, Supplier fallbackLoader) { AiObserveStatsDto cached = readObserveStats(tenantId); if (cached != null) { return cached; } // Redis 为空时再回源数据库,避免管理端看板每次都扫全量日志表。 AiObserveStatsDto snapshot = fallbackLoader.get(); if (snapshot != null) { seedObserveStats(tenantId, snapshot); } return snapshot; } private AiObserveStatsDto readObserveStats(Long tenantId) { Map fields = execute(jedis -> { String key = buildObserveStatsKey(tenantId); if (!jedis.exists(key)) { return null; } return jedis.hgetAll(key); }); if (fields == null || fields.isEmpty()) { return null; } long callCount = parseLong(fields.get(FIELD_CALL_COUNT)); long successCount = parseLong(fields.get(FIELD_SUCCESS_COUNT)); long failureCount = parseLong(fields.get(FIELD_FAILURE_COUNT)); long elapsedSum = parseLong(fields.get(FIELD_ELAPSED_SUM)); long elapsedCount = parseLong(fields.get(FIELD_ELAPSED_COUNT)); long firstTokenSum = parseLong(fields.get(FIELD_FIRST_TOKEN_SUM)); long firstTokenCount = parseLong(fields.get(FIELD_FIRST_TOKEN_COUNT)); long totalTokensSum = parseLong(fields.get(FIELD_TOTAL_TOKENS_SUM)); long totalTokensCount = parseLong(fields.get(FIELD_TOTAL_TOKENS_COUNT)); long toolCallCount = parseLong(fields.get(FIELD_TOOL_CALL_COUNT)); long toolSuccessCount = parseLong(fields.get(FIELD_TOOL_SUCCESS_COUNT)); long toolFailureCount = parseLong(fields.get(FIELD_TOOL_FAILURE_COUNT)); return AiObserveStatsDto.builder() .callCount(callCount) .successCount(successCount) .failureCount(failureCount) .avgElapsedMs(elapsedCount == 0 ? 0L : elapsedSum / elapsedCount) .avgFirstTokenLatencyMs(firstTokenCount == 0 ? 0L : firstTokenSum / firstTokenCount) .totalTokens(totalTokensSum) .avgTotalTokens(totalTokensCount == 0 ? 0L : totalTokensSum / totalTokensCount) .toolCallCount(toolCallCount) .toolSuccessCount(toolSuccessCount) .toolFailureCount(toolFailureCount) .toolSuccessRate(toolCallCount == 0 ? 0D : (toolSuccessCount * 100D) / toolCallCount) .build(); } private void seedObserveStats(Long tenantId, AiObserveStatsDto snapshot) { executeVoid(jedis -> { String key = buildObserveStatsKey(tenantId); Map values = new LinkedHashMap<>(); values.put(FIELD_CALL_COUNT, String.valueOf(defaultLong(snapshot.getCallCount()))); values.put(FIELD_SUCCESS_COUNT, String.valueOf(defaultLong(snapshot.getSuccessCount()))); values.put(FIELD_FAILURE_COUNT, String.valueOf(defaultLong(snapshot.getFailureCount()))); values.put(FIELD_ELAPSED_SUM, String.valueOf(defaultLong(snapshot.getAvgElapsedMs()) * defaultLong(snapshot.getCallCount()))); values.put(FIELD_ELAPSED_COUNT, String.valueOf(defaultLong(snapshot.getCallCount()))); values.put(FIELD_FIRST_TOKEN_SUM, String.valueOf(defaultLong(snapshot.getAvgFirstTokenLatencyMs()) * defaultLong(snapshot.getCallCount()))); values.put(FIELD_FIRST_TOKEN_COUNT, String.valueOf(defaultLong(snapshot.getCallCount()))); values.put(FIELD_TOTAL_TOKENS_SUM, String.valueOf(defaultLong(snapshot.getTotalTokens()))); values.put(FIELD_TOTAL_TOKENS_COUNT, String.valueOf(defaultLong(snapshot.getCallCount()))); values.put(FIELD_TOOL_CALL_COUNT, String.valueOf(defaultLong(snapshot.getToolCallCount()))); values.put(FIELD_TOOL_SUCCESS_COUNT, String.valueOf(defaultLong(snapshot.getToolSuccessCount()))); values.put(FIELD_TOOL_FAILURE_COUNT, String.valueOf(defaultLong(snapshot.getToolFailureCount()))); jedis.hset(key, values); }); } private String buildConfigKey(Long tenantId, String promptCode, Long aiParamId) { return CONFIG_KEY_PREFIX + tenantId + ":" + safeToken(promptCode) + ":" + aiParamToken(aiParamId); } private String buildRuntimeKey(Long tenantId, Long userId, String promptCode, Long sessionId, Long aiParamId) { return RUNTIME_KEY_PREFIX + tenantId + ":" + userId + ":" + safeToken(promptCode) + ":" + sessionToken(sessionId) + ":" + aiParamToken(aiParamId); } private String buildMemoryKey(Long tenantId, Long userId, String promptCode, Long sessionId) { return MEMORY_KEY_PREFIX + tenantId + ":" + userId + ":" + safeToken(promptCode) + ":" + sessionToken(sessionId); } private String buildSessionsKey(Long tenantId, Long userId, String promptCode, String keyword) { return SESSIONS_KEY_PREFIX + tenantId + ":" + userId + ":" + safeToken(promptCode) + ":" + safeToken(keyword); } private String buildMcpPreviewKey(Long tenantId, Long mountId) { return MCP_PREVIEW_KEY_PREFIX + tenantId + ":" + mountId; } private String buildMcpHealthKey(Long tenantId, Long mountId) { return MCP_HEALTH_KEY_PREFIX + tenantId + ":" + mountId; } private String buildStreamStateKey(Long tenantId, String requestId) { return STREAM_STATE_KEY_PREFIX + tenantId + ":" + safeToken(requestId); } private String buildToolResultKey(Long tenantId, String requestId, String toolName, String toolInput) { return TOOL_RESULT_KEY_PREFIX + tenantId + ":" + safeToken(requestId) + ":" + safeToken(toolName) + ":" + digest(toolInput); } private String buildRateLimitKey(Long tenantId, Long userId, String promptCode) { return RATE_LIMIT_KEY_PREFIX + tenantId + ":" + userId + ":" + safeToken(promptCode); } private String buildObserveStatsKey(Long tenantId) { return OBSERVE_STATS_KEY_PREFIX + tenantId; } private String buildToolRankKey(Long tenantId) { return OBSERVE_TOOL_RANK_KEY_PREFIX + tenantId; } private String buildToolFailRankKey(Long tenantId) { return OBSERVE_TOOL_FAIL_RANK_KEY_PREFIX + tenantId; } private String sessionToken(Long sessionId) { return sessionId == null ? "LATEST" : String.valueOf(sessionId); } private String aiParamToken(Long aiParamId) { return aiParamId == null ? "DEFAULT" : String.valueOf(aiParamId); } private String safeToken(String source) { if (!StringUtils.hasText(source)) { return "_"; } return URLEncoder.encode(source.trim(), StandardCharsets.UTF_8); } private String digest(String source) { try { MessageDigest messageDigest = MessageDigest.getInstance("SHA-256"); byte[] bytes = messageDigest.digest((source == null ? "" : source).getBytes(StandardCharsets.UTF_8)); StringBuilder builder = new StringBuilder(); for (byte value : bytes) { builder.append(String.format("%02x", value)); } return builder.toString(); } catch (Exception e) { return safeToken(source); } } private T readJson(String key, Class type) { return readJson(key, value -> objectMapper.readValue(value, type)); } private T readJson(String key, TypeReference typeReference) { return readJson(key, value -> objectMapper.readValue(value, typeReference)); } private T readJson(String key, JsonReader reader) { return execute(jedis -> { String value = jedis.get(key); if (!StringUtils.hasText(value)) { return null; } try { return reader.read(value); } catch (Exception e) { // 反序列化失败时直接删坏缓存,让下次请求自然回源重建。 log.warn("AI redis cache deserialize failed, key={}, message={}", key, e.getMessage()); jedis.del(key); return null; } }); } private void writeJson(String key, Object value, int ttlSeconds) { if (value == null) { delete(key); return; } executeVoid(jedis -> { try { jedis.setex(key, ttlSeconds, objectMapper.writeValueAsString(value)); } catch (Exception e) { log.warn("AI redis cache serialize failed, key={}, message={}", key, e.getMessage()); } }); } private void delete(String key) { executeVoid(jedis -> jedis.del(key)); } private void deleteByPrefix(String prefix) { executeVoid(jedis -> { Set keys = jedis.keys(prefix + "*"); if (keys == null || keys.isEmpty()) { return; } jedis.del(keys.toArray(new String[0])); }); } private long parseLong(String source) { if (!StringUtils.hasText(source)) { return 0L; } try { return Long.parseLong(source); } catch (Exception e) { return 0L; } } private long defaultLong(Long value) { return value == null ? 0L : value; } private T execute(Function action) { Jedis jedis = redisService.getJedis(); if (jedis == null) { return null; } try (jedis) { return action.apply(jedis); } catch (Exception e) { log.warn("AI redis operation skipped, message={}", e.getMessage()); return null; } } private void executeVoid(Consumer action) { Jedis jedis = redisService.getJedis(); if (jedis == null) { return; } try (jedis) { action.accept(jedis); } catch (Exception e) { log.warn("AI redis operation skipped, message={}", e.getMessage()); } } @FunctionalInterface private interface JsonReader { T read(String value) throws Exception; } @Data @Builder @NoArgsConstructor @AllArgsConstructor public static class CachedToolResult { private boolean success; private String output; private String errorMessage; } @Data @Builder @NoArgsConstructor @AllArgsConstructor private static class AiStreamState { private String requestId; private Long tenantId; private Long userId; private Long sessionId; private String promptCode; private String status; private String errorMessage; private Long timestamp; } }