| | |
| | | |
| | | import com.fasterxml.jackson.databind.JsonNode; |
| | | import com.vincent.rsf.framework.common.R; |
| | | import com.vincent.rsf.server.ai.config.AiProperties; |
| | | import com.vincent.rsf.server.ai.constant.AiSceneCode; |
| | | import com.vincent.rsf.server.ai.dto.AiChatStreamRequest; |
| | | import com.vincent.rsf.server.ai.dto.AiSessionCreateRequest; |
| | | import com.vincent.rsf.server.ai.dto.AiSessionRenameRequest; |
| | | import com.vincent.rsf.server.ai.dto.GatewayChatMessage; |
| | | import com.vincent.rsf.server.ai.dto.GatewayChatRequest; |
| | | import com.vincent.rsf.server.ai.model.AiChatMessage; |
| | | import com.vincent.rsf.server.ai.model.AiChatSession; |
| | | import com.vincent.rsf.server.ai.model.AiPromptContext; |
| | | import com.vincent.rsf.server.ai.service.AiGatewayClient; |
| | | import com.vincent.rsf.server.ai.service.AiPromptContextService; |
| | | import com.vincent.rsf.server.ai.service.diagnosis.AiChatStreamOrchestrator; |
| | | import com.vincent.rsf.server.ai.service.diagnosis.AiDiagnosisRuntimeService; |
| | | import com.vincent.rsf.server.ai.service.AiModelRouteRuntimeService; |
| | | import com.vincent.rsf.server.ai.service.AiRuntimeConfigService; |
| | | import com.vincent.rsf.server.ai.service.AiSessionService; |
| | | import com.vincent.rsf.server.system.controller.BaseController; |
| | | import com.vincent.rsf.server.system.entity.AiDiagnosisRecord; |
| | | import org.springframework.http.MediaType; |
| | | import org.springframework.web.bind.annotation.*; |
| | | import org.springframework.web.bind.annotation.GetMapping; |
| | | import org.springframework.web.bind.annotation.PathVariable; |
| | | import org.springframework.web.bind.annotation.PostMapping; |
| | | import org.springframework.web.bind.annotation.RequestBody; |
| | | import org.springframework.web.bind.annotation.RequestMapping; |
| | | import org.springframework.web.bind.annotation.RestController; |
| | | import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; |
| | | |
| | | import javax.annotation.Resource; |
| | | import java.io.IOException; |
| | | import java.util.Date; |
| | | import java.util.ArrayList; |
| | | import java.util.LinkedHashMap; |
| | | import java.util.List; |
| | | import java.util.Map; |
| | |
| | | @Resource |
| | | private AiSessionService aiSessionService; |
| | | @Resource |
| | | private AiProperties aiProperties; |
| | | @Resource |
| | | private AiGatewayClient aiGatewayClient; |
| | | @Resource |
| | | private AiRuntimeConfigService aiRuntimeConfigService; |
| | | @Resource |
| | | private AiPromptContextService aiPromptContextService; |
| | | private AiModelRouteRuntimeService aiModelRouteRuntimeService; |
| | | @Resource |
| | | private AiDiagnosisRuntimeService aiDiagnosisRuntimeService; |
| | | @Resource |
| | | private AiChatStreamOrchestrator aiChatStreamOrchestrator; |
| | | |
| | | @GetMapping("/model/list") |
| | | public R modelList() { |
| | | List<Map<String, Object>> models = new java.util.ArrayList<>(); |
| | | List<Map<String, Object>> models = new ArrayList<>(); |
| | | for (AiRuntimeConfigService.ModelRuntimeConfig model : aiRuntimeConfigService.listEnabledModels()) { |
| | | Map<String, Object> item = new LinkedHashMap<>(); |
| | | item.put("code", model.getCode()); |
| | |
| | | |
| | | @PostMapping(value = "/chat/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) |
| | | public SseEmitter chatStream(@RequestBody AiChatStreamRequest request) { |
| | | return doChatStream(normalizeRequest(request)); |
| | | } |
| | | |
| | | @PostMapping(value = "/diagnose/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) |
| | | public SseEmitter diagnoseStream(@RequestBody(required = false) AiChatStreamRequest request) { |
| | | AiChatStreamRequest diagnosisRequest = normalizeRequest(request); |
| | | diagnosisRequest.setSceneCode(AiSceneCode.SYSTEM_DIAGNOSE); |
| | | if (diagnosisRequest.getMessage() == null || diagnosisRequest.getMessage().trim().isEmpty()) { |
| | | diagnosisRequest.setMessage("请对当前WMS系统进行一次巡检诊断,结合库存、任务、设备站点数据识别异常并给出处理建议。"); |
| | | } |
| | | return doChatStream(diagnosisRequest); |
| | | } |
| | | |
| | | private SseEmitter doChatStream(AiChatStreamRequest request) { |
| | | SseEmitter emitter = new SseEmitter(0L); |
| | | Long tenantId = getTenantId(); |
| | | Long userId = getLoginUserId(); |
| | |
| | | completeWithError(emitter, "消息内容不能为空"); |
| | | return emitter; |
| | | } |
| | | |
| | | AiChatSession session = aiSessionService.ensureSession(tenantId, userId, request.getSessionId(), request.getModelCode()); |
| | | aiSessionService.clearStopFlag(session.getId()); |
| | | aiSessionService.appendMessage(tenantId, userId, session.getId(), "user", request.getMessage(), session.getModelCode()); |
| | | AiRuntimeConfigService.ModelRuntimeConfig modelRuntimeConfig = aiRuntimeConfigService.resolveModel(session.getModelCode()); |
| | | |
| | | AiPromptContext promptContext = new AiPromptContext() |
| | | .setTenantId(tenantId) |
| | | .setUserId(userId) |
| | | .setSessionId(session.getId()) |
| | | .setModelCode(session.getModelCode()) |
| | | .setQuestion(request.getMessage()) |
| | | .setSceneCode(request.getSceneCode()); |
| | | |
| | | List<AiModelRouteRuntimeService.RouteCandidate> candidates = aiModelRouteRuntimeService.resolveCandidates( |
| | | tenantId, |
| | | request.getSceneCode(), |
| | | session.getModelCode() |
| | | ); |
| | | if (candidates.isEmpty()) { |
| | | completeWithError(emitter, "未找到可用的AI模型配置"); |
| | | return emitter; |
| | | } |
| | | int maxContextMessages = resolveContextSize(candidates); |
| | | List<AiChatMessage> contextMessages = aiSessionService.listContextMessages( |
| | | tenantId, |
| | | userId, |
| | | session.getId(), |
| | | modelRuntimeConfig.getMaxContextMessages() |
| | | maxContextMessages |
| | | ); |
| | | AiDiagnosisRecord diagnosisRecord = AiSceneCode.SYSTEM_DIAGNOSE.equals(request.getSceneCode()) |
| | | ? aiDiagnosisRuntimeService.startDiagnosis(tenantId, userId, session.getId(), request.getSceneCode(), request.getMessage()) |
| | | : null; |
| | | |
| | | Thread thread = new Thread(() -> { |
| | | StringBuilder assistantReply = new StringBuilder(); |
| | | boolean doneSent = false; |
| | | try { |
| | | emitter.send(SseEmitter.event().name("session").data(buildSessionPayload(session), MediaType.APPLICATION_JSON)); |
| | | GatewayChatRequest gatewayChatRequest = buildGatewayRequest( |
| | | tenantId, |
| | | userId, |
| | | session, |
| | | contextMessages, |
| | | modelRuntimeConfig, |
| | | request.getMessage() |
| | | ); |
| | | aiGatewayClient.stream(gatewayChatRequest, event -> handleGatewayEvent( |
| | | emitter, |
| | | event, |
| | | session, |
| | | assistantReply |
| | | )); |
| | | if (aiSessionService.isStopRequested(session.getId())) { |
| | | if (assistantReply.length() > 0) { |
| | | aiSessionService.appendMessage(tenantId, userId, session.getId(), "assistant", assistantReply.toString(), session.getModelCode()); |
| | | } |
| | | emitter.send(SseEmitter.event().name("done").data(buildDonePayload(session, true), MediaType.APPLICATION_JSON)); |
| | | doneSent = true; |
| | | } |
| | | } catch (Exception e) { |
| | | try { |
| | | emitter.send(SseEmitter.event().name("error").data(buildErrorPayload(e.getMessage()), MediaType.APPLICATION_JSON)); |
| | | } catch (IOException ignore) { |
| | | } |
| | | } finally { |
| | | if (!doneSent && assistantReply.length() > 0) { |
| | | aiSessionService.appendMessage(tenantId, userId, session.getId(), "assistant", assistantReply.toString(), session.getModelCode()); |
| | | try { |
| | | emitter.send(SseEmitter.event().name("done").data(buildDonePayload(session, false), MediaType.APPLICATION_JSON)); |
| | | } catch (IOException ignore) { |
| | | } |
| | | } |
| | | emitter.complete(); |
| | | aiSessionService.clearStopFlag(session.getId()); |
| | | } |
| | | }, "ai-chat-stream-" + session.getId()); |
| | | Thread thread = new Thread(() -> executeStream( |
| | | emitter, tenantId, userId, session, request, promptContext, contextMessages, diagnosisRecord, candidates |
| | | ), "ai-chat-stream-" + session.getId()); |
| | | thread.setDaemon(true); |
| | | thread.start(); |
| | | return emitter; |
| | | } |
| | | |
| | | private boolean handleGatewayEvent(SseEmitter emitter, JsonNode event, AiChatSession session, |
| | | StringBuilder assistantReply) throws Exception { |
| | | if (aiSessionService.isStopRequested(session.getId())) { |
| | | return false; |
| | | } |
| | | String type = event.path("type").asText(); |
| | | if ("delta".equals(type)) { |
| | | String content = event.path("content").asText(""); |
| | | assistantReply.append(content); |
| | | emitter.send(SseEmitter.event().name("delta").data(buildDeltaPayload(session, content), MediaType.APPLICATION_JSON)); |
| | | return true; |
| | | } |
| | | if ("error".equals(type)) { |
| | | emitter.send(SseEmitter.event().name("error").data(buildErrorPayload(event.path("message").asText("模型调用失败")), MediaType.APPLICATION_JSON)); |
| | | return false; |
| | | } |
| | | if ("done".equals(type)) { |
| | | return false; |
| | | } |
| | | return true; |
| | | private void executeStream(SseEmitter emitter, |
| | | Long tenantId, |
| | | Long userId, |
| | | AiChatSession session, |
| | | AiChatStreamRequest request, |
| | | AiPromptContext promptContext, |
| | | List<AiChatMessage> contextMessages, |
| | | AiDiagnosisRecord diagnosisRecord, |
| | | List<AiModelRouteRuntimeService.RouteCandidate> candidates) { |
| | | aiChatStreamOrchestrator.executeStream( |
| | | emitter, |
| | | tenantId, |
| | | userId, |
| | | session, |
| | | request, |
| | | promptContext, |
| | | contextMessages, |
| | | diagnosisRecord, |
| | | candidates |
| | | ); |
| | | } |
| | | |
| | | private GatewayChatRequest buildGatewayRequest(Long tenantId, Long userId, AiChatSession session, List<AiChatMessage> contextMessages, |
| | | AiRuntimeConfigService.ModelRuntimeConfig modelRuntimeConfig, |
| | | String latestQuestion) { |
| | | GatewayChatRequest request = new GatewayChatRequest(); |
| | | request.setSessionId(session.getId()); |
| | | request.setModelCode(session.getModelCode()); |
| | | request.setSystemPrompt(aiPromptContextService.buildSystemPrompt( |
| | | modelRuntimeConfig.getSystemPrompt(), |
| | | new AiPromptContext() |
| | | .setTenantId(tenantId) |
| | | .setUserId(userId) |
| | | .setSessionId(session.getId()) |
| | | .setModelCode(session.getModelCode()) |
| | | .setQuestion(latestQuestion) |
| | | )); |
| | | request.setChatUrl(modelRuntimeConfig.getChatUrl()); |
| | | request.setApiKey(modelRuntimeConfig.getApiKey()); |
| | | request.setModelName(modelRuntimeConfig.getModelName()); |
| | | for (AiChatMessage contextMessage : contextMessages) { |
| | | GatewayChatMessage item = new GatewayChatMessage(); |
| | | item.setRole(contextMessage.getRole()); |
| | | item.setContent(contextMessage.getContent()); |
| | | request.getMessages().add(item); |
| | | private int resolveContextSize(List<AiModelRouteRuntimeService.RouteCandidate> candidates) { |
| | | return aiChatStreamOrchestrator.resolveContextSize(candidates); |
| | | } |
| | | |
| | | private AiChatStreamRequest normalizeRequest(AiChatStreamRequest request) { |
| | | AiChatStreamRequest normalized = request == null ? new AiChatStreamRequest() : request; |
| | | if (normalized.getSceneCode() == null || normalized.getSceneCode().trim().isEmpty()) { |
| | | normalized.setSceneCode(AiSceneCode.GENERAL_CHAT); |
| | | } |
| | | return request; |
| | | } |
| | | |
| | | private Map<String, Object> buildSessionPayload(AiChatSession session) { |
| | | Map<String, Object> payload = new LinkedHashMap<>(); |
| | | payload.put("sessionId", session.getId()); |
| | | payload.put("title", session.getTitle()); |
| | | payload.put("modelCode", session.getModelCode()); |
| | | return payload; |
| | | } |
| | | |
| | | private Map<String, Object> buildDeltaPayload(AiChatSession session, String content) { |
| | | Map<String, Object> payload = new LinkedHashMap<>(); |
| | | payload.put("sessionId", session.getId()); |
| | | payload.put("modelCode", session.getModelCode()); |
| | | payload.put("content", content); |
| | | payload.put("timestamp", new Date().getTime()); |
| | | return payload; |
| | | } |
| | | |
| | | private Map<String, Object> buildDonePayload(AiChatSession session, boolean stopped) { |
| | | Map<String, Object> payload = new LinkedHashMap<>(); |
| | | payload.put("sessionId", session.getId()); |
| | | payload.put("modelCode", session.getModelCode()); |
| | | payload.put("stopped", stopped); |
| | | return payload; |
| | | return normalized; |
| | | } |
| | | |
| | | private Map<String, Object> buildErrorPayload(String message) { |
| | |
| | | } |
| | | emitter.complete(); |
| | | } |
| | | |
| | | } |
| | | |