| | |
| | | import com.vincent.rsf.framework.exception.CoolException; |
| | | import com.vincent.rsf.server.ai.config.AiDefaults; |
| | | import com.vincent.rsf.server.ai.dto.AiChatDoneDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatErrorDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatMemoryDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatMessageDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatRequest; |
| | | import com.vincent.rsf.server.ai.dto.AiChatRuntimeDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatStatusDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatSessionDto; |
| | | import com.vincent.rsf.server.ai.dto.AiChatSessionPinRequest; |
| | | import com.vincent.rsf.server.ai.dto.AiChatSessionRenameRequest; |
| | | import com.vincent.rsf.server.ai.dto.AiResolvedConfig; |
| | | import com.vincent.rsf.server.ai.entity.AiParam; |
| | | import com.vincent.rsf.server.ai.entity.AiPrompt; |
| | | import com.vincent.rsf.server.ai.entity.AiChatSession; |
| | | import com.vincent.rsf.server.ai.enums.AiErrorCategory; |
| | | import com.vincent.rsf.server.ai.exception.AiChatException; |
| | | import com.vincent.rsf.server.ai.service.AiChatService; |
| | | import com.vincent.rsf.server.ai.service.AiChatMemoryService; |
| | | import com.vincent.rsf.server.ai.service.AiConfigResolverService; |
| | |
| | | import org.springframework.context.support.GenericApplicationContext; |
| | | import org.springframework.http.MediaType; |
| | | import org.springframework.http.client.SimpleClientHttpRequestFactory; |
| | | import org.springframework.beans.factory.annotation.Qualifier; |
| | | import org.springframework.stereotype.Service; |
| | | import org.springframework.util.StringUtils; |
| | | import org.springframework.web.client.RestClient; |
| | |
| | | import reactor.core.publisher.Flux; |
| | | |
| | | import java.io.IOException; |
| | | import java.time.Instant; |
| | | import java.util.ArrayList; |
| | | import java.util.Arrays; |
| | | import java.util.LinkedHashMap; |
| | |
| | | import java.util.Map; |
| | | import java.util.Objects; |
| | | import java.util.concurrent.CompletableFuture; |
| | | import java.util.concurrent.Executor; |
| | | import java.util.concurrent.atomic.AtomicReference; |
| | | |
| | | @Slf4j |
| | |
| | | private final GenericApplicationContext applicationContext; |
| | | private final ObservationRegistry observationRegistry; |
| | | private final ObjectMapper objectMapper; |
| | | @Qualifier("aiChatTaskExecutor") |
| | | private final Executor aiChatTaskExecutor; |
| | | |
| | | @Override |
| | | public AiChatRuntimeDto getRuntime(String promptCode, Long sessionId, Long userId, Long tenantId) { |
| | | AiResolvedConfig config = aiConfigResolverService.resolve(promptCode); |
| | | AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId); |
| | | AiChatMemoryDto memory = aiChatMemoryService.getMemory(userId, tenantId, config.getPromptCode(), sessionId); |
| | | return AiChatRuntimeDto.builder() |
| | | .requestId(null) |
| | | .sessionId(memory.getSessionId()) |
| | | .promptCode(config.getPromptCode()) |
| | | .promptName(config.getPrompt().getName()) |
| | |
| | | } |
| | | |
| | | @Override |
| | | public List<AiChatSessionDto> listSessions(String promptCode, Long userId, Long tenantId) { |
| | | AiResolvedConfig config = aiConfigResolverService.resolve(promptCode); |
| | | return aiChatMemoryService.listSessions(userId, tenantId, config.getPromptCode()); |
| | | public List<AiChatSessionDto> listSessions(String promptCode, String keyword, Long userId, Long tenantId) { |
| | | AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId); |
| | | return aiChatMemoryService.listSessions(userId, tenantId, config.getPromptCode(), keyword); |
| | | } |
| | | |
| | | @Override |
| | |
| | | } |
| | | |
| | | @Override |
| | | public AiChatSessionDto renameSession(Long sessionId, AiChatSessionRenameRequest request, Long userId, Long tenantId) { |
| | | return aiChatMemoryService.renameSession(userId, tenantId, sessionId, request); |
| | | } |
| | | |
| | | @Override |
| | | public AiChatSessionDto pinSession(Long sessionId, AiChatSessionPinRequest request, Long userId, Long tenantId) { |
| | | return aiChatMemoryService.pinSession(userId, tenantId, sessionId, request); |
| | | } |
| | | |
| | | @Override |
| | | public SseEmitter stream(AiChatRequest request, Long userId, Long tenantId) { |
| | | SseEmitter emitter = new SseEmitter(AiDefaults.SSE_TIMEOUT_MS); |
| | | CompletableFuture.runAsync(() -> doStream(request, userId, tenantId, emitter)); |
| | | CompletableFuture.runAsync(() -> doStream(request, userId, tenantId, emitter), aiChatTaskExecutor); |
| | | return emitter; |
| | | } |
| | | |
| | | private void doStream(AiChatRequest request, Long userId, Long tenantId, SseEmitter emitter) { |
| | | String requestId = request.getRequestId(); |
| | | long startedAt = System.currentTimeMillis(); |
| | | AtomicReference<Long> firstTokenAtRef = new AtomicReference<>(); |
| | | Long sessionId = request.getSessionId(); |
| | | String model = null; |
| | | try { |
| | | AiResolvedConfig config = aiConfigResolverService.resolve(request.getPromptCode()); |
| | | AiChatSession session = aiChatMemoryService.resolveSession(userId, tenantId, config.getPromptCode(), request.getSessionId(), resolveTitleSeed(request.getMessages())); |
| | | AiChatMemoryDto memory = aiChatMemoryService.getMemory(userId, tenantId, config.getPromptCode(), session.getId()); |
| | | ensureIdentity(userId, tenantId); |
| | | AiResolvedConfig config = resolveConfig(request, tenantId); |
| | | final String resolvedModel = config.getAiParam().getModel(); |
| | | model = resolvedModel; |
| | | AiChatSession session = resolveSession(request, userId, tenantId, config.getPromptCode()); |
| | | sessionId = session.getId(); |
| | | AiChatMemoryDto memory = loadMemory(userId, tenantId, config.getPromptCode(), session.getId()); |
| | | List<AiChatMessageDto> mergedMessages = mergeMessages(memory.getPersistedMessages(), request.getMessages()); |
| | | try (McpMountRuntimeFactory.McpMountRuntime runtime = mcpMountRuntimeFactory.create(config.getMcpMounts(), userId)) { |
| | | emit(emitter, "start", AiChatRuntimeDto.builder() |
| | | try (McpMountRuntimeFactory.McpMountRuntime runtime = createRuntime(config, userId)) { |
| | | emitStrict(emitter, "start", AiChatRuntimeDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(session.getId()) |
| | | .promptCode(config.getPromptCode()) |
| | | .promptName(config.getPrompt().getName()) |
| | |
| | | .mountErrors(runtime.getErrors()) |
| | | .persistedMessages(memory.getPersistedMessages()) |
| | | .build()); |
| | | emitSafely(emitter, "status", AiChatStatusDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(session.getId()) |
| | | .status("STARTED") |
| | | .model(resolvedModel) |
| | | .timestamp(Instant.now().toEpochMilli()) |
| | | .elapsedMs(0L) |
| | | .build()); |
| | | log.info("AI chat started, requestId={}, userId={}, tenantId={}, sessionId={}, model={}", |
| | | requestId, userId, tenantId, session.getId(), resolvedModel); |
| | | |
| | | Prompt prompt = new Prompt( |
| | | buildPromptMessages(mergedMessages, config.getPrompt(), request.getMetadata()), |
| | |
| | | ); |
| | | OpenAiChatModel chatModel = createChatModel(config.getAiParam()); |
| | | if (Boolean.FALSE.equals(config.getAiParam().getStreamingEnabled())) { |
| | | ChatResponse response = chatModel.call(prompt); |
| | | ChatResponse response = invokeChatCall(chatModel, prompt); |
| | | String content = extractContent(response); |
| | | aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), content); |
| | | if (StringUtils.hasText(content)) { |
| | | emit(emitter, "delta", buildMessagePayload("content", content)); |
| | | markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt); |
| | | emitStrict(emitter, "delta", buildMessagePayload("requestId", requestId, "content", content)); |
| | | } |
| | | emitDone(emitter, response.getMetadata(), config.getAiParam().getModel(), session.getId()); |
| | | emitDone(emitter, requestId, response.getMetadata(), config.getAiParam().getModel(), session.getId(), startedAt, firstTokenAtRef.get()); |
| | | emitSafely(emitter, "status", buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get())); |
| | | log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}", |
| | | requestId, session.getId(), System.currentTimeMillis() - startedAt, resolveFirstTokenLatency(startedAt, firstTokenAtRef.get())); |
| | | emitter.complete(); |
| | | return; |
| | | } |
| | | |
| | | Flux<ChatResponse> responseFlux = chatModel.stream(prompt); |
| | | Flux<ChatResponse> responseFlux = invokeChatStream(chatModel, prompt); |
| | | AtomicReference<ChatResponseMetadata> lastMetadata = new AtomicReference<>(); |
| | | StringBuilder assistantContent = new StringBuilder(); |
| | | responseFlux.doOnNext(response -> { |
| | | try { |
| | | responseFlux.doOnNext(response -> { |
| | | lastMetadata.set(response.getMetadata()); |
| | | String content = extractContent(response); |
| | | if (StringUtils.hasText(content)) { |
| | | markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt); |
| | | assistantContent.append(content); |
| | | emit(emitter, "delta", buildMessagePayload("content", content)); |
| | | emitStrict(emitter, "delta", buildMessagePayload("requestId", requestId, "content", content)); |
| | | } |
| | | }) |
| | | .doOnError(error -> emit(emitter, "error", buildMessagePayload("message", error == null ? "AI 对话失败" : error.getMessage()))) |
| | | .blockLast(); |
| | | } catch (Exception e) { |
| | | throw buildAiException("AI_MODEL_STREAM_ERROR", AiErrorCategory.MODEL, "MODEL_STREAM", |
| | | e == null ? "AI 模型流式调用失败" : e.getMessage(), e); |
| | | } |
| | | aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), assistantContent.toString()); |
| | | emitDone(emitter, lastMetadata.get(), config.getAiParam().getModel(), session.getId()); |
| | | emitDone(emitter, requestId, lastMetadata.get(), config.getAiParam().getModel(), session.getId(), startedAt, firstTokenAtRef.get()); |
| | | emitSafely(emitter, "status", buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get())); |
| | | log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}", |
| | | requestId, session.getId(), System.currentTimeMillis() - startedAt, resolveFirstTokenLatency(startedAt, firstTokenAtRef.get())); |
| | | emitter.complete(); |
| | | } |
| | | } catch (AiChatException e) { |
| | | handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(), e); |
| | | } catch (Exception e) { |
| | | log.error("AI stream error", e); |
| | | emit(emitter, "error", buildMessagePayload("message", e == null ? "AI 对话失败" : e.getMessage())); |
| | | emitter.completeWithError(e); |
| | | handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(), |
| | | buildAiException("AI_INTERNAL_ERROR", AiErrorCategory.INTERNAL, "INTERNAL", |
| | | e == null ? "AI 对话失败" : e.getMessage(), e)); |
| | | } finally { |
| | | log.debug("AI chat stream finished, requestId={}", requestId); |
| | | } |
| | | } |
| | | |
| | | private void ensureIdentity(Long userId, Long tenantId) { |
| | | if (userId == null) { |
| | | throw buildAiException("AI_AUTH_USER_MISSING", AiErrorCategory.AUTH, "AUTH_VALIDATE", "当前登录用户不存在", null); |
| | | } |
| | | if (tenantId == null) { |
| | | throw buildAiException("AI_AUTH_TENANT_MISSING", AiErrorCategory.AUTH, "AUTH_VALIDATE", "当前租户不存在", null); |
| | | } |
| | | } |
| | | |
| | | private AiResolvedConfig resolveConfig(AiChatRequest request, Long tenantId) { |
| | | try { |
| | | return aiConfigResolverService.resolve(request.getPromptCode(), tenantId); |
| | | } catch (Exception e) { |
| | | throw buildAiException("AI_CONFIG_RESOLVE_ERROR", AiErrorCategory.CONFIG, "CONFIG_RESOLVE", |
| | | e == null ? "AI 配置解析失败" : e.getMessage(), e); |
| | | } |
| | | } |
| | | |
| | | private AiChatSession resolveSession(AiChatRequest request, Long userId, Long tenantId, String promptCode) { |
| | | try { |
| | | return aiChatMemoryService.resolveSession(userId, tenantId, promptCode, request.getSessionId(), resolveTitleSeed(request.getMessages())); |
| | | } catch (Exception e) { |
| | | throw buildAiException("AI_SESSION_RESOLVE_ERROR", AiErrorCategory.REQUEST, "SESSION_RESOLVE", |
| | | e == null ? "AI 会话解析失败" : e.getMessage(), e); |
| | | } |
| | | } |
| | | |
| | | private AiChatMemoryDto loadMemory(Long userId, Long tenantId, String promptCode, Long sessionId) { |
| | | try { |
| | | return aiChatMemoryService.getMemory(userId, tenantId, promptCode, sessionId); |
| | | } catch (Exception e) { |
| | | throw buildAiException("AI_MEMORY_LOAD_ERROR", AiErrorCategory.REQUEST, "MEMORY_LOAD", |
| | | e == null ? "AI 会话记忆加载失败" : e.getMessage(), e); |
| | | } |
| | | } |
| | | |
| | | private McpMountRuntimeFactory.McpMountRuntime createRuntime(AiResolvedConfig config, Long userId) { |
| | | try { |
| | | return mcpMountRuntimeFactory.create(config.getMcpMounts(), userId); |
| | | } catch (Exception e) { |
| | | throw buildAiException("AI_MCP_MOUNT_ERROR", AiErrorCategory.MCP, "MCP_MOUNT", |
| | | e == null ? "MCP 挂载失败" : e.getMessage(), e); |
| | | } |
| | | } |
| | | |
| | | private ChatResponse invokeChatCall(OpenAiChatModel chatModel, Prompt prompt) { |
| | | try { |
| | | return chatModel.call(prompt); |
| | | } catch (Exception e) { |
| | | throw buildAiException("AI_MODEL_CALL_ERROR", AiErrorCategory.MODEL, "MODEL_CALL", |
| | | e == null ? "AI 模型调用失败" : e.getMessage(), e); |
| | | } |
| | | } |
| | | |
| | | private Flux<ChatResponse> invokeChatStream(OpenAiChatModel chatModel, Prompt prompt) { |
| | | try { |
| | | return chatModel.stream(prompt); |
| | | } catch (Exception e) { |
| | | throw buildAiException("AI_MODEL_STREAM_ERROR", AiErrorCategory.MODEL, "MODEL_STREAM_INIT", |
| | | e == null ? "AI 模型流式调用失败" : e.getMessage(), e); |
| | | } |
| | | } |
| | | |
| | | private void markFirstToken(AtomicReference<Long> firstTokenAtRef, SseEmitter emitter, String requestId, Long sessionId, String model, long startedAt) { |
| | | if (!firstTokenAtRef.compareAndSet(null, System.currentTimeMillis())) { |
| | | return; |
| | | } |
| | | emitSafely(emitter, "status", AiChatStatusDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .status("FIRST_TOKEN") |
| | | .model(model) |
| | | .timestamp(Instant.now().toEpochMilli()) |
| | | .elapsedMs(System.currentTimeMillis() - startedAt) |
| | | .firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAtRef.get())) |
| | | .build()); |
| | | } |
| | | |
| | | private AiChatStatusDto buildTerminalStatus(String requestId, Long sessionId, String status, String model, long startedAt, Long firstTokenAt) { |
| | | return AiChatStatusDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .status(status) |
| | | .model(model) |
| | | .timestamp(Instant.now().toEpochMilli()) |
| | | .elapsedMs(System.currentTimeMillis() - startedAt) |
| | | .firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAt)) |
| | | .build(); |
| | | } |
| | | |
| | | private Long resolveFirstTokenLatency(long startedAt, Long firstTokenAt) { |
| | | return firstTokenAt == null ? null : Math.max(0L, firstTokenAt - startedAt); |
| | | } |
| | | |
| | | private void handleStreamFailure(SseEmitter emitter, String requestId, Long sessionId, String model, long startedAt, Long firstTokenAt, AiChatException exception) { |
| | | if (isClientAbortException(exception)) { |
| | | log.warn("AI chat aborted by client, requestId={}, sessionId={}, stage={}, message={}", |
| | | requestId, sessionId, exception.getStage(), exception.getMessage()); |
| | | emitSafely(emitter, "status", buildTerminalStatus(requestId, sessionId, "ABORTED", model, startedAt, firstTokenAt)); |
| | | emitter.completeWithError(exception); |
| | | return; |
| | | } |
| | | log.error("AI chat failed, requestId={}, sessionId={}, category={}, stage={}, message={}", |
| | | requestId, sessionId, exception.getCategory(), exception.getStage(), exception.getMessage(), exception); |
| | | emitSafely(emitter, "status", buildTerminalStatus(requestId, sessionId, "FAILED", model, startedAt, firstTokenAt)); |
| | | emitSafely(emitter, "error", AiChatErrorDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .code(exception.getCode()) |
| | | .category(exception.getCategory().name()) |
| | | .stage(exception.getStage()) |
| | | .message(exception.getMessage()) |
| | | .timestamp(Instant.now().toEpochMilli()) |
| | | .build()); |
| | | emitter.completeWithError(exception); |
| | | } |
| | | |
| | | private OpenAiChatModel createChatModel(AiParam aiParam) { |
| | |
| | | |
| | | private OpenAiChatOptions buildChatOptions(AiParam aiParam, ToolCallback[] toolCallbacks, Long userId, Map<String, Object> metadata) { |
| | | if (userId == null) { |
| | | throw new CoolException("当前登录用户不存在"); |
| | | throw buildAiException("AI_AUTH_USER_MISSING", AiErrorCategory.AUTH, "OPTIONS_BUILD", "当前登录用户不存在", null); |
| | | } |
| | | OpenAiChatOptions.Builder builder = OpenAiChatOptions.builder() |
| | | .model(aiParam.getModel()) |
| | |
| | | return response.getResult().getOutput().getText(); |
| | | } |
| | | |
| | | private void emitDone(SseEmitter emitter, ChatResponseMetadata metadata, String fallbackModel, Long sessionId) { |
| | | private void emitDone(SseEmitter emitter, String requestId, ChatResponseMetadata metadata, String fallbackModel, Long sessionId, long startedAt, Long firstTokenAt) { |
| | | Usage usage = metadata == null ? null : metadata.getUsage(); |
| | | emit(emitter, "done", AiChatDoneDto.builder() |
| | | emitStrict(emitter, "done", AiChatDoneDto.builder() |
| | | .requestId(requestId) |
| | | .sessionId(sessionId) |
| | | .model(metadata != null && StringUtils.hasText(metadata.getModel()) ? metadata.getModel() : fallbackModel) |
| | | .elapsedMs(System.currentTimeMillis() - startedAt) |
| | | .firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAt)) |
| | | .promptTokens(usage == null ? null : usage.getPromptTokens()) |
| | | .completionTokens(usage == null ? null : usage.getCompletionTokens()) |
| | | .totalTokens(usage == null ? null : usage.getTotalTokens()) |
| | | .build()); |
| | | } |
| | | |
| | | private Map<String, String> buildMessagePayload(String key, String value) { |
| | | private Map<String, String> buildMessagePayload(String... keyValues) { |
| | | Map<String, String> payload = new LinkedHashMap<>(); |
| | | payload.put(key, value == null ? "" : value); |
| | | if (keyValues == null || keyValues.length == 0) { |
| | | return payload; |
| | | } |
| | | if (keyValues.length % 2 != 0) { |
| | | throw new CoolException("消息载荷参数必须成对出现"); |
| | | } |
| | | for (int i = 0; i < keyValues.length; i += 2) { |
| | | payload.put(keyValues[i], keyValues[i + 1] == null ? "" : keyValues[i + 1]); |
| | | } |
| | | return payload; |
| | | } |
| | | |
| | | private void emit(SseEmitter emitter, String eventName, Object payload) { |
| | | private void emitStrict(SseEmitter emitter, String eventName, Object payload) { |
| | | try { |
| | | String data = objectMapper.writeValueAsString(payload); |
| | | emitter.send(SseEmitter.event() |
| | | .name(eventName) |
| | | .data(data, MediaType.APPLICATION_JSON)); |
| | | } catch (IOException e) { |
| | | throw new CoolException("SSE 输出失败: " + e.getMessage()); |
| | | throw buildAiException("AI_SSE_EMIT_ERROR", AiErrorCategory.STREAM, "SSE_EMIT", "SSE 输出失败: " + e.getMessage(), e); |
| | | } |
| | | } |
| | | |
| | | private void emitSafely(SseEmitter emitter, String eventName, Object payload) { |
| | | try { |
| | | emitStrict(emitter, eventName, payload); |
| | | } catch (Exception e) { |
| | | log.warn("AI SSE event emit skipped, eventName={}, message={}", eventName, e.getMessage()); |
| | | } |
| | | } |
| | | |
| | | private AiChatException buildAiException(String code, AiErrorCategory category, String stage, String message, Throwable cause) { |
| | | return new AiChatException(code, category, stage, message, cause); |
| | | } |
| | | |
| | | private boolean isClientAbortException(Throwable throwable) { |
| | | Throwable current = throwable; |
| | | while (current != null) { |
| | | String message = current.getMessage(); |
| | | if (message != null) { |
| | | String normalized = message.toLowerCase(); |
| | | if (normalized.contains("broken pipe") |
| | | || normalized.contains("connection reset") |
| | | || normalized.contains("forcibly closed") |
| | | || normalized.contains("abort")) { |
| | | return true; |
| | | } |
| | | } |
| | | current = current.getCause(); |
| | | } |
| | | return false; |
| | | } |
| | | } |