From 88b3f09a702f8f8515af43bc14242ecca2a667db Mon Sep 17 00:00:00 2001
From: zhou zhou <3272660260@qq.com>
Date: 星期四, 19 三月 2026 11:10:45 +0800
Subject: [PATCH] #AI2. 对话执行链路治理

---
 rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java |  256 ++++++++++++++++++++++++++++++++++++++++++++++-----
 1 files changed, 230 insertions(+), 26 deletions(-)

diff --git a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java
index 7fc31fc..05dc09b 100644
--- a/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java
+++ b/rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java
@@ -5,15 +5,19 @@
 import com.vincent.rsf.framework.exception.CoolException;
 import com.vincent.rsf.server.ai.config.AiDefaults;
 import com.vincent.rsf.server.ai.dto.AiChatDoneDto;
+import com.vincent.rsf.server.ai.dto.AiChatErrorDto;
 import com.vincent.rsf.server.ai.dto.AiChatMemoryDto;
 import com.vincent.rsf.server.ai.dto.AiChatMessageDto;
 import com.vincent.rsf.server.ai.dto.AiChatRequest;
 import com.vincent.rsf.server.ai.dto.AiChatRuntimeDto;
+import com.vincent.rsf.server.ai.dto.AiChatStatusDto;
 import com.vincent.rsf.server.ai.dto.AiChatSessionDto;
 import com.vincent.rsf.server.ai.dto.AiResolvedConfig;
 import com.vincent.rsf.server.ai.entity.AiParam;
 import com.vincent.rsf.server.ai.entity.AiPrompt;
 import com.vincent.rsf.server.ai.entity.AiChatSession;
+import com.vincent.rsf.server.ai.enums.AiErrorCategory;
+import com.vincent.rsf.server.ai.exception.AiChatException;
 import com.vincent.rsf.server.ai.service.AiChatService;
 import com.vincent.rsf.server.ai.service.AiChatMemoryService;
 import com.vincent.rsf.server.ai.service.AiConfigResolverService;
@@ -41,6 +45,7 @@
 import org.springframework.context.support.GenericApplicationContext;
 import org.springframework.http.MediaType;
 import org.springframework.http.client.SimpleClientHttpRequestFactory;
+import org.springframework.beans.factory.annotation.Qualifier;
 import org.springframework.stereotype.Service;
 import org.springframework.util.StringUtils;
 import org.springframework.web.client.RestClient;
@@ -49,6 +54,7 @@
 import reactor.core.publisher.Flux;
 
 import java.io.IOException;
+import java.time.Instant;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.LinkedHashMap;
@@ -56,6 +62,7 @@
 import java.util.Map;
 import java.util.Objects;
 import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.Executor;
 import java.util.concurrent.atomic.AtomicReference;
 
 @Slf4j
@@ -69,12 +76,15 @@
     private final GenericApplicationContext applicationContext;
     private final ObservationRegistry observationRegistry;
     private final ObjectMapper objectMapper;
+    @Qualifier("aiChatTaskExecutor")
+    private final Executor aiChatTaskExecutor;
 
     @Override
     public AiChatRuntimeDto getRuntime(String promptCode, Long sessionId, Long userId, Long tenantId) {
-        AiResolvedConfig config = aiConfigResolverService.resolve(promptCode);
+        AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId);
         AiChatMemoryDto memory = aiChatMemoryService.getMemory(userId, tenantId, config.getPromptCode(), sessionId);
         return AiChatRuntimeDto.builder()
+                .requestId(null)
                 .sessionId(memory.getSessionId())
                 .promptCode(config.getPromptCode())
                 .promptName(config.getPrompt().getName())
@@ -89,7 +99,7 @@
 
     @Override
     public List<AiChatSessionDto> listSessions(String promptCode, Long userId, Long tenantId) {
-        AiResolvedConfig config = aiConfigResolverService.resolve(promptCode);
+        AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId);
         return aiChatMemoryService.listSessions(userId, tenantId, config.getPromptCode());
     }
 
