package com.vincent.rsf.server.ai.service.impl.chat;
|
|
import com.vincent.rsf.server.ai.dto.AiChatMemoryDto;
|
import com.vincent.rsf.server.ai.dto.AiChatMessageDto;
|
import com.vincent.rsf.server.ai.dto.AiChatModelOptionDto;
|
import com.vincent.rsf.server.ai.dto.AiChatRequest;
|
import com.vincent.rsf.server.ai.dto.AiChatStatusDto;
|
import com.vincent.rsf.server.ai.dto.AiResolvedConfig;
|
import com.vincent.rsf.server.ai.entity.AiCallLog;
|
import com.vincent.rsf.server.ai.entity.AiChatSession;
|
import com.vincent.rsf.server.ai.enums.AiErrorCategory;
|
import com.vincent.rsf.server.ai.exception.AiChatException;
|
import com.vincent.rsf.server.ai.service.AiCallLogService;
|
import com.vincent.rsf.server.ai.service.AiChatMemoryService;
|
import com.vincent.rsf.server.ai.service.AiConfigResolverService;
|
import com.vincent.rsf.server.ai.service.AiParamService;
|
import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory;
|
import com.vincent.rsf.server.ai.store.AiChatRateLimiter;
|
import com.vincent.rsf.server.ai.store.AiStreamStateStore;
|
import lombok.RequiredArgsConstructor;
|
import lombok.extern.slf4j.Slf4j;
|
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
|
import org.springframework.ai.chat.model.ChatResponse;
|
import org.springframework.ai.chat.prompt.Prompt;
|
import org.springframework.ai.openai.OpenAiChatModel;
|
import org.springframework.ai.tool.ToolCallback;
|
import org.springframework.stereotype.Component;
|
import org.springframework.util.StringUtils;
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
import reactor.core.publisher.Flux;
|
|
import java.time.Instant;
|
import java.util.List;
|
import java.util.concurrent.atomic.AtomicLong;
|
import java.util.concurrent.atomic.AtomicReference;
|
|
@Slf4j
|
@Component
|
@RequiredArgsConstructor
|
public class AiChatOrchestrator {
|
|
private final AiConfigResolverService aiConfigResolverService;
|
private final AiChatMemoryService aiChatMemoryService;
|
private final AiParamService aiParamService;
|
private final McpMountRuntimeFactory mcpMountRuntimeFactory;
|
private final AiCallLogService aiCallLogService;
|
private final AiChatRateLimiter aiChatRateLimiter;
|
private final AiStreamStateStore aiStreamStateStore;
|
private final AiChatRuntimeAssembler aiChatRuntimeAssembler;
|
private final AiPromptMessageBuilder aiPromptMessageBuilder;
|
private final AiOpenAiChatModelFactory aiOpenAiChatModelFactory;
|
private final AiToolObservationService aiToolObservationService;
|
private final AiSseEventPublisher aiSseEventPublisher;
|
private final AiChatFailureHandler aiChatFailureHandler;
|
|
public void executeStream(AiChatRequest request, Long userId, Long tenantId, SseEmitter emitter) {
|
String requestId = request.getRequestId();
|
long startedAt = System.currentTimeMillis();
|
AtomicReference<Long> firstTokenAtRef = new AtomicReference<>();
|
AtomicLong toolCallSequence = new AtomicLong(0);
|
AtomicLong toolSuccessCount = new AtomicLong(0);
|
AtomicLong toolFailureCount = new AtomicLong(0);
|
Long sessionId = request.getSessionId();
|
Long callLogId = null;
|
String model = null;
|
String resolvedPromptCode = request.getPromptCode();
|
AiThinkingTraceEmitter thinkingTraceEmitter = null;
|
try {
|
ensureIdentity(userId, tenantId);
|
AiResolvedConfig config = resolveConfig(request, tenantId);
|
List<AiChatModelOptionDto> modelOptions = aiParamService.listChatModelOptions(tenantId);
|
resolvedPromptCode = config.getPromptCode();
|
if (!aiChatRateLimiter.allowChatRequest(tenantId, userId, config.getPromptCode())) {
|
throw aiChatFailureHandler.buildAiException("AI_RATE_LIMITED", AiErrorCategory.REQUEST, "RATE_LIMIT",
|
"当前提问过于频繁,请稍后再试", null);
|
}
|
final String resolvedModel = config.getAiParam().getModel();
|
model = resolvedModel;
|
AiChatSession session = resolveSession(request, userId, tenantId, config.getPromptCode());
|
sessionId = session.getId();
|
aiStreamStateStore.markStreamState(requestId, tenantId, userId, sessionId, config.getPromptCode(), "RUNNING", null);
|
AiChatMemoryDto memory = loadMemory(userId, tenantId, config.getPromptCode(), session.getId());
|
List<AiChatMessageDto> mergedMessages = aiPromptMessageBuilder.mergeMessages(memory.getShortMemoryMessages(), request.getMessages());
|
AiCallLog callLog = aiCallLogService.startCallLog(
|
requestId,
|
session.getId(),
|
userId,
|
tenantId,
|
config.getPromptCode(),
|
config.getPrompt().getName(),
|
config.getAiParam().getModel(),
|
config.getMcpMounts().size(),
|
config.getMcpMounts().size(),
|
config.getMcpMounts().stream().map(item -> item.getName()).toList()
|
);
|
callLogId = callLog.getId();
|
try (McpMountRuntimeFactory.McpMountRuntime runtime = createRuntime(config, userId)) {
|
aiSseEventPublisher.emitStrict(emitter, "start", aiChatRuntimeAssembler.buildRuntimeSnapshot(
|
requestId,
|
session.getId(),
|
config,
|
modelOptions,
|
runtime.getMountedCount(),
|
runtime.getMountedNames(),
|
runtime.getErrors(),
|
memory
|
));
|
aiSseEventPublisher.emitSafely(emitter, "status", AiChatStatusDto.builder()
|
.requestId(requestId)
|
.sessionId(session.getId())
|
.status("STARTED")
|
.model(resolvedModel)
|
.timestamp(Instant.now().toEpochMilli())
|
.elapsedMs(0L)
|
.build());
|
log.info("AI chat started, requestId={}, userId={}, tenantId={}, sessionId={}, model={}",
|
requestId, userId, tenantId, session.getId(), resolvedModel);
|
thinkingTraceEmitter = new AiThinkingTraceEmitter(aiSseEventPublisher, emitter, requestId, session.getId());
|
thinkingTraceEmitter.startAnalyze();
|
AiThinkingTraceEmitter activeThinkingTraceEmitter = thinkingTraceEmitter;
|
|
ToolCallback[] observableToolCallbacks = aiToolObservationService.wrapToolCallbacks(
|
runtime.getToolCallbacks(), emitter, requestId, session.getId(), toolCallSequence,
|
toolSuccessCount, toolFailureCount, callLogId, userId, tenantId, activeThinkingTraceEmitter
|
);
|
Prompt prompt = new Prompt(
|
aiPromptMessageBuilder.buildPromptMessages(memory, mergedMessages, config.getPrompt(), request.getMetadata()),
|
aiOpenAiChatModelFactory.buildChatOptions(config.getAiParam(), observableToolCallbacks, userId, tenantId,
|
requestId, session.getId(), request.getMetadata())
|
);
|
OpenAiChatModel chatModel = aiOpenAiChatModelFactory.createChatModel(config.getAiParam());
|
if (Boolean.FALSE.equals(config.getAiParam().getStreamingEnabled())) {
|
ChatResponse response = invokeChatCall(chatModel, prompt);
|
String content = extractContent(response);
|
aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), content);
|
if (StringUtils.hasText(content)) {
|
aiSseEventPublisher.markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt, activeThinkingTraceEmitter);
|
aiSseEventPublisher.emitStrict(emitter, "delta", aiSseEventPublisher.buildMessagePayload("requestId", requestId, "content", content));
|
}
|
activeThinkingTraceEmitter.completeCurrentPhase();
|
aiSseEventPublisher.emitDone(emitter, requestId, response.getMetadata(), config.getAiParam().getModel(),
|
session.getId(), startedAt, firstTokenAtRef.get());
|
aiSseEventPublisher.emitSafely(emitter, "status",
|
aiSseEventPublisher.buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get()));
|
aiCallLogService.completeCallLog(
|
callLogId,
|
"COMPLETED",
|
System.currentTimeMillis() - startedAt,
|
aiSseEventPublisher.resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()),
|
response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getPromptTokens(),
|
response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getCompletionTokens(),
|
response.getMetadata() == null || response.getMetadata().getUsage() == null ? null : response.getMetadata().getUsage().getTotalTokens(),
|
toolSuccessCount.get(),
|
toolFailureCount.get()
|
);
|
aiStreamStateStore.markStreamState(requestId, tenantId, userId, session.getId(), config.getPromptCode(), "COMPLETED", null);
|
log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}",
|
requestId, session.getId(), System.currentTimeMillis() - startedAt,
|
aiSseEventPublisher.resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()));
|
emitter.complete();
|
return;
|
}
|
|
Flux<ChatResponse> responseFlux = invokeChatStream(chatModel, prompt);
|
AtomicReference<ChatResponseMetadata> lastMetadata = new AtomicReference<>();
|
StringBuilder assistantContent = new StringBuilder();
|
try {
|
responseFlux.doOnNext(response -> {
|
lastMetadata.set(response.getMetadata());
|
String content = extractContent(response);
|
if (StringUtils.hasText(content)) {
|
aiSseEventPublisher.markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt, activeThinkingTraceEmitter);
|
assistantContent.append(content);
|
aiSseEventPublisher.emitStrict(emitter, "delta",
|
aiSseEventPublisher.buildMessagePayload("requestId", requestId, "content", content));
|
}
|
})
|
.blockLast();
|
} catch (Exception e) {
|
throw aiChatFailureHandler.buildAiException("AI_MODEL_STREAM_ERROR", AiErrorCategory.MODEL, "MODEL_STREAM",
|
e == null ? "AI 模型流式调用失败" : e.getMessage(), e);
|
}
|
aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), assistantContent.toString());
|
activeThinkingTraceEmitter.completeCurrentPhase();
|
aiSseEventPublisher.emitDone(emitter, requestId, lastMetadata.get(), config.getAiParam().getModel(),
|
session.getId(), startedAt, firstTokenAtRef.get());
|
aiSseEventPublisher.emitSafely(emitter, "status",
|
aiSseEventPublisher.buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get()));
|
aiCallLogService.completeCallLog(
|
callLogId,
|
"COMPLETED",
|
System.currentTimeMillis() - startedAt,
|
aiSseEventPublisher.resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()),
|
lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getPromptTokens(),
|
lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getCompletionTokens(),
|
lastMetadata.get() == null || lastMetadata.get().getUsage() == null ? null : lastMetadata.get().getUsage().getTotalTokens(),
|
toolSuccessCount.get(),
|
toolFailureCount.get()
|
);
|
aiStreamStateStore.markStreamState(requestId, tenantId, userId, session.getId(), config.getPromptCode(), "COMPLETED", null);
|
log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}",
|
requestId, session.getId(), System.currentTimeMillis() - startedAt,
|
aiSseEventPublisher.resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()));
|
emitter.complete();
|
}
|
} catch (AiChatException e) {
|
aiChatFailureHandler.handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(), e,
|
callLogId, toolSuccessCount.get(), toolFailureCount.get(), thinkingTraceEmitter,
|
tenantId, userId, resolvedPromptCode);
|
} catch (Exception e) {
|
aiChatFailureHandler.handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(),
|
aiChatFailureHandler.buildAiException("AI_INTERNAL_ERROR", AiErrorCategory.INTERNAL, "INTERNAL",
|
e == null ? "AI 对话失败" : e.getMessage(), e),
|
callLogId, toolSuccessCount.get(), toolFailureCount.get(), thinkingTraceEmitter,
|
tenantId, userId, resolvedPromptCode);
|
} finally {
|
log.debug("AI chat stream finished, requestId={}", requestId);
|
}
|
}
|
|
private void ensureIdentity(Long userId, Long tenantId) {
|
if (userId == null) {
|
throw aiChatFailureHandler.buildAiException("AI_AUTH_USER_MISSING", AiErrorCategory.AUTH, "AUTH_VALIDATE", "当前登录用户不存在", null);
|
}
|
if (tenantId == null) {
|
throw aiChatFailureHandler.buildAiException("AI_AUTH_TENANT_MISSING", AiErrorCategory.AUTH, "AUTH_VALIDATE", "当前租户不存在", null);
|
}
|
}
|
|
private AiResolvedConfig resolveConfig(AiChatRequest request, Long tenantId) {
|
try {
|
return aiConfigResolverService.resolve(request.getPromptCode(), tenantId, request.getAiParamId());
|
} catch (Exception e) {
|
throw aiChatFailureHandler.buildAiException("AI_CONFIG_RESOLVE_ERROR", AiErrorCategory.CONFIG, "CONFIG_RESOLVE",
|
e == null ? "AI 配置解析失败" : e.getMessage(), e);
|
}
|
}
|
|
private AiChatSession resolveSession(AiChatRequest request, Long userId, Long tenantId, String promptCode) {
|
try {
|
return aiChatMemoryService.resolveSession(userId, tenantId, promptCode, request.getSessionId(),
|
aiPromptMessageBuilder.resolveTitleSeed(request.getMessages()));
|
} catch (Exception e) {
|
throw aiChatFailureHandler.buildAiException("AI_SESSION_RESOLVE_ERROR", AiErrorCategory.REQUEST, "SESSION_RESOLVE",
|
e == null ? "AI 会话解析失败" : e.getMessage(), e);
|
}
|
}
|
|
private AiChatMemoryDto loadMemory(Long userId, Long tenantId, String promptCode, Long sessionId) {
|
try {
|
return aiChatMemoryService.getMemory(userId, tenantId, promptCode, sessionId);
|
} catch (Exception e) {
|
throw aiChatFailureHandler.buildAiException("AI_MEMORY_LOAD_ERROR", AiErrorCategory.REQUEST, "MEMORY_LOAD",
|
e == null ? "AI 会话记忆加载失败" : e.getMessage(), e);
|
}
|
}
|
|
private McpMountRuntimeFactory.McpMountRuntime createRuntime(AiResolvedConfig config, Long userId) {
|
try {
|
return mcpMountRuntimeFactory.create(config.getMcpMounts(), userId);
|
} catch (Exception e) {
|
throw aiChatFailureHandler.buildAiException("AI_MCP_MOUNT_ERROR", AiErrorCategory.MCP, "MCP_MOUNT",
|
e == null ? "MCP 挂载失败" : e.getMessage(), e);
|
}
|
}
|
|
private ChatResponse invokeChatCall(OpenAiChatModel chatModel, Prompt prompt) {
|
try {
|
return chatModel.call(prompt);
|
} catch (Exception e) {
|
throw aiChatFailureHandler.buildAiException("AI_MODEL_CALL_ERROR", AiErrorCategory.MODEL, "MODEL_CALL",
|
e == null ? "AI 模型调用失败" : e.getMessage(), e);
|
}
|
}
|
|
private Flux<ChatResponse> invokeChatStream(OpenAiChatModel chatModel, Prompt prompt) {
|
try {
|
return chatModel.stream(prompt);
|
} catch (Exception e) {
|
throw aiChatFailureHandler.buildAiException("AI_MODEL_STREAM_ERROR", AiErrorCategory.MODEL, "MODEL_STREAM_INIT",
|
e == null ? "AI 模型流式调用失败" : e.getMessage(), e);
|
}
|
}
|
|
private String extractContent(ChatResponse response) {
|
if (response == null || response.getResult() == null || response.getResult().getOutput() == null) {
|
return null;
|
}
|
return response.getResult().getOutput().getText();
|
}
|
}
|