#AI
zhou zhou
9 小时以前 51877df13075ad10ef51107f15bcd21f1661febe
rsf-server/src/main/java/com/vincent/rsf/server/ai/controller/AiController.java
@@ -2,27 +2,32 @@
import com.fasterxml.jackson.databind.JsonNode;
import com.vincent.rsf.framework.common.R;
import com.vincent.rsf.server.ai.config.AiProperties;
import com.vincent.rsf.server.ai.constant.AiSceneCode;
import com.vincent.rsf.server.ai.dto.AiChatStreamRequest;
import com.vincent.rsf.server.ai.dto.AiSessionCreateRequest;
import com.vincent.rsf.server.ai.dto.AiSessionRenameRequest;
import com.vincent.rsf.server.ai.dto.GatewayChatMessage;
import com.vincent.rsf.server.ai.dto.GatewayChatRequest;
import com.vincent.rsf.server.ai.model.AiChatMessage;
import com.vincent.rsf.server.ai.model.AiChatSession;
import com.vincent.rsf.server.ai.model.AiPromptContext;
import com.vincent.rsf.server.ai.service.AiGatewayClient;
import com.vincent.rsf.server.ai.service.AiPromptContextService;
import com.vincent.rsf.server.ai.service.diagnosis.AiChatStreamOrchestrator;
import com.vincent.rsf.server.ai.service.diagnosis.AiDiagnosisRuntimeService;
import com.vincent.rsf.server.ai.service.AiModelRouteRuntimeService;
import com.vincent.rsf.server.ai.service.AiRuntimeConfigService;
import com.vincent.rsf.server.ai.service.AiSessionService;
import com.vincent.rsf.server.system.controller.BaseController;
import com.vincent.rsf.server.system.entity.AiDiagnosisRecord;
import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import javax.annotation.Resource;
import java.io.IOException;
import java.util.Date;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
@@ -34,17 +39,17 @@
    @Resource
    private AiSessionService aiSessionService;
    @Resource
    private AiProperties aiProperties;
    @Resource
    private AiGatewayClient aiGatewayClient;
    @Resource
    private AiRuntimeConfigService aiRuntimeConfigService;
    @Resource
    private AiPromptContextService aiPromptContextService;
    private AiModelRouteRuntimeService aiModelRouteRuntimeService;
    @Resource
    private AiDiagnosisRuntimeService aiDiagnosisRuntimeService;
    @Resource
    private AiChatStreamOrchestrator aiChatStreamOrchestrator;
    @GetMapping("/model/list")
    public R modelList() {
        List<Map<String, Object>> models = new java.util.ArrayList<>();
        List<Map<String, Object>> models = new ArrayList<>();
        for (AiRuntimeConfigService.ModelRuntimeConfig model : aiRuntimeConfigService.listEnabledModels()) {
            Map<String, Object> item = new LinkedHashMap<>();
            item.put("code", model.getCode());
@@ -99,6 +104,20 @@
    @PostMapping(value = "/chat/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
    public SseEmitter chatStream(@RequestBody AiChatStreamRequest request) {
        return doChatStream(normalizeRequest(request));
    }
    @PostMapping(value = "/diagnose/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
    public SseEmitter diagnoseStream(@RequestBody(required = false) AiChatStreamRequest request) {
        AiChatStreamRequest diagnosisRequest = normalizeRequest(request);
        diagnosisRequest.setSceneCode(AiSceneCode.SYSTEM_DIAGNOSE);
        if (diagnosisRequest.getMessage() == null || diagnosisRequest.getMessage().trim().isEmpty()) {
            diagnosisRequest.setMessage("请对当前WMS系统进行一次巡检诊断,结合库存、任务、设备站点数据识别异常并给出处理建议。");
        }
        return doChatStream(diagnosisRequest);
    }
    private SseEmitter doChatStream(AiChatStreamRequest request) {
        SseEmitter emitter = new SseEmitter(0L);
        Long tenantId = getTenantId();
        Long userId = getLoginUserId();
@@ -110,137 +129,79 @@
            completeWithError(emitter, "消息内容不能为空");
            return emitter;
        }
        AiChatSession session = aiSessionService.ensureSession(tenantId, userId, request.getSessionId(), request.getModelCode());
        aiSessionService.clearStopFlag(session.getId());
        aiSessionService.appendMessage(tenantId, userId, session.getId(), "user", request.getMessage(), session.getModelCode());
        AiRuntimeConfigService.ModelRuntimeConfig modelRuntimeConfig = aiRuntimeConfigService.resolveModel(session.getModelCode());
        AiPromptContext promptContext = new AiPromptContext()
                .setTenantId(tenantId)
                .setUserId(userId)
                .setSessionId(session.getId())
                .setModelCode(session.getModelCode())
                .setQuestion(request.getMessage())
                .setSceneCode(request.getSceneCode());
        List<AiModelRouteRuntimeService.RouteCandidate> candidates = aiModelRouteRuntimeService.resolveCandidates(
                tenantId,
                request.getSceneCode(),
                session.getModelCode()
        );
        if (candidates.isEmpty()) {
            completeWithError(emitter, "未找到可用的AI模型配置");
            return emitter;
        }
        int maxContextMessages = resolveContextSize(candidates);
        List<AiChatMessage> contextMessages = aiSessionService.listContextMessages(
                tenantId,
                userId,
                session.getId(),
                modelRuntimeConfig.getMaxContextMessages()
                maxContextMessages
        );
        AiDiagnosisRecord diagnosisRecord = AiSceneCode.SYSTEM_DIAGNOSE.equals(request.getSceneCode())
                ? aiDiagnosisRuntimeService.startDiagnosis(tenantId, userId, session.getId(), request.getSceneCode(), request.getMessage())
                : null;
        Thread thread = new Thread(() -> {
            StringBuilder assistantReply = new StringBuilder();
            boolean doneSent = false;
            try {
                emitter.send(SseEmitter.event().name("session").data(buildSessionPayload(session), MediaType.APPLICATION_JSON));
                GatewayChatRequest gatewayChatRequest = buildGatewayRequest(
                        tenantId,
                        userId,
                        session,
                        contextMessages,
                        modelRuntimeConfig,
                        request.getMessage()
                );
                aiGatewayClient.stream(gatewayChatRequest, event -> handleGatewayEvent(
                        emitter,
                        event,
                        session,
                        assistantReply
                ));
                if (aiSessionService.isStopRequested(session.getId())) {
                    if (assistantReply.length() > 0) {
                        aiSessionService.appendMessage(tenantId, userId, session.getId(), "assistant", assistantReply.toString(), session.getModelCode());
                    }
                    emitter.send(SseEmitter.event().name("done").data(buildDonePayload(session, true), MediaType.APPLICATION_JSON));
                    doneSent = true;
                }
            } catch (Exception e) {
                try {
                    emitter.send(SseEmitter.event().name("error").data(buildErrorPayload(e.getMessage()), MediaType.APPLICATION_JSON));
                } catch (IOException ignore) {
                }
            } finally {
                if (!doneSent && assistantReply.length() > 0) {
                    aiSessionService.appendMessage(tenantId, userId, session.getId(), "assistant", assistantReply.toString(), session.getModelCode());
                    try {
                        emitter.send(SseEmitter.event().name("done").data(buildDonePayload(session, false), MediaType.APPLICATION_JSON));
                    } catch (IOException ignore) {
                    }
                }
                emitter.complete();
                aiSessionService.clearStopFlag(session.getId());
            }
        }, "ai-chat-stream-" + session.getId());
        Thread thread = new Thread(() -> executeStream(
                emitter, tenantId, userId, session, request, promptContext, contextMessages, diagnosisRecord, candidates
        ), "ai-chat-stream-" + session.getId());
        thread.setDaemon(true);
        thread.start();
        return emitter;
    }
    private boolean handleGatewayEvent(SseEmitter emitter, JsonNode event, AiChatSession session,
                                       StringBuilder assistantReply) throws Exception {
        if (aiSessionService.isStopRequested(session.getId())) {
            return false;
        }
        String type = event.path("type").asText();
        if ("delta".equals(type)) {
            String content = event.path("content").asText("");
            assistantReply.append(content);
            emitter.send(SseEmitter.event().name("delta").data(buildDeltaPayload(session, content), MediaType.APPLICATION_JSON));
            return true;
        }
        if ("error".equals(type)) {
            emitter.send(SseEmitter.event().name("error").data(buildErrorPayload(event.path("message").asText("模型调用失败")), MediaType.APPLICATION_JSON));
            return false;
        }
        if ("done".equals(type)) {
            return false;
        }
        return true;
    private void executeStream(SseEmitter emitter,
                               Long tenantId,
                               Long userId,
                               AiChatSession session,
                               AiChatStreamRequest request,
                               AiPromptContext promptContext,
                               List<AiChatMessage> contextMessages,
                               AiDiagnosisRecord diagnosisRecord,
                               List<AiModelRouteRuntimeService.RouteCandidate> candidates) {
        aiChatStreamOrchestrator.executeStream(
                emitter,
                tenantId,
                userId,
                session,
                request,
                promptContext,
                contextMessages,
                diagnosisRecord,
                candidates
        );
    }
    private GatewayChatRequest buildGatewayRequest(Long tenantId, Long userId, AiChatSession session, List<AiChatMessage> contextMessages,
                                                   AiRuntimeConfigService.ModelRuntimeConfig modelRuntimeConfig,
                                                   String latestQuestion) {
        GatewayChatRequest request = new GatewayChatRequest();
        request.setSessionId(session.getId());
        request.setModelCode(session.getModelCode());
        request.setSystemPrompt(aiPromptContextService.buildSystemPrompt(
                modelRuntimeConfig.getSystemPrompt(),
                new AiPromptContext()
                        .setTenantId(tenantId)
                        .setUserId(userId)
                        .setSessionId(session.getId())
                        .setModelCode(session.getModelCode())
                        .setQuestion(latestQuestion)
        ));
        request.setChatUrl(modelRuntimeConfig.getChatUrl());
        request.setApiKey(modelRuntimeConfig.getApiKey());
        request.setModelName(modelRuntimeConfig.getModelName());
        for (AiChatMessage contextMessage : contextMessages) {
            GatewayChatMessage item = new GatewayChatMessage();
            item.setRole(contextMessage.getRole());
            item.setContent(contextMessage.getContent());
            request.getMessages().add(item);
    private int resolveContextSize(List<AiModelRouteRuntimeService.RouteCandidate> candidates) {
        return aiChatStreamOrchestrator.resolveContextSize(candidates);
    }
    private AiChatStreamRequest normalizeRequest(AiChatStreamRequest request) {
        AiChatStreamRequest normalized = request == null ? new AiChatStreamRequest() : request;
        if (normalized.getSceneCode() == null || normalized.getSceneCode().trim().isEmpty()) {
            normalized.setSceneCode(AiSceneCode.GENERAL_CHAT);
        }
        return request;
    }
    private Map<String, Object> buildSessionPayload(AiChatSession session) {
        Map<String, Object> payload = new LinkedHashMap<>();
        payload.put("sessionId", session.getId());
        payload.put("title", session.getTitle());
        payload.put("modelCode", session.getModelCode());
        return payload;
    }
    private Map<String, Object> buildDeltaPayload(AiChatSession session, String content) {
        Map<String, Object> payload = new LinkedHashMap<>();
        payload.put("sessionId", session.getId());
        payload.put("modelCode", session.getModelCode());
        payload.put("content", content);
        payload.put("timestamp", new Date().getTime());
        return payload;
    }
    private Map<String, Object> buildDonePayload(AiChatSession session, boolean stopped) {
        Map<String, Object> payload = new LinkedHashMap<>();
        payload.put("sessionId", session.getId());
        payload.put("modelCode", session.getModelCode());
        payload.put("stopped", stopped);
        return payload;
        return normalized;
    }
    private Map<String, Object> buildErrorPayload(String message) {
@@ -256,5 +217,5 @@
        }
        emitter.complete();
    }
}