@@ -101,18 +111,28 @@
     @Override
     public SseEmitter stream(AiChatRequest request, Long userId, Long tenantId) {
         SseEmitter emitter = new SseEmitter(AiDefaults.SSE_TIMEOUT_MS);
-        CompletableFuture.runAsync(() -> doStream(request, userId, tenantId, emitter));
+        CompletableFuture.runAsync(() -> doStream(request, userId, tenantId, emitter), aiChatTaskExecutor);
         return emitter;
     }
 
     private void doStream(AiChatRequest request, Long userId, Long tenantId, SseEmitter emitter) {
+        String requestId = request.getRequestId();
+        long startedAt = System.currentTimeMillis();
+        AtomicReference<Long> firstTokenAtRef = new AtomicReference<>();
+        Long sessionId = request.getSessionId();
+        String model = null;
         try {
-            AiResolvedConfig config = aiConfigResolverService.resolve(request.getPromptCode());
-            AiChatSession session = aiChatMemoryService.resolveSession(userId, tenantId, config.getPromptCode(), request.getSessionId(), resolveTitleSeed(request.getMessages()));
-            AiChatMemoryDto memory = aiChatMemoryService.getMemory(userId, tenantId, config.getPromptCode(), session.getId());
+            ensureIdentity(userId, tenantId);
+            AiResolvedConfig config = resolveConfig(request, tenantId);
+            final String resolvedModel = config.getAiParam().getModel();
+            model = resolvedModel;
+            AiChatSession session = resolveSession(request, userId, tenantId, config.getPromptCode());
+            sessionId = session.getId();
+            AiChatMemoryDto memory = loadMemory(userId, tenantId, config.getPromptCode(), session.getId());
             List<AiChatMessageDto> mergedMessages = mergeMessages(memory.getPersistedMessages(), request.getMessages());
-            try (McpMountRuntimeFactory.McpMountRuntime runtime = mcpMountRuntimeFactory.create(config.getMcpMounts(), userId)) {
-                emit(emitter, "start", AiChatRuntimeDto.builder()
+            try (McpMountRuntimeFactory.McpMountRuntime runtime = createRuntime(config, userId)) {
+                emitStrict(emitter, "start", AiChatRuntimeDto.builder()
+                        .requestId(requestId)
                         .sessionId(session.getId())
                         .promptCode(config.getPromptCode())
                         .promptName(config.getPrompt().getName())
@@ -123,6 +143,16 @@
                         .mountErrors(runtime.getErrors())
                         .persistedMessages(memory.getPersistedMessages())
                         .build());
+                emitSafely(emitter, "status", AiChatStatusDto.builder()
+                        .requestId(requestId)
+                        .sessionId(session.getId())
+                        .status("STARTED")
+                        .model(resolvedModel)
+                        .timestamp(Instant.now().toEpochMilli())
+                        .elapsedMs(0L)
+                        .build());
+                log.info("AI chat started, requestId={}, userId={}, tenantId={}, sessionId={}, model={}",
+                        requestId, userId, tenantId, session.getId(), resolvedModel);
 
                 Prompt prompt = new Prompt(
                         buildPromptMessages(mergedMessages, config.getPrompt(), request.getMetadata()),
@@ -130,39 +160,172 @@
                 );
                 OpenAiChatModel chatModel = createChatModel(config.getAiParam());
                 if (Boolean.FALSE.equals(config.getAiParam().getStreamingEnabled())) {
-                    ChatResponse response = chatModel.call(prompt);
+                    ChatResponse response = invokeChatCall(chatModel, prompt);
                     String content = extractContent(response);
                     aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), content);
                     if (StringUtils.hasText(content)) {
-                        emit(emitter, "delta", buildMessagePayload("content", content));
+                        markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt);
+                        emitStrict(emitter, "delta", buildMessagePayload("requestId", requestId, "content", content));
                     }
-                    emitDone(emitter, response.getMetadata(), config.getAiParam().getModel(), session.getId());
+                    emitDone(emitter, requestId, response.getMetadata(), config.getAiParam().getModel(), session.getId(), startedAt, firstTokenAtRef.get());
+                    emitSafely(emitter, "status", buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get()));
+                    log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}",
+                            requestId, session.getId(), System.currentTimeMillis() - startedAt, resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()));
                     emitter.complete();
                     return;
                 }
 
-                Flux<ChatResponse> responseFlux = chatModel.stream(prompt);
+                Flux<ChatResponse> responseFlux = invokeChatStream(chatModel, prompt);
                 AtomicReference<ChatResponseMetadata> lastMetadata = new AtomicReference<>();
                 StringBuilder assistantContent = new StringBuilder();
