zhou zhou
12 小时以前 82624affb0251b75b62b35567d3eb260c06efe78
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java
@@ -1,96 +1,81 @@
package com.vincent.rsf.server.ai.service.impl;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.vincent.rsf.framework.common.Cools;
import com.vincent.rsf.framework.exception.CoolException;
import com.vincent.rsf.server.ai.config.AiDefaults;
import com.vincent.rsf.server.ai.dto.AiChatDoneDto;
import com.vincent.rsf.server.ai.dto.AiChatMemoryDto;
import com.vincent.rsf.server.ai.dto.AiChatMessageDto;
import com.vincent.rsf.server.ai.dto.AiChatModelOptionDto;
import com.vincent.rsf.server.ai.dto.AiChatRequest;
import com.vincent.rsf.server.ai.dto.AiChatRuntimeDto;
import com.vincent.rsf.server.ai.dto.AiChatSessionDto;
import com.vincent.rsf.server.ai.dto.AiChatSessionPinRequest;
import com.vincent.rsf.server.ai.dto.AiChatSessionRenameRequest;
import com.vincent.rsf.server.ai.dto.AiResolvedConfig;
import com.vincent.rsf.server.ai.entity.AiParam;
import com.vincent.rsf.server.ai.entity.AiPrompt;
import com.vincent.rsf.server.ai.entity.AiChatSession;
import com.vincent.rsf.server.ai.service.AiChatService;
import com.vincent.rsf.server.ai.service.AiChatMemoryService;
import com.vincent.rsf.server.ai.service.AiChatService;
import com.vincent.rsf.server.ai.service.AiConfigResolverService;
import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory;
import io.micrometer.observation.ObservationRegistry;
import com.vincent.rsf.server.ai.service.AiParamService;
import com.vincent.rsf.server.ai.service.impl.chat.AiChatOrchestrator;
import com.vincent.rsf.server.ai.service.impl.chat.AiChatRuntimeAssembler;
import com.vincent.rsf.server.ai.store.AiConversationCacheStore;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.tool.DefaultToolCallingManager;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor;
import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver;
import org.springframework.ai.util.json.schema.SchemaType;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.http.MediaType;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
import org.springframework.web.client.RestClient;
import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import reactor.core.publisher.Flux;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.Executor;
@Slf4j
@Service
@RequiredArgsConstructor
public class AiChatServiceImpl implements AiChatService {
    private final AiConfigResolverService aiConfigResolverService;
    private final AiChatMemoryService aiChatMemoryService;
    private final McpMountRuntimeFactory mcpMountRuntimeFactory;
    private final GenericApplicationContext applicationContext;
    private final ObservationRegistry observationRegistry;
    private final ObjectMapper objectMapper;
    private final AiParamService aiParamService;
    private final AiConversationCacheStore aiConversationCacheStore;
    private final AiChatRuntimeAssembler aiChatRuntimeAssembler;
    private final AiChatOrchestrator aiChatOrchestrator;
    @Qualifier("aiChatTaskExecutor")
    private final Executor aiChatTaskExecutor;
    @Override
    public AiChatRuntimeDto getRuntime(String promptCode, Long sessionId, Long userId, Long tenantId) {
        AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId);
    public AiChatRuntimeDto getRuntime(String promptCode, Long sessionId, Long aiParamId, Long userId, Long tenantId) {
        AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId, aiParamId);
        AiChatRuntimeDto cached = aiConversationCacheStore.getRuntime(tenantId, userId, config.getPromptCode(), sessionId, aiParamId);
        if (cached != null) {
            return cached;
        }
        AiChatMemoryDto memory = aiChatMemoryService.getMemory(userId, tenantId, config.getPromptCode(), sessionId);
        return AiChatRuntimeDto.builder()
                .sessionId(memory.getSessionId())
                .promptCode(config.getPromptCode())
                .promptName(config.getPrompt().getName())
                .model(config.getAiParam().getModel())
                .configuredMcpCount(config.getMcpMounts().size())
                .mountedMcpCount(config.getMcpMounts().size())
                .mountedMcpNames(config.getMcpMounts().stream().map(item -> item.getName()).toList())
                .mountErrors(List.of())
                .persistedMessages(memory.getPersistedMessages())
                .build();
        List<AiChatModelOptionDto> modelOptions = aiParamService.listChatModelOptions(tenantId);
        AiChatRuntimeDto runtime = aiChatRuntimeAssembler.buildRuntimeSnapshot(
                null,
                memory.getSessionId(),
                config,
                modelOptions,
                config.getMcpMounts().size(),
                config.getMcpMounts().stream().map(item -> item.getName()).toList(),
                List.of(),
                memory
        );
        aiConversationCacheStore.cacheRuntime(tenantId, userId, config.getPromptCode(), sessionId, aiParamId, runtime);
        if (memory.getSessionId() != null && !Objects.equals(memory.getSessionId(), sessionId)) {
            aiConversationCacheStore.cacheRuntime(tenantId, userId, config.getPromptCode(), memory.getSessionId(), aiParamId, runtime);
        }
        return runtime;
    }
    @Override
    public List<AiChatSessionDto> listSessions(String promptCode, Long userId, Long tenantId) {
    public List<AiChatSessionDto> listSessions(String promptCode, String keyword, Long userId, Long tenantId) {
        AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId);
        return aiChatMemoryService.listSessions(userId, tenantId, config.getPromptCode());
        return aiChatMemoryService.listSessions(userId, tenantId, config.getPromptCode(), keyword);
    }
    @Override
    public SseEmitter stream(AiChatRequest request, Long userId, Long tenantId) {
        SseEmitter emitter = new SseEmitter(AiDefaults.SSE_TIMEOUT_MS);
        CompletableFuture.runAsync(() -> aiChatOrchestrator.executeStream(request, userId, tenantId, emitter), aiChatTaskExecutor);
        return emitter;
    }
    @Override
