package com.vincent.rsf.server.ai.controller;
|
|
import com.fasterxml.jackson.databind.JsonNode;
|
import com.vincent.rsf.framework.common.R;
|
import com.vincent.rsf.server.ai.config.AiProperties;
|
import com.vincent.rsf.server.ai.dto.AiChatStreamRequest;
|
import com.vincent.rsf.server.ai.dto.AiSessionCreateRequest;
|
import com.vincent.rsf.server.ai.dto.AiSessionRenameRequest;
|
import com.vincent.rsf.server.ai.dto.GatewayChatMessage;
|
import com.vincent.rsf.server.ai.dto.GatewayChatRequest;
|
import com.vincent.rsf.server.ai.model.AiChatMessage;
|
import com.vincent.rsf.server.ai.model.AiChatSession;
|
import com.vincent.rsf.server.ai.model.AiPromptContext;
|
import com.vincent.rsf.server.ai.service.AiGatewayClient;
|
import com.vincent.rsf.server.ai.service.AiPromptContextService;
|
import com.vincent.rsf.server.ai.service.AiRuntimeConfigService;
|
import com.vincent.rsf.server.ai.service.AiSessionService;
|
import com.vincent.rsf.server.system.controller.BaseController;
|
import org.springframework.http.MediaType;
|
import org.springframework.web.bind.annotation.*;
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
|
import javax.annotation.Resource;
|
import java.io.IOException;
|
import java.util.Date;
|
import java.util.LinkedHashMap;
|
import java.util.List;
|
import java.util.Map;
|
|
@RestController
|
@RequestMapping("/ai")
|
public class AiController extends BaseController {
|
|
@Resource
|
private AiSessionService aiSessionService;
|
@Resource
|
private AiProperties aiProperties;
|
@Resource
|
private AiGatewayClient aiGatewayClient;
|
@Resource
|
private AiRuntimeConfigService aiRuntimeConfigService;
|
@Resource
|
private AiPromptContextService aiPromptContextService;
|
|
@GetMapping("/model/list")
|
public R modelList() {
|
List<Map<String, Object>> models = new java.util.ArrayList<>();
|
for (AiRuntimeConfigService.ModelRuntimeConfig model : aiRuntimeConfigService.listEnabledModels()) {
|
Map<String, Object> item = new LinkedHashMap<>();
|
item.put("code", model.getCode());
|
item.put("name", model.getName());
|
item.put("provider", model.getProvider());
|
item.put("enabled", model.getEnabled());
|
models.add(item);
|
}
|
return R.ok().add(models);
|
}
|
|
@GetMapping("/session/list")
|
public R sessionList() {
|
return R.ok().add(aiSessionService.listSessions(getTenantId(), getLoginUserId()));
|
}
|
|
@PostMapping("/session/create")
|
public R createSession(@RequestBody(required = false) AiSessionCreateRequest request) {
|
AiChatSession session = aiSessionService.createSession(
|
getTenantId(),
|
getLoginUserId(),
|
request == null ? null : request.getTitle(),
|
request == null ? null : request.getModelCode()
|
);
|
return R.ok().add(session);
|
}
|
|
@PostMapping("/session/{sessionId}/rename")
|
public R renameSession(@PathVariable("sessionId") String sessionId, @RequestBody AiSessionRenameRequest request) {
|
AiChatSession session = aiSessionService.renameSession(getTenantId(), getLoginUserId(), sessionId, request.getTitle());
|
return R.ok().add(session);
|
}
|
|
@PostMapping("/session/remove/{sessionId}")
|
public R removeSession(@PathVariable("sessionId") String sessionId) {
|
aiSessionService.removeSession(getTenantId(), getLoginUserId(), sessionId);
|
return R.ok();
|
}
|
|
@GetMapping("/session/{sessionId}/messages")
|
public R messageList(@PathVariable("sessionId") String sessionId) {
|
return R.ok().add(aiSessionService.listMessages(getTenantId(), getLoginUserId(), sessionId));
|
}
|
|
@PostMapping("/chat/stop")
|
public R stop(@RequestBody AiChatStreamRequest request) {
|
if (request != null && request.getSessionId() != null) {
|
aiSessionService.requestStop(request.getSessionId());
|
}
|
return R.ok();
|
}
|
|
@PostMapping(value = "/chat/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
public SseEmitter chatStream(@RequestBody AiChatStreamRequest request) {
|
SseEmitter emitter = new SseEmitter(0L);
|
Long tenantId = getTenantId();
|
Long userId = getLoginUserId();
|
if (tenantId == null || userId == null) {
|
completeWithError(emitter, "请先登录后再使用AI助手");
|
return emitter;
|
}
|
if (request == null || request.getMessage() == null || request.getMessage().trim().isEmpty()) {
|
completeWithError(emitter, "消息内容不能为空");
|
return emitter;
|
}
|
AiChatSession session = aiSessionService.ensureSession(tenantId, userId, request.getSessionId(), request.getModelCode());
|
aiSessionService.clearStopFlag(session.getId());
|
aiSessionService.appendMessage(tenantId, userId, session.getId(), "user", request.getMessage(), session.getModelCode());
|
AiRuntimeConfigService.ModelRuntimeConfig modelRuntimeConfig = aiRuntimeConfigService.resolveModel(session.getModelCode());
|
List<AiChatMessage> contextMessages = aiSessionService.listContextMessages(
|
tenantId,
|
userId,
|
session.getId(),
|
modelRuntimeConfig.getMaxContextMessages()
|
);
|
|
Thread thread = new Thread(() -> {
|
StringBuilder assistantReply = new StringBuilder();
|
boolean doneSent = false;
|
try {
|
emitter.send(SseEmitter.event().name("session").data(buildSessionPayload(session), MediaType.APPLICATION_JSON));
|
GatewayChatRequest gatewayChatRequest = buildGatewayRequest(
|
tenantId,
|
userId,
|
session,
|
contextMessages,
|
modelRuntimeConfig,
|
request.getMessage()
|
);
|
aiGatewayClient.stream(gatewayChatRequest, event -> handleGatewayEvent(
|
emitter,
|
event,
|
session,
|
assistantReply
|
));
|
if (aiSessionService.isStopRequested(session.getId())) {
|
if (assistantReply.length() > 0) {
|
aiSessionService.appendMessage(tenantId, userId, session.getId(), "assistant", assistantReply.toString(), session.getModelCode());
|
}
|
emitter.send(SseEmitter.event().name("done").data(buildDonePayload(session, true), MediaType.APPLICATION_JSON));
|
doneSent = true;
|
}
|
} catch (Exception e) {
|
try {
|
emitter.send(SseEmitter.event().name("error").data(buildErrorPayload(e.getMessage()), MediaType.APPLICATION_JSON));
|
} catch (IOException ignore) {
|
}
|
} finally {
|
if (!doneSent && assistantReply.length() > 0) {
|
aiSessionService.appendMessage(tenantId, userId, session.getId(), "assistant", assistantReply.toString(), session.getModelCode());
|
try {
|
emitter.send(SseEmitter.event().name("done").data(buildDonePayload(session, false), MediaType.APPLICATION_JSON));
|
} catch (IOException ignore) {
|
}
|
}
|
emitter.complete();
|
aiSessionService.clearStopFlag(session.getId());
|
}
|
}, "ai-chat-stream-" + session.getId());
|
thread.setDaemon(true);
|
thread.start();
|
return emitter;
|
}
|
|
private boolean handleGatewayEvent(SseEmitter emitter, JsonNode event, AiChatSession session,
|
StringBuilder assistantReply) throws Exception {
|
if (aiSessionService.isStopRequested(session.getId())) {
|
return false;
|
}
|
String type = event.path("type").asText();
|
if ("delta".equals(type)) {
|
String content = event.path("content").asText("");
|
assistantReply.append(content);
|
emitter.send(SseEmitter.event().name("delta").data(buildDeltaPayload(session, content), MediaType.APPLICATION_JSON));
|
return true;
|
}
|
if ("error".equals(type)) {
|
emitter.send(SseEmitter.event().name("error").data(buildErrorPayload(event.path("message").asText("模型调用失败")), MediaType.APPLICATION_JSON));
|
return false;
|
}
|
if ("done".equals(type)) {
|
return false;
|
}
|
return true;
|
}
|
|
private GatewayChatRequest buildGatewayRequest(Long tenantId, Long userId, AiChatSession session, List<AiChatMessage> contextMessages,
|
AiRuntimeConfigService.ModelRuntimeConfig modelRuntimeConfig,
|
String latestQuestion) {
|
GatewayChatRequest request = new GatewayChatRequest();
|
request.setSessionId(session.getId());
|
request.setModelCode(session.getModelCode());
|
request.setSystemPrompt(aiPromptContextService.buildSystemPrompt(
|
modelRuntimeConfig.getSystemPrompt(),
|
new AiPromptContext()
|
.setTenantId(tenantId)
|
.setUserId(userId)
|
.setSessionId(session.getId())
|
.setModelCode(session.getModelCode())
|
.setQuestion(latestQuestion)
|
));
|
request.setChatUrl(modelRuntimeConfig.getChatUrl());
|
request.setApiKey(modelRuntimeConfig.getApiKey());
|
request.setModelName(modelRuntimeConfig.getModelName());
|
for (AiChatMessage contextMessage : contextMessages) {
|
GatewayChatMessage item = new GatewayChatMessage();
|
item.setRole(contextMessage.getRole());
|
item.setContent(contextMessage.getContent());
|
request.getMessages().add(item);
|
}
|
return request;
|
}
|
|
private Map<String, Object> buildSessionPayload(AiChatSession session) {
|
Map<String, Object> payload = new LinkedHashMap<>();
|
payload.put("sessionId", session.getId());
|
payload.put("title", session.getTitle());
|
payload.put("modelCode", session.getModelCode());
|
return payload;
|
}
|
|
private Map<String, Object> buildDeltaPayload(AiChatSession session, String content) {
|
Map<String, Object> payload = new LinkedHashMap<>();
|
payload.put("sessionId", session.getId());
|
payload.put("modelCode", session.getModelCode());
|
payload.put("content", content);
|
payload.put("timestamp", new Date().getTime());
|
return payload;
|
}
|
|
private Map<String, Object> buildDonePayload(AiChatSession session, boolean stopped) {
|
Map<String, Object> payload = new LinkedHashMap<>();
|
payload.put("sessionId", session.getId());
|
payload.put("modelCode", session.getModelCode());
|
payload.put("stopped", stopped);
|
return payload;
|
}
|
|
private Map<String, Object> buildErrorPayload(String message) {
|
Map<String, Object> payload = new LinkedHashMap<>();
|
payload.put("message", message == null ? "AI服务异常" : message);
|
return payload;
|
}
|
|
private void completeWithError(SseEmitter emitter, String message) {
|
try {
|
emitter.send(SseEmitter.event().name("error").data(buildErrorPayload(message), MediaType.APPLICATION_JSON));
|
} catch (IOException ignore) {
|
}
|
emitter.complete();
|
}
|
|
}
|