package com.vincent.rsf.server.ai.service.impl;
|
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.vincent.rsf.framework.common.Cools;
|
import com.vincent.rsf.framework.exception.CoolException;
|
import com.vincent.rsf.server.ai.config.AiDefaults;
|
import com.vincent.rsf.server.ai.dto.AiChatDoneDto;
|
import com.vincent.rsf.server.ai.dto.AiChatErrorDto;
|
import com.vincent.rsf.server.ai.dto.AiChatMemoryDto;
|
import com.vincent.rsf.server.ai.dto.AiChatMessageDto;
|
import com.vincent.rsf.server.ai.dto.AiChatRequest;
|
import com.vincent.rsf.server.ai.dto.AiChatRuntimeDto;
|
import com.vincent.rsf.server.ai.dto.AiChatStatusDto;
|
import com.vincent.rsf.server.ai.dto.AiChatSessionDto;
|
import com.vincent.rsf.server.ai.dto.AiChatSessionPinRequest;
|
import com.vincent.rsf.server.ai.dto.AiChatSessionRenameRequest;
|
import com.vincent.rsf.server.ai.dto.AiResolvedConfig;
|
import com.vincent.rsf.server.ai.entity.AiParam;
|
import com.vincent.rsf.server.ai.entity.AiPrompt;
|
import com.vincent.rsf.server.ai.entity.AiChatSession;
|
import com.vincent.rsf.server.ai.enums.AiErrorCategory;
|
import com.vincent.rsf.server.ai.exception.AiChatException;
|
import com.vincent.rsf.server.ai.service.AiChatService;
|
import com.vincent.rsf.server.ai.service.AiChatMemoryService;
|
import com.vincent.rsf.server.ai.service.AiConfigResolverService;
|
import com.vincent.rsf.server.ai.service.McpMountRuntimeFactory;
|
import io.micrometer.observation.ObservationRegistry;
|
import lombok.RequiredArgsConstructor;
|
import lombok.extern.slf4j.Slf4j;
|
import org.springframework.ai.chat.messages.AssistantMessage;
|
import org.springframework.ai.chat.messages.Message;
|
import org.springframework.ai.chat.messages.SystemMessage;
|
import org.springframework.ai.chat.messages.UserMessage;
|
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
|
import org.springframework.ai.chat.metadata.Usage;
|
import org.springframework.ai.chat.model.ChatResponse;
|
import org.springframework.ai.chat.prompt.Prompt;
|
import org.springframework.ai.model.tool.DefaultToolCallingManager;
|
import org.springframework.ai.model.tool.ToolCallingManager;
|
import org.springframework.ai.openai.OpenAiChatModel;
|
import org.springframework.ai.openai.OpenAiChatOptions;
|
import org.springframework.ai.openai.api.OpenAiApi;
|
import org.springframework.ai.tool.ToolCallback;
|
import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor;
|
import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver;
|
import org.springframework.ai.util.json.schema.SchemaType;
|
import org.springframework.context.support.GenericApplicationContext;
|
import org.springframework.http.MediaType;
|
import org.springframework.http.client.SimpleClientHttpRequestFactory;
|
import org.springframework.beans.factory.annotation.Qualifier;
|
import org.springframework.stereotype.Service;
|
import org.springframework.util.StringUtils;
|
import org.springframework.web.client.RestClient;
|
import org.springframework.web.reactive.function.client.WebClient;
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
import reactor.core.publisher.Flux;
|
|
import java.io.IOException;
|
import java.time.Instant;
|
import java.util.ArrayList;
|
import java.util.Arrays;
|
import java.util.LinkedHashMap;
|
import java.util.List;
|
import java.util.Map;
|
import java.util.Objects;
|
import java.util.concurrent.CompletableFuture;
|
import java.util.concurrent.Executor;
|
import java.util.concurrent.atomic.AtomicReference;
|
|
@Slf4j
|
@Service
|
@RequiredArgsConstructor
|
public class AiChatServiceImpl implements AiChatService {
|
|
private final AiConfigResolverService aiConfigResolverService;
|
private final AiChatMemoryService aiChatMemoryService;
|
private final McpMountRuntimeFactory mcpMountRuntimeFactory;
|
private final GenericApplicationContext applicationContext;
|
private final ObservationRegistry observationRegistry;
|
private final ObjectMapper objectMapper;
|
@Qualifier("aiChatTaskExecutor")
|
private final Executor aiChatTaskExecutor;
|
|
@Override
|
public AiChatRuntimeDto getRuntime(String promptCode, Long sessionId, Long userId, Long tenantId) {
|
AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId);
|
AiChatMemoryDto memory = aiChatMemoryService.getMemory(userId, tenantId, config.getPromptCode(), sessionId);
|
return AiChatRuntimeDto.builder()
|
.requestId(null)
|
.sessionId(memory.getSessionId())
|
.promptCode(config.getPromptCode())
|
.promptName(config.getPrompt().getName())
|
.model(config.getAiParam().getModel())
|
.configuredMcpCount(config.getMcpMounts().size())
|
.mountedMcpCount(config.getMcpMounts().size())
|
.mountedMcpNames(config.getMcpMounts().stream().map(item -> item.getName()).toList())
|
.mountErrors(List.of())
|
.memorySummary(memory.getMemorySummary())
|
.memoryFacts(memory.getMemoryFacts())
|
.recentMessageCount(memory.getRecentMessageCount())
|
.persistedMessages(memory.getPersistedMessages())
|
.build();
|
}
|
|
@Override
|
public List<AiChatSessionDto> listSessions(String promptCode, String keyword, Long userId, Long tenantId) {
|
AiResolvedConfig config = aiConfigResolverService.resolve(promptCode, tenantId);
|
return aiChatMemoryService.listSessions(userId, tenantId, config.getPromptCode(), keyword);
|
}
|
|
@Override
|
public void removeSession(Long sessionId, Long userId, Long tenantId) {
|
aiChatMemoryService.removeSession(userId, tenantId, sessionId);
|
}
|
|
@Override
|
public AiChatSessionDto renameSession(Long sessionId, AiChatSessionRenameRequest request, Long userId, Long tenantId) {
|
return aiChatMemoryService.renameSession(userId, tenantId, sessionId, request);
|
}
|
|
@Override
|
public AiChatSessionDto pinSession(Long sessionId, AiChatSessionPinRequest request, Long userId, Long tenantId) {
|
return aiChatMemoryService.pinSession(userId, tenantId, sessionId, request);
|
}
|
|
@Override
|
public void clearSessionMemory(Long sessionId, Long userId, Long tenantId) {
|
aiChatMemoryService.clearSessionMemory(userId, tenantId, sessionId);
|
}
|
|
@Override
|
public void retainLatestRound(Long sessionId, Long userId, Long tenantId) {
|
aiChatMemoryService.retainLatestRound(userId, tenantId, sessionId);
|
}
|
|
@Override
|
public SseEmitter stream(AiChatRequest request, Long userId, Long tenantId) {
|
SseEmitter emitter = new SseEmitter(AiDefaults.SSE_TIMEOUT_MS);
|
CompletableFuture.runAsync(() -> doStream(request, userId, tenantId, emitter), aiChatTaskExecutor);
|
return emitter;
|
}
|
|
private void doStream(AiChatRequest request, Long userId, Long tenantId, SseEmitter emitter) {
|
String requestId = request.getRequestId();
|
long startedAt = System.currentTimeMillis();
|
AtomicReference<Long> firstTokenAtRef = new AtomicReference<>();
|
Long sessionId = request.getSessionId();
|
String model = null;
|
try {
|
ensureIdentity(userId, tenantId);
|
AiResolvedConfig config = resolveConfig(request, tenantId);
|
final String resolvedModel = config.getAiParam().getModel();
|
model = resolvedModel;
|
AiChatSession session = resolveSession(request, userId, tenantId, config.getPromptCode());
|
sessionId = session.getId();
|
AiChatMemoryDto memory = loadMemory(userId, tenantId, config.getPromptCode(), session.getId());
|
List<AiChatMessageDto> mergedMessages = mergeMessages(memory.getShortMemoryMessages(), request.getMessages());
|
try (McpMountRuntimeFactory.McpMountRuntime runtime = createRuntime(config, userId)) {
|
emitStrict(emitter, "start", AiChatRuntimeDto.builder()
|
.requestId(requestId)
|
.sessionId(session.getId())
|
.promptCode(config.getPromptCode())
|
.promptName(config.getPrompt().getName())
|
.model(config.getAiParam().getModel())
|
.configuredMcpCount(config.getMcpMounts().size())
|
.mountedMcpCount(runtime.getMountedCount())
|
.mountedMcpNames(runtime.getMountedNames())
|
.mountErrors(runtime.getErrors())
|
.memorySummary(memory.getMemorySummary())
|
.memoryFacts(memory.getMemoryFacts())
|
.recentMessageCount(memory.getRecentMessageCount())
|
.persistedMessages(memory.getPersistedMessages())
|
.build());
|
emitSafely(emitter, "status", AiChatStatusDto.builder()
|
.requestId(requestId)
|
.sessionId(session.getId())
|
.status("STARTED")
|
.model(resolvedModel)
|
.timestamp(Instant.now().toEpochMilli())
|
.elapsedMs(0L)
|
.build());
|
log.info("AI chat started, requestId={}, userId={}, tenantId={}, sessionId={}, model={}",
|
requestId, userId, tenantId, session.getId(), resolvedModel);
|
|
Prompt prompt = new Prompt(
|
buildPromptMessages(memory, mergedMessages, config.getPrompt(), request.getMetadata()),
|
buildChatOptions(config.getAiParam(), runtime.getToolCallbacks(), userId, request.getMetadata())
|
);
|
OpenAiChatModel chatModel = createChatModel(config.getAiParam());
|
if (Boolean.FALSE.equals(config.getAiParam().getStreamingEnabled())) {
|
ChatResponse response = invokeChatCall(chatModel, prompt);
|
String content = extractContent(response);
|
aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), content);
|
if (StringUtils.hasText(content)) {
|
markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt);
|
emitStrict(emitter, "delta", buildMessagePayload("requestId", requestId, "content", content));
|
}
|
emitDone(emitter, requestId, response.getMetadata(), config.getAiParam().getModel(), session.getId(), startedAt, firstTokenAtRef.get());
|
emitSafely(emitter, "status", buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get()));
|
log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}",
|
requestId, session.getId(), System.currentTimeMillis() - startedAt, resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()));
|
emitter.complete();
|
return;
|
}
|
|
Flux<ChatResponse> responseFlux = invokeChatStream(chatModel, prompt);
|
AtomicReference<ChatResponseMetadata> lastMetadata = new AtomicReference<>();
|
StringBuilder assistantContent = new StringBuilder();
|
try {
|
responseFlux.doOnNext(response -> {
|
lastMetadata.set(response.getMetadata());
|
String content = extractContent(response);
|
if (StringUtils.hasText(content)) {
|
markFirstToken(firstTokenAtRef, emitter, requestId, session.getId(), resolvedModel, startedAt);
|
assistantContent.append(content);
|
emitStrict(emitter, "delta", buildMessagePayload("requestId", requestId, "content", content));
|
}
|
})
|
.blockLast();
|
} catch (Exception e) {
|
throw buildAiException("AI_MODEL_STREAM_ERROR", AiErrorCategory.MODEL, "MODEL_STREAM",
|
e == null ? "AI 模型流式调用失败" : e.getMessage(), e);
|
}
|
aiChatMemoryService.saveRound(session, userId, tenantId, request.getMessages(), assistantContent.toString());
|
emitDone(emitter, requestId, lastMetadata.get(), config.getAiParam().getModel(), session.getId(), startedAt, firstTokenAtRef.get());
|
emitSafely(emitter, "status", buildTerminalStatus(requestId, session.getId(), "COMPLETED", resolvedModel, startedAt, firstTokenAtRef.get()));
|
log.info("AI chat completed, requestId={}, sessionId={}, elapsedMs={}, firstTokenLatencyMs={}",
|
requestId, session.getId(), System.currentTimeMillis() - startedAt, resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()));
|
emitter.complete();
|
}
|
} catch (AiChatException e) {
|
handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(), e);
|
} catch (Exception e) {
|
handleStreamFailure(emitter, requestId, sessionId, model, startedAt, firstTokenAtRef.get(),
|
buildAiException("AI_INTERNAL_ERROR", AiErrorCategory.INTERNAL, "INTERNAL",
|
e == null ? "AI 对话失败" : e.getMessage(), e));
|
} finally {
|
log.debug("AI chat stream finished, requestId={}", requestId);
|
}
|
}
|
|
private void ensureIdentity(Long userId, Long tenantId) {
|
if (userId == null) {
|
throw buildAiException("AI_AUTH_USER_MISSING", AiErrorCategory.AUTH, "AUTH_VALIDATE", "当前登录用户不存在", null);
|
}
|
if (tenantId == null) {
|
throw buildAiException("AI_AUTH_TENANT_MISSING", AiErrorCategory.AUTH, "AUTH_VALIDATE", "当前租户不存在", null);
|
}
|
}
|
|
private AiResolvedConfig resolveConfig(AiChatRequest request, Long tenantId) {
|
try {
|
return aiConfigResolverService.resolve(request.getPromptCode(), tenantId);
|
} catch (Exception e) {
|
throw buildAiException("AI_CONFIG_RESOLVE_ERROR", AiErrorCategory.CONFIG, "CONFIG_RESOLVE",
|
e == null ? "AI 配置解析失败" : e.getMessage(), e);
|
}
|
}
|
|
private AiChatSession resolveSession(AiChatRequest request, Long userId, Long tenantId, String promptCode) {
|
try {
|
return aiChatMemoryService.resolveSession(userId, tenantId, promptCode, request.getSessionId(), resolveTitleSeed(request.getMessages()));
|
} catch (Exception e) {
|
throw buildAiException("AI_SESSION_RESOLVE_ERROR", AiErrorCategory.REQUEST, "SESSION_RESOLVE",
|
e == null ? "AI 会话解析失败" : e.getMessage(), e);
|
}
|
}
|
|
private AiChatMemoryDto loadMemory(Long userId, Long tenantId, String promptCode, Long sessionId) {
|
try {
|
return aiChatMemoryService.getMemory(userId, tenantId, promptCode, sessionId);
|
} catch (Exception e) {
|
throw buildAiException("AI_MEMORY_LOAD_ERROR", AiErrorCategory.REQUEST, "MEMORY_LOAD",
|
e == null ? "AI 会话记忆加载失败" : e.getMessage(), e);
|
}
|
}
|
|
private McpMountRuntimeFactory.McpMountRuntime createRuntime(AiResolvedConfig config, Long userId) {
|
try {
|
return mcpMountRuntimeFactory.create(config.getMcpMounts(), userId);
|
} catch (Exception e) {
|
throw buildAiException("AI_MCP_MOUNT_ERROR", AiErrorCategory.MCP, "MCP_MOUNT",
|
e == null ? "MCP 挂载失败" : e.getMessage(), e);
|
}
|
}
|
|
private ChatResponse invokeChatCall(OpenAiChatModel chatModel, Prompt prompt) {
|
try {
|
return chatModel.call(prompt);
|
} catch (Exception e) {
|
throw buildAiException("AI_MODEL_CALL_ERROR", AiErrorCategory.MODEL, "MODEL_CALL",
|
e == null ? "AI 模型调用失败" : e.getMessage(), e);
|
}
|
}
|
|
private Flux<ChatResponse> invokeChatStream(OpenAiChatModel chatModel, Prompt prompt) {
|
try {
|
return chatModel.stream(prompt);
|
} catch (Exception e) {
|
throw buildAiException("AI_MODEL_STREAM_ERROR", AiErrorCategory.MODEL, "MODEL_STREAM_INIT",
|
e == null ? "AI 模型流式调用失败" : e.getMessage(), e);
|
}
|
}
|
|
private void markFirstToken(AtomicReference<Long> firstTokenAtRef, SseEmitter emitter, String requestId, Long sessionId, String model, long startedAt) {
|
if (!firstTokenAtRef.compareAndSet(null, System.currentTimeMillis())) {
|
return;
|
}
|
emitSafely(emitter, "status", AiChatStatusDto.builder()
|
.requestId(requestId)
|
.sessionId(sessionId)
|
.status("FIRST_TOKEN")
|
.model(model)
|
.timestamp(Instant.now().toEpochMilli())
|
.elapsedMs(System.currentTimeMillis() - startedAt)
|
.firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAtRef.get()))
|
.build());
|
}
|
|
private AiChatStatusDto buildTerminalStatus(String requestId, Long sessionId, String status, String model, long startedAt, Long firstTokenAt) {
|
return AiChatStatusDto.builder()
|
.requestId(requestId)
|
.sessionId(sessionId)
|
.status(status)
|
.model(model)
|
.timestamp(Instant.now().toEpochMilli())
|
.elapsedMs(System.currentTimeMillis() - startedAt)
|
.firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAt))
|
.build();
|
}
|
|
private Long resolveFirstTokenLatency(long startedAt, Long firstTokenAt) {
|
return firstTokenAt == null ? null : Math.max(0L, firstTokenAt - startedAt);
|
}
|
|
private void handleStreamFailure(SseEmitter emitter, String requestId, Long sessionId, String model, long startedAt, Long firstTokenAt, AiChatException exception) {
|
if (isClientAbortException(exception)) {
|
log.warn("AI chat aborted by client, requestId={}, sessionId={}, stage={}, message={}",
|
requestId, sessionId, exception.getStage(), exception.getMessage());
|
emitSafely(emitter, "status", buildTerminalStatus(requestId, sessionId, "ABORTED", model, startedAt, firstTokenAt));
|
emitter.completeWithError(exception);
|
return;
|
}
|
log.error("AI chat failed, requestId={}, sessionId={}, category={}, stage={}, message={}",
|
requestId, sessionId, exception.getCategory(), exception.getStage(), exception.getMessage(), exception);
|
emitSafely(emitter, "status", buildTerminalStatus(requestId, sessionId, "FAILED", model, startedAt, firstTokenAt));
|
emitSafely(emitter, "error", AiChatErrorDto.builder()
|
.requestId(requestId)
|
.sessionId(sessionId)
|
.code(exception.getCode())
|
.category(exception.getCategory().name())
|
.stage(exception.getStage())
|
.message(exception.getMessage())
|
.timestamp(Instant.now().toEpochMilli())
|
.build());
|
emitter.completeWithError(exception);
|
}
|
|
private OpenAiChatModel createChatModel(AiParam aiParam) {
|
OpenAiApi openAiApi = buildOpenAiApi(aiParam);
|
ToolCallingManager toolCallingManager = DefaultToolCallingManager.builder()
|
.observationRegistry(observationRegistry)
|
.toolCallbackResolver(new SpringBeanToolCallbackResolver(applicationContext, SchemaType.OPEN_API_SCHEMA))
|
.toolExecutionExceptionProcessor(new DefaultToolExecutionExceptionProcessor(false))
|
.build();
|
return new OpenAiChatModel(
|
openAiApi,
|
OpenAiChatOptions.builder()
|
.model(aiParam.getModel())
|
.temperature(aiParam.getTemperature())
|
.topP(aiParam.getTopP())
|
.maxTokens(aiParam.getMaxTokens())
|
.streamUsage(true)
|
.build(),
|
toolCallingManager,
|
org.springframework.retry.support.RetryTemplate.builder().maxAttempts(1).build(),
|
observationRegistry
|
);
|
}
|
|
private OpenAiApi buildOpenAiApi(AiParam aiParam) {
|
int timeoutMs = aiParam.getTimeoutMs() == null ? AiDefaults.DEFAULT_TIMEOUT_MS : aiParam.getTimeoutMs();
|
SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory();
|
requestFactory.setConnectTimeout(timeoutMs);
|
requestFactory.setReadTimeout(timeoutMs);
|
|
return OpenAiApi.builder()
|
.baseUrl(aiParam.getBaseUrl())
|
.apiKey(aiParam.getApiKey())
|
.restClientBuilder(RestClient.builder().requestFactory(requestFactory))
|
.webClientBuilder(WebClient.builder())
|
.build();
|
}
|
|
private OpenAiChatOptions buildChatOptions(AiParam aiParam, ToolCallback[] toolCallbacks, Long userId, Map<String, Object> metadata) {
|
if (userId == null) {
|
throw buildAiException("AI_AUTH_USER_MISSING", AiErrorCategory.AUTH, "OPTIONS_BUILD", "当前登录用户不存在", null);
|
}
|
OpenAiChatOptions.Builder builder = OpenAiChatOptions.builder()
|
.model(aiParam.getModel())
|
.temperature(aiParam.getTemperature())
|
.topP(aiParam.getTopP())
|
.maxTokens(aiParam.getMaxTokens())
|
.streamUsage(true)
|
.user(String.valueOf(userId));
|
if (!Cools.isEmpty(toolCallbacks)) {
|
builder.toolCallbacks(Arrays.asList(toolCallbacks));
|
}
|
Map<String, String> metadataMap = new LinkedHashMap<>();
|
if (metadata != null) {
|
metadata.forEach((key, value) -> metadataMap.put(key, value == null ? "" : String.valueOf(value)));
|
}
|
if (!metadataMap.isEmpty()) {
|
builder.metadata(metadataMap);
|
}
|
return builder.build();
|
}
|
|
private List<Message> buildPromptMessages(AiChatMemoryDto memory, List<AiChatMessageDto> sourceMessages, AiPrompt aiPrompt, Map<String, Object> metadata) {
|
if (Cools.isEmpty(sourceMessages)) {
|
throw new CoolException("对话消息不能为空");
|
}
|
List<Message> messages = new ArrayList<>();
|
if (StringUtils.hasText(aiPrompt.getSystemPrompt())) {
|
messages.add(new SystemMessage(aiPrompt.getSystemPrompt()));
|
}
|
if (memory != null && StringUtils.hasText(memory.getMemorySummary())) {
|
messages.add(new SystemMessage("历史摘要:\n" + memory.getMemorySummary()));
|
}
|
if (memory != null && StringUtils.hasText(memory.getMemoryFacts())) {
|
messages.add(new SystemMessage("关键事实:\n" + memory.getMemoryFacts()));
|
}
|
int lastUserIndex = -1;
|
for (int i = 0; i < sourceMessages.size(); i++) {
|
AiChatMessageDto item = sourceMessages.get(i);
|
if (item != null && "user".equalsIgnoreCase(item.getRole())) {
|
lastUserIndex = i;
|
}
|
}
|
for (int i = 0; i < sourceMessages.size(); i++) {
|
AiChatMessageDto item = sourceMessages.get(i);
|
if (item == null || !StringUtils.hasText(item.getContent())) {
|
continue;
|
}
|
String role = item.getRole() == null ? "user" : item.getRole().toLowerCase();
|
if ("system".equals(role)) {
|
continue;
|
}
|
String content = item.getContent();
|
if ("user".equals(role) && i == lastUserIndex) {
|
content = renderUserPrompt(aiPrompt.getUserPromptTemplate(), content, metadata);
|
}
|
if ("assistant".equals(role)) {
|
messages.add(new AssistantMessage(content));
|
} else {
|
messages.add(new UserMessage(content));
|
}
|
}
|
if (messages.stream().noneMatch(item -> item instanceof UserMessage)) {
|
throw new CoolException("至少需要一条用户消息");
|
}
|
return messages;
|
}
|
|
private List<AiChatMessageDto> mergeMessages(List<AiChatMessageDto> persistedMessages, List<AiChatMessageDto> memoryMessages) {
|
List<AiChatMessageDto> merged = new ArrayList<>();
|
if (!Cools.isEmpty(persistedMessages)) {
|
merged.addAll(persistedMessages);
|
}
|
if (!Cools.isEmpty(memoryMessages)) {
|
merged.addAll(memoryMessages);
|
}
|
if (merged.isEmpty()) {
|
throw new CoolException("对话消息不能为空");
|
}
|
return merged;
|
}
|
|
private String resolveTitleSeed(List<AiChatMessageDto> messages) {
|
if (Cools.isEmpty(messages)) {
|
throw new CoolException("对话消息不能为空");
|
}
|
for (int i = messages.size() - 1; i >= 0; i--) {
|
AiChatMessageDto item = messages.get(i);
|
if (item != null && "user".equalsIgnoreCase(item.getRole()) && StringUtils.hasText(item.getContent())) {
|
return item.getContent();
|
}
|
}
|
throw new CoolException("至少需要一条用户消息");
|
}
|
|
private String renderUserPrompt(String userPromptTemplate, String content, Map<String, Object> metadata) {
|
if (!StringUtils.hasText(userPromptTemplate)) {
|
return content;
|
}
|
String rendered = userPromptTemplate
|
.replace("{{input}}", content)
|
.replace("{input}", content);
|
if (metadata != null) {
|
for (Map.Entry<String, Object> entry : metadata.entrySet()) {
|
String value = entry.getValue() == null ? "" : String.valueOf(entry.getValue());
|
rendered = rendered.replace("{{" + entry.getKey() + "}}", value);
|
rendered = rendered.replace("{" + entry.getKey() + "}", value);
|
}
|
}
|
if (Objects.equals(rendered, userPromptTemplate)) {
|
return userPromptTemplate + "\n\n" + content;
|
}
|
return rendered;
|
}
|
|
private String extractContent(ChatResponse response) {
|
if (response == null || response.getResult() == null || response.getResult().getOutput() == null) {
|
return null;
|
}
|
return response.getResult().getOutput().getText();
|
}
|
|
private void emitDone(SseEmitter emitter, String requestId, ChatResponseMetadata metadata, String fallbackModel, Long sessionId, long startedAt, Long firstTokenAt) {
|
Usage usage = metadata == null ? null : metadata.getUsage();
|
emitStrict(emitter, "done", AiChatDoneDto.builder()
|
.requestId(requestId)
|
.sessionId(sessionId)
|
.model(metadata != null && StringUtils.hasText(metadata.getModel()) ? metadata.getModel() : fallbackModel)
|
.elapsedMs(System.currentTimeMillis() - startedAt)
|
.firstTokenLatencyMs(resolveFirstTokenLatency(startedAt, firstTokenAt))
|
.promptTokens(usage == null ? null : usage.getPromptTokens())
|
.completionTokens(usage == null ? null : usage.getCompletionTokens())
|
.totalTokens(usage == null ? null : usage.getTotalTokens())
|
.build());
|
}
|
|
private Map<String, String> buildMessagePayload(String... keyValues) {
|
Map<String, String> payload = new LinkedHashMap<>();
|
if (keyValues == null || keyValues.length == 0) {
|
return payload;
|
}
|
if (keyValues.length % 2 != 0) {
|
throw new CoolException("消息载荷参数必须成对出现");
|
}
|
for (int i = 0; i < keyValues.length; i += 2) {
|
payload.put(keyValues[i], keyValues[i + 1] == null ? "" : keyValues[i + 1]);
|
}
|
return payload;
|
}
|
|
private void emitStrict(SseEmitter emitter, String eventName, Object payload) {
|
try {
|
String data = objectMapper.writeValueAsString(payload);
|
emitter.send(SseEmitter.event()
|
.name(eventName)
|
.data(data, MediaType.APPLICATION_JSON));
|
} catch (IOException e) {
|
throw buildAiException("AI_SSE_EMIT_ERROR", AiErrorCategory.STREAM, "SSE_EMIT", "SSE 输出失败: " + e.getMessage(), e);
|
}
|
}
|
|
private void emitSafely(SseEmitter emitter, String eventName, Object payload) {
|
try {
|
emitStrict(emitter, eventName, payload);
|
} catch (Exception e) {
|
log.warn("AI SSE event emit skipped, eventName={}, message={}", eventName, e.getMessage());
|
}
|
}
|
|
private AiChatException buildAiException(String code, AiErrorCategory category, String stage, String message, Throwable cause) {
|
return new AiChatException(code, category, stage, message, cause);
|
}
|
|
private boolean isClientAbortException(Throwable throwable) {
|
Throwable current = throwable;
|
while (current != null) {
|
String message = current.getMessage();
|
if (message != null) {
|
String normalized = message.toLowerCase();
|
if (normalized.contains("broken pipe")
|
|| normalized.contains("connection reset")
|
|| normalized.contains("forcibly closed")
|
|| normalized.contains("abort")) {
|
return true;
|
}
|
}
|
current = current.getCause();
|
}
|
return false;
|
}
|
}
|