zhou zhou
12 小时以前 3d81df739dc45599c257d8cdefe0996f66ccdeae
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java
@@ -12,6 +12,9 @@
import com.vincent.rsf.server.ai.dto.AiChatRuntimeDto;
import com.vincent.rsf.server.ai.dto.AiChatStatusDto;
import com.vincent.rsf.server.ai.dto.AiChatSessionDto;
import com.vincent.rsf.server.ai.dto.AiChatSessionPinRequest;
import com.vincent.rsf.server.ai.dto.AiChatSessionRenameRequest;
import com.vincent.rsf.server.ai.dto.AiChatToolEventDto;
import com.vincent.rsf.server.ai.dto.AiResolvedConfig;
import com.vincent.rsf.server.ai.entity.AiParam;
import com.vincent.rsf.server.ai.entity.AiPrompt;
@@ -32,6 +35,7 @@
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.tool.DefaultToolCallingManager;
import org.springframework.ai.model.tool.ToolCallingManager;
@@ -64,6 +68,7 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.atomic.AtomicLong;
@Slf4j
@Service
@@ -93,19 +98,42 @@
                .mountedMcpCount(config.getMcpMounts().size())
                .mountedMcpNames(config.getMcpMounts().stream().map(item -> item.getName()).toList())
                .mountErrors(List.of())
                .memorySummary(memory.getMemorySummary())
                .memoryFacts(memory.getMemoryFacts())
                .recentMessageCount(memory.getRecentMessageCount())
                .persistedMessages(memory.getPersistedMessages())
                .build();
    }
    @Override
    public List<AiChatSessionDto> listSessions(String promptCode, Long userId, Long tenantId) {
    public List<AiChatSessionDto> listSessions(String promptCode, String keyword, Long userId, Long tenantId) {
        AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId);
        return aiChatMemoryService.listSessions(userId, tenantId, config.getPromptCode());
        return aiChatMemoryService.listSessions(userId, tenantId, config.getPromptCode(), keyword);
    }
    @Override
    public void removeSession(Long sessionId, Long userId, Long tenantId) {
        aiChatMemoryService.removeSession(userId, tenantId, sessionId);
    }
    @Override
    public AiChatSessionDto renameSession(Long sessionId, AiChatSessionRenameRequest request, Long userId, Long tenantId) {
        return aiChatMemoryService.renameSession(userId, tenantId, sessionId, request);
    }
    @Override
    public AiChatSessionDto pinSession(Long sessionId, AiChatSessionPinRequest request, Long userId, Long tenantId) {
        return aiChatMemoryService.pinSession(userId, tenantId, sessionId, request);
    }
    @Override
    public void clearSessionMemory(Long sessionId, Long userId, Long tenantId) {
        aiChatMemoryService.clearSessionMemory(userId, tenantId, sessionId);
    }
    @Override
    public void retainLatestRound(Long sessionId, Long userId, Long tenantId) {
        aiChatMemoryService.retainLatestRound(userId, tenantId, sessionId);
    }
    @Override