@@ -99,251 +84,22 @@
    }
    @Override
    public SseEmitter stream(AiChatRequest request, Long userId, Long tenantId) {
        SseEmitter emitter = new SseEmitter(AiDefaults.SSE_TIMEOUT_MS);
        CompletableFuture.runAsync(() -> doStream(request, userId, tenantId, emitter));
        return emitter;
    public AiChatSessionDto renameSession(Long sessionId, AiChatSessionRenameRequest request, Long userId, Long tenantId) {
        return aiChatMemoryService.renameSession(userId, tenantId, sessionId, request);
    }
    private void doStream(AiChatRequest request, Long userId, Long tenantId, SseEmitter emitter) {
        try {
            AiResolvedConfig config = aiConfigResolverService.resolve(request.getPromptCode(), tenantId);
            AiChatSession session = aiChatMemoryService.resolveSession(userId, tenantId, config.getPromptCode(), request.getSessionId(), resolveTitleSeed(request.getMessages()));
            AiChatMemoryDto memory = aiChatMemoryService.getMemory(userId, tenantId, config.getPromptCode(), session.getId());
            List<AiChatMessageDto> mergedMessages = mergeMessages(memory.getPersistedMessages(), request.getMessages());
            try (McpMountRuntimeFactory.McpMountRuntime runtime = mcpMountRuntimeFactory.create(config.getMcpMounts(), userId)) {
                emit(emitter, "start", AiChatRuntimeDto.builder()
                        .sessionId(session.getId())
                        .promptCode(config.getPromptCode())
                        .promptName(config.getPrompt().getName())
                        .model(config.getAiParam().getModel())
                        .configuredMcpCount(config.getMcpMounts().size())
                        .mountedMcpCount(runtime.getMountedCount())
                        .mountedMcpNames(runtime.getMountedNames())
                        .mountErrors(runtime.getErrors())
                        .persistedMessages(memory.getPersistedMessages())
                        .build());
                Prompt prompt = new Prompt(
                        buildPromptMessages(mergedMessages, config.getPrompt(), request.getMetadata()),
                        buildChatOptions(config.getAiParam(), runtime.getToolCallbacks(), userId, request.getMetadata())
                );
                OpenAiChatModel chatModel = createChatModel(config.getAiParam());
                if (Boolean.FALSE.equals(config.getAiParam().getStreamingEnabled())) {
                    ChatResponse response = chatModel.call(prompt);
                    String content = extractContent(response);
                    aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), content);
                    if (StringUtils.hasText(content)) {
                        emit(emitter, "delta", buildMessagePayload("content", content));
                    }
                    emitDone(emitter, response.getMetadata(), config.getAiParam().getModel(), session.getId());
                    emitter.complete();
                    return;
                }
                Flux<ChatResponse> responseFlux = chatModel.stream(prompt);
                AtomicReference<ChatResponseMetadata> lastMetadata = new AtomicReference<>();
                StringBuilder assistantContent = new StringBuilder();
                responseFlux.doOnNext(response -> {
                            lastMetadata.set(response.getMetadata());
                            String content = extractContent(response);
                            if (StringUtils.hasText(content)) {
                                assistantContent.append(content);
                                emit(emitter, "delta", buildMessagePayload("content", content));
                            }
                        })
                        .doOnError(error -> emit(emitter, "error", buildMessagePayload("message", error == null ? "AI 对话失败" : error.getMessage())))
                        .blockLast();
                aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), assistantContent.toString());
                emitDone(emitter, lastMetadata.get(), config.getAiParam().getModel(), session.getId());
                emitter.complete();
            }
        } catch (Exception e) {
            log.error("AI stream error", e);
            emit(emitter, "error", buildMessagePayload("message", e == null ? "AI 对话失败" : e.getMessage()));
            emitter.completeWithError(e);
        }
    @Override
    public AiChatSessionDto pinSession(Long sessionId, AiChatSessionPinRequest request, Long userId, Long tenantId) {
        return aiChatMemoryService.pinSession(userId, tenantId, sessionId, request);
    }
    private OpenAiChatModel createChatModel(AiParam aiParam) {
        OpenAiApi openAiApi = buildOpenAiApi(aiParam);
        ToolCallingManager toolCallingManager = DefaultToolCallingManager.builder()
                .observationRegistry(observationRegistry)
                .toolCallbackResolver(new SpringBeanToolCallbackResolver(applicationContext, SchemaType.OPEN_API_SCHEMA))
                .toolExecutionExceptionProcessor(new DefaultToolExecutionExceptionProcessor(false))
                .build();
        return new OpenAiChatModel(
                openAiApi,
                OpenAiChatOptions.builder()
                        .model(aiParam.getModel())
                        .temperature(aiParam.getTemperature())
                        .topP(aiParam.getTopP())
                        .maxTokens(aiParam.getMaxTokens())
                        .streamUsage(true)
                        .build(),
                toolCallingManager,
                org.springframework.retry.support.RetryTemplate.builder().maxAttempts(1).build(),
                observationRegistry
        );
    @Override
    public void clearSessionMemory(Long sessionId, Long userId, Long tenantId) {
        aiChatMemoryService.clearSessionMemory(userId, tenantId, sessionId);
    }
    private OpenAiApi buildOpenAiApi(AiParam aiParam) {
        int timeoutMs = aiParam.getTimeoutMs() == null ? AiDefaults.DEFAULT_TIMEOUT_MS : aiParam.getTimeoutMs();
        SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory();
        requestFactory.setConnectTimeout(timeoutMs);
        requestFactory.setReadTimeout(timeoutMs);
        return OpenAiApi.builder()
                .baseUrl(aiParam.getBaseUrl())
                .apiKey(aiParam.getApiKey())
                .restClientBuilder(RestClient.builder().requestFactory(requestFactory))
                .webClientBuilder(WebClient.builder())
                .build();
    }
    private OpenAiChatOptions buildChatOptions(AiParam aiParam, ToolCallback[] toolCallbacks, Long userId, Map<String, Object> metadata) {
        if (userId == null) {
            throw new CoolException("当前登录用户不存在");
        }
        OpenAiChatOptions.Builder builder = OpenAiChatOptions.builder()
                .model(aiParam.getModel())
                .temperature(aiParam.getTemperature())
                .topP(aiParam.getTopP())
                .maxTokens(aiParam.getMaxTokens())
                .streamUsage(true)
                .user(String.valueOf(userId));
        if (!Cools.isEmpty(toolCallbacks)) {
            builder.toolCallbacks(Arrays.asList(toolCallbacks));
        }
        Map<String, String> metadataMap = new LinkedHashMap<>();
        if (metadata != null) {
            metadata.forEach((key, value) -> metadataMap.put(key, value == null ? "" : String.valueOf(value)));
        }
        if (!metadataMap.isEmpty()) {
            builder.metadata(metadataMap);
        }
        return builder.build();
    }
    private List<Message> buildPromptMessages(List<AiChatMessageDto> sourceMessages, AiPrompt aiPrompt, Map<String, Object> metadata) {
        if (Cools.isEmpty(sourceMessages)) {
            throw new CoolException("对话消息不能为空");
        }
        List<Message> messages = new ArrayList<>();
        if (StringUtils.hasText(aiPrompt.getSystemPrompt())) {
            messages.add(new SystemMessage(aiPrompt.getSystemPrompt()));
        }
        int lastUserIndex = -1;
        for (int i = 0; i < sourceMessages.size(); i++) {
            AiChatMessageDto item = sourceMessages.get(i);
            if (item != null && "user".equalsIgnoreCase(item.getRole())) {
                lastUserIndex = i;
            }
        }
        for (int i = 0; i < sourceMessages.size(); i++) {
            AiChatMessageDto item = sourceMessages.get(i);
            if (item == null || !StringUtils.hasText(item.getContent())) {
                continue;
            }
            String role = item.getRole() == null ? "user" : item.getRole().toLowerCase();
            if ("system".equals(role)) {
                continue;
            }
            String content = item.getContent();
            if ("user".equals(role) && i == lastUserIndex) {
                content = renderUserPrompt(aiPrompt.getUserPromptTemplate(), content, metadata);
            }
            if ("assistant".equals(role)) {
                messages.add(new AssistantMessage(content));
            } else {
                messages.add(new UserMessage(content));
            }
        }
        if (messages.stream().noneMatch(item -> item instanceof UserMessage)) {
            throw new CoolException("至少需要一条用户消息");
        }
        return messages;
    }
    private List<AiChatMessageDto> mergeMessages(List<AiChatMessageDto> persistedMessages, List<AiChatMessageDto> memoryMessages) {
        List<AiChatMessageDto> merged = new ArrayList<>();
        if (!Cools.isEmpty(persistedMessages)) {
            merged.addAll(persistedMessages);
        }
        if (!Cools.isEmpty(memoryMessages)) {
            merged.addAll(memoryMessages);
        }
        if (merged.isEmpty()) {
            throw new CoolException("对话消息不能为空");
        }
        return merged;
    }
    private String resolveTitleSeed(List<AiChatMessageDto> messages) {
        if (Cools.isEmpty(messages)) {
            throw new CoolException("对话消息不能为空");
        }
        for (int i = messages.size() - 1; i >= 0; i--) {
            AiChatMessageDto item = messages.get(i);
            if (item != null && "user".equalsIgnoreCase(item.getRole()) && StringUtils.hasText(item.getContent())) {
                return item.getContent();
            }
        }
        throw new CoolException("至少需要一条用户消息");
    }
    private String renderUserPrompt(String userPromptTemplate, String content, Map<String, Object> metadata) {
        if (!StringUtils.hasText(userPromptTemplate)) {
            return content;
        }
        String rendered = userPromptTemplate
                .replace("{{input}}", content)
                .replace("{input}", content);
        if (metadata != null) {
            for (Map.Entry<String, Object> entry : metadata.entrySet()) {
                String value = entry.getValue() == null ? "" : String.valueOf(entry.getValue());
                rendered = rendered.replace("{{" + entry.getKey() + "}}", value);
                rendered = rendered.replace("{" + entry.getKey() + "}", value);
            }
        }
        if (Objects.equals(rendered, userPromptTemplate)) {
            return userPromptTemplate + "\n\n" + content;
        }
        return rendered;
    }
    private String extractContent(ChatResponse response) {
        if (response == null || response.getResult() == null || response.getResult().getOutput() == null) {
            return null;
        }
        return response.getResult().getOutput().getText();
    }
    private void emitDone(SseEmitter emitter, ChatResponseMetadata metadata, String fallbackModel, Long sessionId) {
        Usage usage = metadata == null ? null : metadata.getUsage();
        emit(emitter, "done", AiChatDoneDto.builder()
                .sessionId(sessionId)
                .model(metadata != null && StringUtils.hasText(metadata.getModel()) ? metadata.getModel() : fallbackModel)
                .promptTokens(usage == null ? null : usage.getPromptTokens())
                .completionTokens(usage == null ? null : usage.getCompletionTokens())
                .totalTokens(usage == null ? null : usage.getTotalTokens())
                .build());
    }
    private Map<String, String> buildMessagePayload(String key, String value) {
        Map<String, String> payload = new LinkedHashMap<>();
        payload.put(key, value == null ? "" : value);
        return payload;
    }
    private void emit(SseEmitter emitter, String eventName, Object payload) {
        try {
            String data = objectMapper.writeValueAsString(payload);
            emitter.send(SseEmitter.event()
                    .name(eventName)
                    .data(data, MediaType.APPLICATION_JSON));
        } catch (IOException e) {
            throw new CoolException("SSE 输出失败: " + e.getMessage());
        }
    @Override
    public void retainLatestRound(Long sessionId, Long userId, Long tenantId) {
        aiChatMemoryService.retainLatestRound(userId, tenantId, sessionId);
    }
}