package com.vincent.rsf.server.ai.controller; import com.fasterxml.jackson.databind.JsonNode; import com.vincent.rsf.framework.common.R; import com.vincent.rsf.server.ai.constant.AiSceneCode; import com.vincent.rsf.server.ai.dto.AiChatStreamRequest; import com.vincent.rsf.server.ai.dto.AiSessionCreateRequest; import com.vincent.rsf.server.ai.dto.AiSessionRenameRequest; import com.vincent.rsf.server.ai.model.AiChatMessage; import com.vincent.rsf.server.ai.model.AiChatSession; import com.vincent.rsf.server.ai.model.AiPromptContext; import com.vincent.rsf.server.ai.service.diagnosis.AiChatStreamOrchestrator; import com.vincent.rsf.server.ai.service.diagnosis.AiDiagnosisRuntimeService; import com.vincent.rsf.server.ai.service.AiModelRouteRuntimeService; import com.vincent.rsf.server.ai.service.AiRuntimeConfigService; import com.vincent.rsf.server.ai.service.AiSessionService; import com.vincent.rsf.server.system.controller.BaseController; import com.vincent.rsf.server.system.entity.AiDiagnosisRecord; import org.springframework.http.MediaType; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import javax.annotation.Resource; import java.io.IOException; import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @RestController @RequestMapping("/ai") public class AiController extends BaseController { @Resource private AiSessionService aiSessionService; @Resource private AiRuntimeConfigService aiRuntimeConfigService; @Resource private AiModelRouteRuntimeService aiModelRouteRuntimeService; @Resource private AiDiagnosisRuntimeService aiDiagnosisRuntimeService; @Resource private AiChatStreamOrchestrator aiChatStreamOrchestrator; @GetMapping("/model/list") public R modelList() { List> models = new ArrayList<>(); for (AiRuntimeConfigService.ModelRuntimeConfig model : aiRuntimeConfigService.listEnabledModels()) { Map item = new LinkedHashMap<>(); item.put("code", model.getCode()); item.put("name", model.getName()); item.put("provider", model.getProvider()); item.put("enabled", model.getEnabled()); models.add(item); } return R.ok().add(models); } @GetMapping("/session/list") public R sessionList() { return R.ok().add(aiSessionService.listSessions(getTenantId(), getLoginUserId())); } @PostMapping("/session/create") public R createSession(@RequestBody(required = false) AiSessionCreateRequest request) { AiChatSession session = aiSessionService.createSession( getTenantId(), getLoginUserId(), request == null ? null : request.getTitle(), request == null ? null : request.getModelCode() ); return R.ok().add(session); } @PostMapping("/session/{sessionId}/rename") public R renameSession(@PathVariable("sessionId") String sessionId, @RequestBody AiSessionRenameRequest request) { AiChatSession session = aiSessionService.renameSession(getTenantId(), getLoginUserId(), sessionId, request.getTitle()); return R.ok().add(session); } @PostMapping("/session/remove/{sessionId}") public R removeSession(@PathVariable("sessionId") String sessionId) { aiSessionService.removeSession(getTenantId(), getLoginUserId(), sessionId); return R.ok(); } @GetMapping("/session/{sessionId}/messages") public R messageList(@PathVariable("sessionId") String sessionId) { return R.ok().add(aiSessionService.listMessages(getTenantId(), getLoginUserId(), sessionId)); } @PostMapping("/chat/stop") public R stop(@RequestBody AiChatStreamRequest request) { if (request != null && request.getSessionId() != null) { aiSessionService.requestStop(request.getSessionId()); } return R.ok(); } @PostMapping(value = "/chat/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) public SseEmitter chatStream(@RequestBody AiChatStreamRequest request) { return doChatStream(normalizeRequest(request)); } @PostMapping(value = "/diagnose/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) public SseEmitter diagnoseStream(@RequestBody(required = false) AiChatStreamRequest request) { AiChatStreamRequest diagnosisRequest = normalizeRequest(request); diagnosisRequest.setSceneCode(AiSceneCode.SYSTEM_DIAGNOSE); if (diagnosisRequest.getMessage() == null || diagnosisRequest.getMessage().trim().isEmpty()) { diagnosisRequest.setMessage("请对当前WMS系统进行一次巡检诊断,结合库存、任务、设备站点数据识别异常并给出处理建议。"); } return doChatStream(diagnosisRequest); } private SseEmitter doChatStream(AiChatStreamRequest request) { SseEmitter emitter = new SseEmitter(0L); Long tenantId = getTenantId(); Long userId = getLoginUserId(); if (tenantId == null || userId == null) { completeWithError(emitter, "请先登录后再使用AI助手"); return emitter; } if (request == null || request.getMessage() == null || request.getMessage().trim().isEmpty()) { completeWithError(emitter, "消息内容不能为空"); return emitter; } AiChatSession session = aiSessionService.ensureSession(tenantId, userId, request.getSessionId(), request.getModelCode()); aiSessionService.clearStopFlag(session.getId()); aiSessionService.appendMessage(tenantId, userId, session.getId(), "user", request.getMessage(), session.getModelCode()); AiPromptContext promptContext = new AiPromptContext() .setTenantId(tenantId) .setUserId(userId) .setSessionId(session.getId()) .setModelCode(session.getModelCode()) .setQuestion(request.getMessage()) .setSceneCode(request.getSceneCode()); List candidates = aiModelRouteRuntimeService.resolveCandidates( tenantId, request.getSceneCode(), session.getModelCode() ); if (candidates.isEmpty()) { completeWithError(emitter, "未找到可用的AI模型配置"); return emitter; } int maxContextMessages = resolveContextSize(candidates); List contextMessages = aiSessionService.listContextMessages( tenantId, userId, session.getId(), maxContextMessages ); AiDiagnosisRecord diagnosisRecord = AiSceneCode.SYSTEM_DIAGNOSE.equals(request.getSceneCode()) ? aiDiagnosisRuntimeService.startDiagnosis(tenantId, userId, session.getId(), request.getSceneCode(), request.getMessage()) : null; Thread thread = new Thread(() -> executeStream( emitter, tenantId, userId, session, request, promptContext, contextMessages, diagnosisRecord, candidates ), "ai-chat-stream-" + session.getId()); thread.setDaemon(true); thread.start(); return emitter; } private void executeStream(SseEmitter emitter, Long tenantId, Long userId, AiChatSession session, AiChatStreamRequest request, AiPromptContext promptContext, List contextMessages, AiDiagnosisRecord diagnosisRecord, List candidates) { aiChatStreamOrchestrator.executeStream( emitter, tenantId, userId, session, request, promptContext, contextMessages, diagnosisRecord, candidates ); } private int resolveContextSize(List candidates) { return aiChatStreamOrchestrator.resolveContextSize(candidates); } private AiChatStreamRequest normalizeRequest(AiChatStreamRequest request) { AiChatStreamRequest normalized = request == null ? new AiChatStreamRequest() : request; if (normalized.getSceneCode() == null || normalized.getSceneCode().trim().isEmpty()) { normalized.setSceneCode(AiSceneCode.GENERAL_CHAT); } return normalized; } private Map buildErrorPayload(String message) { Map payload = new LinkedHashMap<>(); payload.put("message", message == null ? "AI服务异常" : message); return payload; } private void completeWithError(SseEmitter emitter, String message) { try { emitter.send(SseEmitter.event().name("error").data(buildErrorPayload(message), MediaType.APPLICATION_JSON)); } catch (IOException ignore) { } emitter.complete(); } }