package com.vincent.rsf.ai.gateway.controller; import com.vincent.rsf.ai.gateway.dto.GatewayChatRequest; import com.vincent.rsf.ai.gateway.service.AiGatewayService; import com.vincent.rsf.ai.gateway.service.GatewayStreamEvent; import com.fasterxml.jackson.databind.ObjectMapper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.http.MediaType; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody; import javax.annotation.Resource; import java.io.IOException; import java.io.InterruptedIOException; import java.nio.charset.StandardCharsets; import java.util.concurrent.atomic.AtomicBoolean; @RestController @RequestMapping("/internal/chat") public class AiGatewayController { private static final Logger logger = LoggerFactory.getLogger(AiGatewayController.class); @Resource private AiGatewayService aiGatewayService; @Resource private ObjectMapper objectMapper; @PostMapping(value = "/stream", produces = "application/x-ndjson") public StreamingResponseBody stream(@RequestBody GatewayChatRequest request) { return outputStream -> { logger.info("AI gateway controller stream opened: sessionId={}, routeCode={}, attemptNo={}, modelCode={}", request.getSessionId(), request.getRouteCode(), request.getAttemptNo(), request.getModelCode()); AtomicBoolean streaming = new AtomicBoolean(true); Object writeLock = new Object(); Thread heartbeatThread = new Thread(() -> { while (streaming.get()) { try { Thread.sleep(10000L); if (!streaming.get()) { break; } String json = objectMapper.writeValueAsString(new GatewayStreamEvent() .setType("ping") .setModelCode(request.getModelCode()) .setResponseTime(System.currentTimeMillis())) + "\n"; synchronized (writeLock) { outputStream.write(json.getBytes(StandardCharsets.UTF_8)); outputStream.flush(); } } catch (InterruptedException e) { Thread.currentThread().interrupt(); logger.info("AI gateway heartbeat interrupted: sessionId={}, routeCode={}, attemptNo={}, modelCode={}", request.getSessionId(), request.getRouteCode(), request.getAttemptNo(), request.getModelCode()); break; } catch (Exception e) { logger.warn("AI gateway heartbeat write failed: sessionId={}, routeCode={}, attemptNo={}, modelCode={}, message={}", request.getSessionId(), request.getRouteCode(), request.getAttemptNo(), request.getModelCode(), e.getMessage()); break; } } }, "ai-gateway-heartbeat-" + (request.getSessionId() == null ? "unknown" : request.getSessionId())); heartbeatThread.setDaemon(true); heartbeatThread.start(); try { aiGatewayService.stream(request, event -> { String json = objectMapper.writeValueAsString(event) + "\n"; synchronized (writeLock) { outputStream.write(json.getBytes(StandardCharsets.UTF_8)); outputStream.flush(); } }); } catch (Exception e) { if (isInterruptedError(e)) { logger.warn("AI gateway controller stream interrupted: sessionId={}, routeCode={}, attemptNo={}, modelCode={}, message={}", request.getSessionId(), request.getRouteCode(), request.getAttemptNo(), request.getModelCode(), e.getMessage()); return; } logger.error("AI gateway controller stream failed: sessionId={}, routeCode={}, attemptNo={}, modelCode={}, message={}", request.getSessionId(), request.getRouteCode(), request.getAttemptNo(), request.getModelCode(), e.getMessage(), e); throw new IOException(e); } finally { streaming.set(false); heartbeatThread.interrupt(); logger.info("AI gateway controller stream closed: sessionId={}, routeCode={}, attemptNo={}, modelCode={}", request.getSessionId(), request.getRouteCode(), request.getAttemptNo(), request.getModelCode()); } }; } private boolean isInterruptedError(Throwable throwable) { Throwable current = throwable; while (current != null) { if (current instanceof InterruptedException || current instanceof InterruptedIOException) { return true; } String message = current.getMessage(); if (message != null) { String normalized = message.toLowerCase(); if (normalized.contains("interrupted") || normalized.contains("broken pipe") || normalized.contains("connection reset") || normalized.contains("forcibly closed")) { return true; } } current = current.getCause(); } return false; } }