package com.vincent.rsf.server.ai.controller;
|
|
import com.fasterxml.jackson.databind.JsonNode;
|
import com.vincent.rsf.framework.common.R;
|
import com.vincent.rsf.server.ai.constant.AiSceneCode;
|
import com.vincent.rsf.server.ai.dto.AiChatStreamRequest;
|
import com.vincent.rsf.server.ai.dto.AiSessionCreateRequest;
|
import com.vincent.rsf.server.ai.dto.AiSessionRenameRequest;
|
import com.vincent.rsf.server.ai.model.AiChatMessage;
|
import com.vincent.rsf.server.ai.model.AiChatSession;
|
import com.vincent.rsf.server.ai.model.AiPromptContext;
|
import com.vincent.rsf.server.ai.service.diagnosis.AiChatStreamOrchestrator;
|
import com.vincent.rsf.server.ai.service.diagnosis.AiDiagnosisRuntimeService;
|
import com.vincent.rsf.server.ai.service.AiModelRouteRuntimeService;
|
import com.vincent.rsf.server.ai.service.AiRuntimeConfigService;
|
import com.vincent.rsf.server.ai.service.AiSessionService;
|
import com.vincent.rsf.server.system.controller.BaseController;
|
import com.vincent.rsf.server.system.entity.AiDiagnosisRecord;
|
import org.springframework.http.MediaType;
|
import org.springframework.web.bind.annotation.GetMapping;
|
import org.springframework.web.bind.annotation.PathVariable;
|
import org.springframework.web.bind.annotation.PostMapping;
|
import org.springframework.web.bind.annotation.RequestBody;
|
import org.springframework.web.bind.annotation.RequestMapping;
|
import org.springframework.web.bind.annotation.RestController;
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
|
import javax.annotation.Resource;
|
import java.io.IOException;
|
import java.util.ArrayList;
|
import java.util.LinkedHashMap;
|
import java.util.List;
|
import java.util.Map;
|
|
@RestController
|
@RequestMapping("/ai")
|
public class AiController extends BaseController {
|
|
@Resource
|
private AiSessionService aiSessionService;
|
@Resource
|
private AiRuntimeConfigService aiRuntimeConfigService;
|
@Resource
|
private AiModelRouteRuntimeService aiModelRouteRuntimeService;
|
@Resource
|
private AiDiagnosisRuntimeService aiDiagnosisRuntimeService;
|
@Resource
|
private AiChatStreamOrchestrator aiChatStreamOrchestrator;
|
|
@GetMapping("/model/list")
|
public R modelList() {
|
List<Map<String, Object>> models = new ArrayList<>();
|
for (AiRuntimeConfigService.ModelRuntimeConfig model : aiRuntimeConfigService.listEnabledModels()) {
|
Map<String, Object> item = new LinkedHashMap<>();
|
item.put("code", model.getCode());
|
item.put("name", model.getName());
|
item.put("provider", model.getProvider());
|
item.put("enabled", model.getEnabled());
|
models.add(item);
|
}
|
return R.ok().add(models);
|
}
|
|
@GetMapping("/session/list")
|
public R sessionList() {
|
return R.ok().add(aiSessionService.listSessions(getTenantId(), getLoginUserId()));
|
}
|
|
@PostMapping("/session/create")
|
public R createSession(@RequestBody(required = false) AiSessionCreateRequest request) {
|
AiChatSession session = aiSessionService.createSession(
|
getTenantId(),
|
getLoginUserId(),
|
request == null ? null : request.getTitle(),
|
request == null ? null : request.getModelCode()
|
);
|
return R.ok().add(session);
|
}
|
|
@PostMapping("/session/{sessionId}/rename")
|
public R renameSession(@PathVariable("sessionId") String sessionId, @RequestBody AiSessionRenameRequest request) {
|
AiChatSession session = aiSessionService.renameSession(getTenantId(), getLoginUserId(), sessionId, request.getTitle());
|
return R.ok().add(session);
|
}
|
|
@PostMapping("/session/remove/{sessionId}")
|
public R removeSession(@PathVariable("sessionId") String sessionId) {
|
aiSessionService.removeSession(getTenantId(), getLoginUserId(), sessionId);
|
return R.ok();
|
}
|
|
@GetMapping("/session/{sessionId}/messages")
|
public R messageList(@PathVariable("sessionId") String sessionId) {
|
return R.ok().add(aiSessionService.listMessages(getTenantId(), getLoginUserId(), sessionId));
|
}
|
|
@PostMapping("/chat/stop")
|
public R stop(@RequestBody AiChatStreamRequest request) {
|
if (request != null && request.getSessionId() != null) {
|
aiSessionService.requestStop(request.getSessionId());
|
}
|
return R.ok();
|
}
|
|
@PostMapping(value = "/chat/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
public SseEmitter chatStream(@RequestBody AiChatStreamRequest request) {
|
return doChatStream(normalizeRequest(request));
|
}
|
|
@PostMapping(value = "/diagnose/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
public SseEmitter diagnoseStream(@RequestBody(required = false) AiChatStreamRequest request) {
|
AiChatStreamRequest diagnosisRequest = normalizeRequest(request);
|
diagnosisRequest.setSceneCode(AiSceneCode.SYSTEM_DIAGNOSE);
|
if (diagnosisRequest.getMessage() == null || diagnosisRequest.getMessage().trim().isEmpty()) {
|
diagnosisRequest.setMessage("请对当前WMS系统进行一次巡检诊断,结合库存、任务、设备站点数据识别异常并给出处理建议。");
|
}
|
return doChatStream(diagnosisRequest);
|
}
|
|
private SseEmitter doChatStream(AiChatStreamRequest request) {
|
SseEmitter emitter = new SseEmitter(0L);
|
Long tenantId = getTenantId();
|
Long userId = getLoginUserId();
|
if (tenantId == null || userId == null) {
|
completeWithError(emitter, "请先登录后再使用AI助手");
|
return emitter;
|
}
|
if (request == null || request.getMessage() == null || request.getMessage().trim().isEmpty()) {
|
completeWithError(emitter, "消息内容不能为空");
|
return emitter;
|
}
|
|
AiChatSession session = aiSessionService.ensureSession(tenantId, userId, request.getSessionId(), request.getModelCode());
|
aiSessionService.clearStopFlag(session.getId());
|
aiSessionService.appendMessage(tenantId, userId, session.getId(), "user", request.getMessage(), session.getModelCode());
|
|
AiPromptContext promptContext = new AiPromptContext()
|
.setTenantId(tenantId)
|
.setUserId(userId)
|
.setSessionId(session.getId())
|
.setModelCode(session.getModelCode())
|
.setQuestion(request.getMessage())
|
.setSceneCode(request.getSceneCode());
|
|
List<AiModelRouteRuntimeService.RouteCandidate> candidates = aiModelRouteRuntimeService.resolveCandidates(
|
tenantId,
|
request.getSceneCode(),
|
session.getModelCode()
|
);
|
if (candidates.isEmpty()) {
|
completeWithError(emitter, "未找到可用的AI模型配置");
|
return emitter;
|
}
|
int maxContextMessages = resolveContextSize(candidates);
|
List<AiChatMessage> contextMessages = aiSessionService.listContextMessages(
|
tenantId,
|
userId,
|
session.getId(),
|
maxContextMessages
|
);
|
AiDiagnosisRecord diagnosisRecord = AiSceneCode.SYSTEM_DIAGNOSE.equals(request.getSceneCode())
|
? aiDiagnosisRuntimeService.startDiagnosis(tenantId, userId, session.getId(), request.getSceneCode(), request.getMessage())
|
: null;
|
|
Thread thread = new Thread(() -> executeStream(
|
emitter, tenantId, userId, session, request, promptContext, contextMessages, diagnosisRecord, candidates
|
), "ai-chat-stream-" + session.getId());
|
thread.setDaemon(true);
|
thread.start();
|
return emitter;
|
}
|
|
private void executeStream(SseEmitter emitter,
|
Long tenantId,
|
Long userId,
|
AiChatSession session,
|
AiChatStreamRequest request,
|
AiPromptContext promptContext,
|
List<AiChatMessage> contextMessages,
|
AiDiagnosisRecord diagnosisRecord,
|
List<AiModelRouteRuntimeService.RouteCandidate> candidates) {
|
aiChatStreamOrchestrator.executeStream(
|
emitter,
|
tenantId,
|
userId,
|
session,
|
request,
|
promptContext,
|
contextMessages,
|
diagnosisRecord,
|
candidates
|
);
|
}
|
|
private int resolveContextSize(List<AiModelRouteRuntimeService.RouteCandidate> candidates) {
|
return aiChatStreamOrchestrator.resolveContextSize(candidates);
|
}
|
|
private AiChatStreamRequest normalizeRequest(AiChatStreamRequest request) {
|
AiChatStreamRequest normalized = request == null ? new AiChatStreamRequest() : request;
|
if (normalized.getSceneCode() == null || normalized.getSceneCode().trim().isEmpty()) {
|
normalized.setSceneCode(AiSceneCode.GENERAL_CHAT);
|
}
|
return normalized;
|
}
|
|
private Map<String, Object> buildErrorPayload(String message) {
|
Map<String, Object> payload = new LinkedHashMap<>();
|
payload.put("message", message == null ? "AI服务异常" : message);
|
return payload;
|
}
|
|
private void completeWithError(SseEmitter emitter, String message) {
|
try {
|
emitter.send(SseEmitter.event().name("error").data(buildErrorPayload(message), MediaType.APPLICATION_JSON));
|
} catch (IOException ignore) {
|
}
|
emitter.complete();
|
}
|
}
|