package com.vincent.rsf.server.ai.controller; import com.fasterxml.jackson.databind.JsonNode; import com.vincent.rsf.framework.common.R; import com.vincent.rsf.server.ai.config.AiProperties; import com.vincent.rsf.server.ai.dto.AiChatStreamRequest; import com.vincent.rsf.server.ai.dto.AiSessionCreateRequest; import com.vincent.rsf.server.ai.dto.AiSessionRenameRequest; import com.vincent.rsf.server.ai.dto.GatewayChatMessage; import com.vincent.rsf.server.ai.dto.GatewayChatRequest; import com.vincent.rsf.server.ai.model.AiChatMessage; import com.vincent.rsf.server.ai.model.AiChatSession; import com.vincent.rsf.server.ai.model.AiPromptContext; import com.vincent.rsf.server.ai.service.AiGatewayClient; import com.vincent.rsf.server.ai.service.AiPromptContextService; import com.vincent.rsf.server.ai.service.AiRuntimeConfigService; import com.vincent.rsf.server.ai.service.AiSessionService; import com.vincent.rsf.server.system.controller.BaseController; import org.springframework.http.MediaType; import org.springframework.web.bind.annotation.*; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import javax.annotation.Resource; import java.io.IOException; import java.util.Date; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @RestController @RequestMapping("/ai") public class AiController extends BaseController { @Resource private AiSessionService aiSessionService; @Resource private AiProperties aiProperties; @Resource private AiGatewayClient aiGatewayClient; @Resource private AiRuntimeConfigService aiRuntimeConfigService; @Resource private AiPromptContextService aiPromptContextService; @GetMapping("/model/list") public R modelList() { List> models = new java.util.ArrayList<>(); for (AiRuntimeConfigService.ModelRuntimeConfig model : aiRuntimeConfigService.listEnabledModels()) { Map item = new LinkedHashMap<>(); item.put("code", model.getCode()); item.put("name", model.getName()); item.put("provider", model.getProvider()); item.put("enabled", model.getEnabled()); models.add(item); } return R.ok().add(models); } @GetMapping("/session/list") public R sessionList() { return R.ok().add(aiSessionService.listSessions(getTenantId(), getLoginUserId())); } @PostMapping("/session/create") public R createSession(@RequestBody(required = false) AiSessionCreateRequest request) { AiChatSession session = aiSessionService.createSession( getTenantId(), getLoginUserId(), request == null ? null : request.getTitle(), request == null ? null : request.getModelCode() ); return R.ok().add(session); } @PostMapping("/session/{sessionId}/rename") public R renameSession(@PathVariable("sessionId") String sessionId, @RequestBody AiSessionRenameRequest request) { AiChatSession session = aiSessionService.renameSession(getTenantId(), getLoginUserId(), sessionId, request.getTitle()); return R.ok().add(session); } @PostMapping("/session/remove/{sessionId}") public R removeSession(@PathVariable("sessionId") String sessionId) { aiSessionService.removeSession(getTenantId(), getLoginUserId(), sessionId); return R.ok(); } @GetMapping("/session/{sessionId}/messages") public R messageList(@PathVariable("sessionId") String sessionId) { return R.ok().add(aiSessionService.listMessages(getTenantId(), getLoginUserId(), sessionId)); } @PostMapping("/chat/stop") public R stop(@RequestBody AiChatStreamRequest request) { if (request != null && request.getSessionId() != null) { aiSessionService.requestStop(request.getSessionId()); } return R.ok(); } @PostMapping(value = "/chat/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) public SseEmitter chatStream(@RequestBody AiChatStreamRequest request) { SseEmitter emitter = new SseEmitter(0L); Long tenantId = getTenantId(); Long userId = getLoginUserId(); if (tenantId == null || userId == null) { completeWithError(emitter, "请先登录后再使用AI助手"); return emitter; } if (request == null || request.getMessage() == null || request.getMessage().trim().isEmpty()) { completeWithError(emitter, "消息内容不能为空"); return emitter; } AiChatSession session = aiSessionService.ensureSession(tenantId, userId, request.getSessionId(), request.getModelCode()); aiSessionService.clearStopFlag(session.getId()); aiSessionService.appendMessage(tenantId, userId, session.getId(), "user", request.getMessage(), session.getModelCode()); AiRuntimeConfigService.ModelRuntimeConfig modelRuntimeConfig = aiRuntimeConfigService.resolveModel(session.getModelCode()); List contextMessages = aiSessionService.listContextMessages( tenantId, userId, session.getId(), modelRuntimeConfig.getMaxContextMessages() ); Thread thread = new Thread(() -> { StringBuilder assistantReply = new StringBuilder(); boolean doneSent = false; try { emitter.send(SseEmitter.event().name("session").data(buildSessionPayload(session), MediaType.APPLICATION_JSON)); GatewayChatRequest gatewayChatRequest = buildGatewayRequest( tenantId, userId, session, contextMessages, modelRuntimeConfig, request.getMessage() ); aiGatewayClient.stream(gatewayChatRequest, event -> handleGatewayEvent( emitter, event, session, assistantReply )); if (aiSessionService.isStopRequested(session.getId())) { if (assistantReply.length() > 0) { aiSessionService.appendMessage(tenantId, userId, session.getId(), "assistant", assistantReply.toString(), session.getModelCode()); } emitter.send(SseEmitter.event().name("done").data(buildDonePayload(session, true), MediaType.APPLICATION_JSON)); doneSent = true; } } catch (Exception e) { try { emitter.send(SseEmitter.event().name("error").data(buildErrorPayload(e.getMessage()), MediaType.APPLICATION_JSON)); } catch (IOException ignore) { } } finally { if (!doneSent && assistantReply.length() > 0) { aiSessionService.appendMessage(tenantId, userId, session.getId(), "assistant", assistantReply.toString(), session.getModelCode()); try { emitter.send(SseEmitter.event().name("done").data(buildDonePayload(session, false), MediaType.APPLICATION_JSON)); } catch (IOException ignore) { } } emitter.complete(); aiSessionService.clearStopFlag(session.getId()); } }, "ai-chat-stream-" + session.getId()); thread.setDaemon(true); thread.start(); return emitter; } private boolean handleGatewayEvent(SseEmitter emitter, JsonNode event, AiChatSession session, StringBuilder assistantReply) throws Exception { if (aiSessionService.isStopRequested(session.getId())) { return false; } String type = event.path("type").asText(); if ("delta".equals(type)) { String content = event.path("content").asText(""); assistantReply.append(content); emitter.send(SseEmitter.event().name("delta").data(buildDeltaPayload(session, content), MediaType.APPLICATION_JSON)); return true; } if ("error".equals(type)) { emitter.send(SseEmitter.event().name("error").data(buildErrorPayload(event.path("message").asText("模型调用失败")), MediaType.APPLICATION_JSON)); return false; } if ("done".equals(type)) { return false; } return true; } private GatewayChatRequest buildGatewayRequest(Long tenantId, Long userId, AiChatSession session, List contextMessages, AiRuntimeConfigService.ModelRuntimeConfig modelRuntimeConfig, String latestQuestion) { GatewayChatRequest request = new GatewayChatRequest(); request.setSessionId(session.getId()); request.setModelCode(session.getModelCode()); request.setSystemPrompt(aiPromptContextService.buildSystemPrompt( modelRuntimeConfig.getSystemPrompt(), new AiPromptContext() .setTenantId(tenantId) .setUserId(userId) .setSessionId(session.getId()) .setModelCode(session.getModelCode()) .setQuestion(latestQuestion) )); request.setChatUrl(modelRuntimeConfig.getChatUrl()); request.setApiKey(modelRuntimeConfig.getApiKey()); request.setModelName(modelRuntimeConfig.getModelName()); for (AiChatMessage contextMessage : contextMessages) { GatewayChatMessage item = new GatewayChatMessage(); item.setRole(contextMessage.getRole()); item.setContent(contextMessage.getContent()); request.getMessages().add(item); } return request; } private Map buildSessionPayload(AiChatSession session) { Map payload = new LinkedHashMap<>(); payload.put("sessionId", session.getId()); payload.put("title", session.getTitle()); payload.put("modelCode", session.getModelCode()); return payload; } private Map buildDeltaPayload(AiChatSession session, String content) { Map payload = new LinkedHashMap<>(); payload.put("sessionId", session.getId()); payload.put("modelCode", session.getModelCode()); payload.put("content", content); payload.put("timestamp", new Date().getTime()); return payload; } private Map buildDonePayload(AiChatSession session, boolean stopped) { Map payload = new LinkedHashMap<>(); payload.put("sessionId", session.getId()); payload.put("modelCode", session.getModelCode()); payload.put("stopped", stopped); return payload; } private Map buildErrorPayload(String message) { Map payload = new LinkedHashMap<>(); payload.put("message", message == null ? "AI服务异常" : message); return payload; } private void completeWithError(SseEmitter emitter, String message) { try { emitter.send(SseEmitter.event().name("error").data(buildErrorPayload(message), MediaType.APPLICATION_JSON)); } catch (IOException ignore) { } emitter.complete(); } }