package com.vincent.rsf.ai.gateway.controller;
|
|
import com.vincent.rsf.ai.gateway.dto.GatewayChatRequest;
|
import com.vincent.rsf.ai.gateway.service.AiGatewayService;
|
import com.vincent.rsf.ai.gateway.service.GatewayStreamEvent;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
import org.slf4j.Logger;
|
import org.slf4j.LoggerFactory;
|
import org.springframework.http.MediaType;
|
import org.springframework.web.bind.annotation.PostMapping;
|
import org.springframework.web.bind.annotation.RequestBody;
|
import org.springframework.web.bind.annotation.RequestMapping;
|
import org.springframework.web.bind.annotation.RestController;
|
import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody;
|
|
import javax.annotation.Resource;
|
import java.io.IOException;
|
import java.io.InterruptedIOException;
|
import java.nio.charset.StandardCharsets;
|
import java.util.concurrent.atomic.AtomicBoolean;
|
|
@RestController
|
@RequestMapping("/internal/chat")
|
public class AiGatewayController {
|
|
private static final Logger logger = LoggerFactory.getLogger(AiGatewayController.class);
|
|
@Resource
|
private AiGatewayService aiGatewayService;
|
@Resource
|
private ObjectMapper objectMapper;
|
|
@PostMapping(value = "/stream", produces = "application/x-ndjson")
|
public StreamingResponseBody stream(@RequestBody GatewayChatRequest request) {
|
return outputStream -> {
|
logger.info("AI gateway controller stream opened: sessionId={}, routeCode={}, attemptNo={}, modelCode={}",
|
request.getSessionId(),
|
request.getRouteCode(),
|
request.getAttemptNo(),
|
request.getModelCode());
|
AtomicBoolean streaming = new AtomicBoolean(true);
|
Object writeLock = new Object();
|
Thread heartbeatThread = new Thread(() -> {
|
while (streaming.get()) {
|
try {
|
Thread.sleep(10000L);
|
if (!streaming.get()) {
|
break;
|
}
|
String json = objectMapper.writeValueAsString(new GatewayStreamEvent()
|
.setType("ping")
|
.setModelCode(request.getModelCode())
|
.setResponseTime(System.currentTimeMillis())) + "\n";
|
synchronized (writeLock) {
|
outputStream.write(json.getBytes(StandardCharsets.UTF_8));
|
outputStream.flush();
|
}
|
} catch (InterruptedException e) {
|
Thread.currentThread().interrupt();
|
logger.info("AI gateway heartbeat interrupted: sessionId={}, routeCode={}, attemptNo={}, modelCode={}",
|
request.getSessionId(),
|
request.getRouteCode(),
|
request.getAttemptNo(),
|
request.getModelCode());
|
break;
|
} catch (Exception e) {
|
logger.warn("AI gateway heartbeat write failed: sessionId={}, routeCode={}, attemptNo={}, modelCode={}, message={}",
|
request.getSessionId(),
|
request.getRouteCode(),
|
request.getAttemptNo(),
|
request.getModelCode(),
|
e.getMessage());
|
break;
|
}
|
}
|
}, "ai-gateway-heartbeat-" + (request.getSessionId() == null ? "unknown" : request.getSessionId()));
|
heartbeatThread.setDaemon(true);
|
heartbeatThread.start();
|
try {
|
aiGatewayService.stream(request, event -> {
|
String json = objectMapper.writeValueAsString(event) + "\n";
|
synchronized (writeLock) {
|
outputStream.write(json.getBytes(StandardCharsets.UTF_8));
|
outputStream.flush();
|
}
|
});
|
} catch (Exception e) {
|
if (isInterruptedError(e)) {
|
logger.warn("AI gateway controller stream interrupted: sessionId={}, routeCode={}, attemptNo={}, modelCode={}, message={}",
|
request.getSessionId(),
|
request.getRouteCode(),
|
request.getAttemptNo(),
|
request.getModelCode(),
|
e.getMessage());
|
return;
|
}
|
logger.error("AI gateway controller stream failed: sessionId={}, routeCode={}, attemptNo={}, modelCode={}, message={}",
|
request.getSessionId(),
|
request.getRouteCode(),
|
request.getAttemptNo(),
|
request.getModelCode(),
|
e.getMessage(),
|
e);
|
throw new IOException(e);
|
} finally {
|
streaming.set(false);
|
heartbeatThread.interrupt();
|
logger.info("AI gateway controller stream closed: sessionId={}, routeCode={}, attemptNo={}, modelCode={}",
|
request.getSessionId(),
|
request.getRouteCode(),
|
request.getAttemptNo(),
|
request.getModelCode());
|
}
|
};
|
}
|
|
private boolean isInterruptedError(Throwable throwable) {
|
Throwable current = throwable;
|
while (current != null) {
|
if (current instanceof InterruptedException || current instanceof InterruptedIOException) {
|
return true;
|
}
|
String message = current.getMessage();
|
if (message != null) {
|
String normalized = message.toLowerCase();
|
if (normalized.contains("interrupted")
|
|| normalized.contains("broken pipe")
|
|| normalized.contains("connection reset")
|
|| normalized.contains("forcibly closed")) {
|
return true;
|
}
|
}
|
current = current.getCause();
|
}
|
return false;
|
}
|
|
}
|