@@ -119,6 +147,7 @@
        String requestId = request.getRequestId();
        long startedAt = System.currentTimeMillis();
        AtomicReference<Long> firstTokenAtRef = new AtomicReference<>();
        AtomicLong toolCallSequence = new AtomicLong(0);
        Long sessionId = request.getSessionId();
        String model = null;
        try {
@@ -129,7 +158,7 @@
            AiChatSession session = resolveSession(request, userId, tenantId, config.getPromptCode());
            sessionId = session.getId();
            AiChatMemoryDto memory = loadMemory(userId, tenantId, config.getPromptCode(), session.getId());
            List<AiChatMessageDto> mergedMessages = mergeMessages(memory.getPersistedMessages(), request.getMessages());
            List<AiChatMessageDto> mergedMessages = mergeMessages(memory.getShortMemoryMessages(), request.getMessages());
            try (McpMountRuntimeFactory.McpMountRuntime runtime = createRuntime(config, userId)) {
                emitStrict(emitter, "start", AiChatRuntimeDto.builder()
                        .requestId(requestId)
@@ -141,6 +170,9 @@
                        .mountedMcpCount(runtime.getMountedCount())
                        .mountedMcpNames(runtime.getMountedNames())
                        .mountErrors(runtime.getErrors())
                        .memorySummary(memory.getMemorySummary())
                        .memoryFacts(memory.getMemoryFacts())
                        .recentMessageCount(memory.getRecentMessageCount())
                        .persistedMessages(memory.getPersistedMessages())
                        .build());
                emitSafely(emitter, "status", AiChatStatusDto.builder()
@@ -154,9 +186,13 @@
                log.info("AI chat started, requestId={}, userId={}, tenantId={}, sessionId={}, model={}",
                        requestId, userId, tenantId, session.getId(), resolvedModel);
                ToolCallback[] observableToolCallbacks = wrapToolCallbacks(
                        runtime.getToolCallbacks(), emitter, requestId, session.getId(), toolCallSequence
                );
                Prompt prompt = new Prompt(
                        buildPromptMessages(mergedMessages, config.getPrompt(), request.getMetadata()),
                        buildChatOptions(config.getAiParam(), runtime.getToolCallbacks(), userId, request.getMetadata())
                        buildPromptMessages(memory, mergedMessages, config.getPrompt(), request.getMetadata()),
                        buildChatOptions(config.getAiParam(), observableToolCallbacks, userId, tenantId,
                                requestId, session.getId(), request.getMetadata())
                );
                OpenAiChatModel chatModel = createChatModel(config.getAiParam());
                if (Boolean.FALSE.equals(config.getAiParam().getStreamingEnabled())) {
@@ -364,7 +400,8 @@
                .build();
    }
    private OpenAiChatOptions buildChatOptions(AiParam aiParam, ToolCallback[] toolCallbacks, Long userId, Map<String, Object> metadata) {
    private OpenAiChatOptions buildChatOptions(AiParam aiParam, ToolCallback[] toolCallbacks, Long userId, Long tenantId,
                                               String requestId, Long sessionId, Map<String, Object> metadata) {
        if (userId == null) {
            throw buildAiException("AI_AUTH_USER_MISSING", AiErrorCategory.AUTH, "OPTIONS_BUILD", "当前登录用户不存在", null);
        }
@@ -376,25 +413,56 @@
                .streamUsage(true)
                .user(String.valueOf(userId));
        if (!Cools.isEmpty(toolCallbacks)) {
            builder.toolCallbacks(Arrays.asList(toolCallbacks));
            builder.toolCallbacks(Arrays.stream(toolCallbacks).toList());
        }
        Map<String, Object> toolContext = new LinkedHashMap<>();
        toolContext.put("userId", userId);
        toolContext.put("tenantId", tenantId);
        toolContext.put("requestId", requestId);
        toolContext.put("sessionId", sessionId);
        Map<String, String> metadataMap = new LinkedHashMap<>();
        if (metadata != null) {
            metadata.forEach((key, value) -> metadataMap.put(key, value == null ? "" : String.valueOf(value)));
            metadata.forEach((key, value) -> {
                String normalized = value == null ? "" : String.valueOf(value);
                metadataMap.put(key, normalized);
                toolContext.put(key, normalized);
            });
        }
        builder.toolContext(toolContext);
        if (!metadataMap.isEmpty()) {
            builder.metadata(metadataMap);
        }
        return builder.build();
    }
    private List<Message> buildPromptMessages(List<AiChatMessageDto> sourceMessages, AiPrompt aiPrompt, Map<String, Object> metadata) {
    private ToolCallback[] wrapToolCallbacks(ToolCallback[] toolCallbacks, SseEmitter emitter, String requestId,
                                             Long sessionId, AtomicLong toolCallSequence) {
        if (Cools.isEmpty(toolCallbacks)) {
            return toolCallbacks;
        }
        List<ToolCallback> wrappedCallbacks = new ArrayList<>();
        for (ToolCallback callback : toolCallbacks) {
            if (callback == null) {
                continue;
            }
            wrappedCallbacks.add(new ObservableToolCallback(callback, emitter, requestId, sessionId, toolCallSequence));
        }
        return wrappedCallbacks.toArray(new ToolCallback[0]);
    }
    private List<Message> buildPromptMessages(AiChatMemoryDto memory, List<AiChatMessageDto> sourceMessages, AiPrompt aiPrompt, Map<String, Object> metadata) {
        if (Cools.isEmpty(sourceMessages)) {
            throw new CoolException("对话消息不能为空");
        }
        List<Message> messages = new ArrayList<>();
        if (StringUtils.hasText(aiPrompt.getSystemPrompt())) {
            messages.add(new SystemMessage(aiPrompt.getSystemPrompt()));
        }
        if (memory != null && StringUtils.hasText(memory.getMemorySummary())) {
            messages.add(new SystemMessage("历史摘要:\n" + memory.getMemorySummary()));
        }
        if (memory != null && StringUtils.hasText(memory.getMemoryFacts())) {
            messages.add(new SystemMessage("关键事实:\n" + memory.getMemoryFacts()));
        }
        int lastUserIndex = -1;
        for (int i = 0; i < sourceMessages.size(); i++) {
@@ -482,6 +550,17 @@
        return response.getResult().getOutput().getText();
    }
    private String summarizeToolPayload(String content, int maxLength) {
        if (!StringUtils.hasText(content)) {
            return null;
        }
        String normalized = content.trim()
                .replace("\r", " ")
                .replace("\n", " ")
                .replaceAll("\\s+", " ");
        return normalized.length() > maxLength ? normalized.substring(0, maxLength) : normalized;
    }
    private void emitDone(SseEmitter emitter, String requestId, ChatResponseMetadata metadata, String fallbackModel, Long sessionId, long startedAt, Long firstTokenAt) {
        Usage usage = metadata == null ? null : metadata.getUsage();
        emitStrict(emitter, "done", AiChatDoneDto.builder()
@@ -550,4 +629,81 @@
        }
        return false;
    }
    private class ObservableToolCallback implements ToolCallback {
        private final ToolCallback delegate;
        private final SseEmitter emitter;
        private final String requestId;
        private final Long sessionId;
        private final AtomicLong toolCallSequence;
        private ObservableToolCallback(ToolCallback delegate, SseEmitter emitter, String requestId,
                                       Long sessionId, AtomicLong toolCallSequence) {
            this.delegate = delegate;
            this.emitter = emitter;
            this.requestId = requestId;
            this.sessionId = sessionId;
            this.toolCallSequence = toolCallSequence;
        }
        @Override
        public org.springframework.ai.tool.definition.ToolDefinition getToolDefinition() {
            return delegate.getToolDefinition();
        }
        @Override
        public org.springframework.ai.tool.metadata.ToolMetadata getToolMetadata() {
            return delegate.getToolMetadata();
        }
        @Override
        public String call(String toolInput) {
            return call(toolInput, null);
        }
        @Override
        public String call(String toolInput, ToolContext toolContext) {
            String toolName = delegate.getToolDefinition() == null ? "unknown" : delegate.getToolDefinition().name();
            String toolCallId = requestId + "-tool-" + toolCallSequence.incrementAndGet();
            long startedAt = System.currentTimeMillis();
            emitSafely(emitter, "tool_start", AiChatToolEventDto.builder()
                    .requestId(requestId)
                    .sessionId(sessionId)
                    .toolCallId(toolCallId)
                    .toolName(toolName)
                    .status("STARTED")
                    .inputSummary(summarizeToolPayload(toolInput, 400))
                    .timestamp(startedAt)
                    .build());
            try {
                String output = toolContext == null ? delegate.call(toolInput) : delegate.call(toolInput, toolContext);
                emitSafely(emitter, "tool_result", AiChatToolEventDto.builder()
                        .requestId(requestId)
                        .sessionId(sessionId)
                        .toolCallId(toolCallId)
                        .toolName(toolName)
                        .status("COMPLETED")
                        .inputSummary(summarizeToolPayload(toolInput, 400))
                        .outputSummary(summarizeToolPayload(output, 600))
                        .durationMs(System.currentTimeMillis() - startedAt)
                        .timestamp(System.currentTimeMillis())
                        .build());
                return output;
            } catch (RuntimeException e) {
                emitSafely(emitter, "tool_error", AiChatToolEventDto.builder()
                        .requestId(requestId)
                        .sessionId(sessionId)
                        .toolCallId(toolCallId)
                        .toolName(toolName)
                        .status("FAILED")
                        .inputSummary(summarizeToolPayload(toolInput, 400))
                        .errorMessage(e.getMessage())
                        .durationMs(System.currentTimeMillis() - startedAt)
                        .timestamp(System.currentTimeMillis())
                        .build());
                throw e;
            }
        }
    }
}