-                responseFlux.doOnNext(response -> {
+                try {
+                    responseFlux.doOnNext(response -> {
                             lastMetadata.set(response.getMetadata());
                             String content = extractContent(response);
                             if (StringUtils.hasText(content)) {
+                                markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt);
                                 assistantContent.append(content);
-                                emit(emitter, "delta", buildMessagePayload("content", content));
+                                emitStrict(emitter, "delta", buildMessagePayload("requestId", requestId, "content", content));
                             }
                         })
-                        .doOnError(error -> emit(emitter, "error", buildMessagePayload("message", error == null ? "AI 瀵硅瘽澶辫触" : error.getMessage())))
                         .blockLast();
+                } catch (Exception e) {
+                    throw buildAiException("AI_MODEL_STREAM_ERROR", AiErrorCategory.MODEL, "MODEL_STREAM",
+                            e == null ? "AI 妯″瀷娴佸紡璋冪敤澶辫触" : e.getMessage(), e);
+                }
                 aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), assistantContent.toString());
-                emitDone(emitter, lastMetadata.get(), config.getAiParam().getModel(), session.getId());
+                emitDone(emitter, requestId, lastMetadata.get(), config.getAiParam().getModel(), session.getId(), startedAt, firstTokenAtRef.get());
+                emitSafely(emitter, "status", buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get()));
+                log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}",
+                        requestId, session.getId(), System.currentTimeMillis() - startedAt, resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()));
                 emitter.complete();
             }
+        } catch (AiChatException e) {
+            handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(), e);
         } catch (Exception e) {
-            log.error("AI stream error", e);
-            emit(emitter, "error", buildMessagePayload("message", e == null ? "AI 瀵硅瘽澶辫触" : e.getMessage()));
-            emitter.completeWithError(e);
+            handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(),
+                    buildAiException("AI_INTERNAL_ERROR", AiErrorCategory.INTERNAL, "INTERNAL",
+                            e == null ? "AI 瀵硅瘽澶辫触" : e.getMessage(), e));
+        } finally {
+            log.debug("AI chat stream finished, requestId={}", requestId);
         }
+    }
+
+    private void ensureIdentity(Long userId, Long tenantId) {
+        if (userId == null) {
+            throw buildAiException("AI_AUTH_USER_MISSING", AiErrorCategory.AUTH, "AUTH_VALIDATE", "褰撳墠鐧诲綍鐢ㄦ埛涓嶅瓨鍦�", null);
+        }
+        if (tenantId == null) {
+            throw buildAiException("AI_AUTH_TENANT_MISSING", AiErrorCategory.AUTH, "AUTH_VALIDATE", "褰撳墠绉熸埛涓嶅瓨鍦�", null);
+        }
+    }
+
+    private AiResolvedConfig resolveConfig(AiChatRequest request, Long tenantId) {
+        try {
+            return aiConfigResolverService.resolve(request.getPromptCode(), tenantId);
+        } catch (Exception e) {
+            throw buildAiException("AI_CONFIG_RESOLVE_ERROR", AiErrorCategory.CONFIG, "CONFIG_RESOLVE",
+                    e == null ? "AI 閰嶇疆瑙f瀽澶辫触" : e.getMessage(), e);
+        }
+    }
+
+    private AiChatSession resolveSession(AiChatRequest request, Long userId, Long tenantId, String promptCode) {
+        try {
+            return aiChatMemoryService.resolveSession(userId, tenantId, promptCode, request.getSessionId(), resolveTitleSeed(request.getMessages()));
+        } catch (Exception e) {
+            throw buildAiException("AI_SESSION_RESOLVE_ERROR", AiErrorCategory.REQUEST, "SESSION_RESOLVE",
+                    e == null ? "AI 浼氳瘽瑙f瀽澶辫触" : e.getMessage(), e);
+        }
+    }
+
+    private AiChatMemoryDto loadMemory(Long userId, Long tenantId, String promptCode, Long sessionId) {
+        try {
+            return aiChatMemoryService.getMemory(userId, tenantId, promptCode, sessionId);
+        } catch (Exception e) {
+            throw buildAiException("AI_MEMORY_LOAD_ERROR", AiErrorCategory.REQUEST, "MEMORY_LOAD",
+                    e == null ? "AI 浼氳瘽璁板繂鍔犺浇澶辫触" : e.getMessage(), e);
+        }
+    }
+
+    private McpMountRuntimeFactory.McpMountRuntime createRuntime(AiResolvedConfig config, Long userId) {
+        try {
+            return mcpMountRuntimeFactory.create(config.getMcpMounts(), userId);
+        } catch (Exception e) {
+            throw buildAiException("AI_MCP_MOUNT_ERROR", AiErrorCategory.MCP, "MCP_MOUNT",
+                    e == null ? "MCP 鎸傝浇澶辫触" : e.getMessage(), e);
+        }
+    }
+
+    private ChatResponse invokeChatCall(OpenAiChatModel chatModel, Prompt prompt) {
+        try {
+            return chatModel.call(prompt);
+        } catch (Exception e) {
+            throw buildAiException("AI_MODEL_CALL_ERROR", AiErrorCategory.MODEL, "MODEL_CALL",
+                    e == null ? "AI 妯″瀷璋冪敤澶辫触" : e.getMessage(), e);
+        }
+    }
+
+    private Flux<ChatResponse> invokeChatStream(OpenAiChatModel chatModel, Prompt prompt) {
+        try {
+            return chatModel.stream(prompt);
+        } catch (Exception e) {
+            throw buildAiException("AI_MODEL_STREAM_ERROR", AiErrorCategory.MODEL, "MODEL_STREAM_INIT",
+                    e == null ? "AI 妯″瀷娴佸紡璋冪敤澶辫触" : e.getMessage(), e);
+        }
+    }
+
+    private void markFirstToken(AtomicReference<Long> firstTokenAtRef, SseEmitter emitter, String requestId, Long sessionId, String model, long startedAt) {
+        if (!firstTokenAtRef.compareAndSet(null, System.currentTimeMillis())) {
+            return;
+        }
+        emitSafely(emitter, "status", AiChatStatusDto.builder()
+                .requestId(requestId)
+                .sessionId(sessionId)
+                .status("FIRST_TOKEN")
+                .model(model)
+                .timestamp(Instant.now().toEpochMilli())
+                .elapsedMs(System.currentTimeMillis() - startedAt)
+                .firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()))
+                .build());
+    }
+
+    private AiChatStatusDto buildTerminalStatus(String requestId, Long sessionId, String status, String model, long startedAt, Long firstTokenAt) {
+        return AiChatStatusDto.builder()
+                .requestId(requestId)
+                .sessionId(sessionId)
+                .status(status)
+                .model(model)
+                .timestamp(Instant.now().toEpochMilli())
+                .elapsedMs(System.currentTimeMillis() - startedAt)
+                .firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAt))
+                .build();
+    }
+
+    private Long resolveFirstTokenLatency(long startedAt, Long firstTokenAt) {
+        return firstTokenAt == null ? null : Math.max(0L, firstTokenAt - startedAt);
+    }
+
+    private void handleStreamFailure(SseEmitter emitter, String requestId, Long sessionId, String model, long startedAt, Long firstTokenAt, AiChatException exception) {
+        if (isClientAbortException(exception)) {
+            log.warn("AI chat aborted by client, requestId={}, sessionId={}, stage={}, message={}",
+                    requestId, sessionId, exception.getStage(), exception.getMessage());
+            emitSafely(emitter, "status", buildTerminalStatus(requestId, sessionId, "ABORTED", model, startedAt, firstTokenAt));
+            emitter.completeWithError(exception);
+            return;
+        }
+        log.error("AI chat failed, requestId={}, sessionId={}, category={}, stage={}, message={}",
+                requestId, sessionId, exception.getCategory(), exception.getStage(), exception.getMessage(), exception);
+        emitSafely(emitter, "status", buildTerminalStatus(requestId, sessionId, "FAILED", model, startedAt, firstTokenAt));
+        emitSafely(emitter, "error", AiChatErrorDto.builder()
+                .requestId(requestId)
+                .sessionId(sessionId)
+                .code(exception.getCode())
+                .category(exception.getCategory().name())
+                .stage(exception.getStage())
+                .message(exception.getMessage())
+                .timestamp(Instant.now().toEpochMilli())
+                .build());
+        emitter.completeWithError(exception);
     }
 
     private OpenAiChatModel createChatModel(AiParam aiParam) {
@@ -203,7 +366,7 @@
 
     private OpenAiChatOptions buildChatOptions(AiParam aiParam, ToolCallback[] toolCallbacks, Long userId, Map<String, Object> metadata) {
         if (userId == null) {
-            throw new CoolException("褰撳墠鐧诲綍鐢ㄦ埛涓嶅瓨鍦�");
+            throw buildAiException("AI_AUTH_USER_MISSING", AiErrorCategory.AUTH, "OPTIONS_BUILD", "褰撳墠鐧诲綍鐢ㄦ埛涓嶅瓨鍦�", null);
         }
         OpenAiChatOptions.Builder builder = OpenAiChatOptions.builder()
                 .model(aiParam.getModel())
@@ -319,31 +482,72 @@
         return response.getResult().getOutput().getText();
     }
 
-    private void emitDone(SseEmitter emitter, ChatResponseMetadata metadata, String fallbackModel, Long sessionId) {
+    private void emitDone(SseEmitter emitter, String requestId, ChatResponseMetadata metadata, String fallbackModel, Long sessionId, long startedAt, Long firstTokenAt) {
         Usage usage = metadata == null ? null : metadata.getUsage();
-        emit(emitter, "done", AiChatDoneDto.builder()
+        emitStrict(emitter, "done", AiChatDoneDto.builder()
+                .requestId(requestId)
                 .sessionId(sessionId)
                 .model(metadata != null && StringUtils.hasText(metadata.getModel()) ? metadata.getModel() : fallbackModel)
+                .elapsedMs(System.currentTimeMillis() - startedAt)
+                .firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAt))
                 .promptTokens(usage == null ? null : usage.getPromptTokens())
                 .completionTokens(usage == null ? null : usage.getCompletionTokens())
                 .totalTokens(usage == null ? null : usage.getTotalTokens())
                 .build());
     }
 
-    private Map<String, String> buildMessagePayload(String key, String value) {
+    private Map<String, String> buildMessagePayload(String... keyValues) {
         Map<String, String> payload = new LinkedHashMap<>();
-        payload.put(key, value == null ? "" : value);
+        if (keyValues == null || keyValues.length == 0) {
+            return payload;
+        }
+        if (keyValues.length % 2 != 0) {
+            throw new CoolException("娑堟伅杞借嵎鍙傛暟蹇呴』鎴愬鍑虹幇");
+        }
+        for (int i = 0; i < keyValues.length; i += 2) {
+            payload.put(keyValues[i], keyValues[i + 1] == null ? "" : keyValues[i + 1]);
+        }
         return payload;
     }
 
-    private void emit(SseEmitter emitter, String eventName, Object payload) {
+    private void emitStrict(SseEmitter emitter, String eventName, Object payload) {
         try {
             String data = objectMapper.writeValueAsString(payload);
             emitter.send(SseEmitter.event()
                     .name(eventName)
                     .data(data, MediaType.APPLICATION_JSON));
         } catch (IOException e) {
-            throw new CoolException("SSE 杈撳嚭澶辫触: " + e.getMessage());
+            throw buildAiException("AI_SSE_EMIT_ERROR", AiErrorCategory.STREAM, "SSE_EMIT", "SSE 杈撳嚭澶辫触: " + e.getMessage(), e);
         }
     }
+
+    private void emitSafely(SseEmitter emitter, String eventName, Object payload) {
+        try {
+            emitStrict(emitter, eventName, payload);
+        } catch (Exception e) {
+            log.warn("AI SSE event emit skipped, eventName={}, message={}", eventName, e.getMessage());
+        }
+    }
+
+    private AiChatException buildAiException(String code, AiErrorCategory category, String stage, String message, Throwable cause) {
+        return new AiChatException(code, category, stage, message, cause);
+    }
+
+    private boolean isClientAbortException(Throwable throwable) {
+        Throwable current = throwable;
+        while (current != null) {
+            String message = current.getMessage();
+            if (message != null) {
+                String normalized = message.toLowerCase();
+                if (normalized.contains("broken pipe")
+                        || normalized.contains("connection reset")
+                        || normalized.contains("forcibly closed")
+                        || normalized.contains("abort")) {
+                    return true;
+                }
+            }
+            current = current.getCause();
+        }
+        return false;
+    }
 }

--
Gitblit v1.9.1