package com.zy.asrs.websocket; import jakarta.websocket.*; import jakarta.servlet.http.HttpServletRequest; import jakarta.websocket.server.HandshakeRequest; import jakarta.websocket.server.ServerEndpoint; import jakarta.websocket.server.ServerEndpointConfig; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.stereotype.Component; import org.apache.tomcat.websocket.server.WsHandshakeRequest; import java.io.IOException; import java.lang.reflect.Field; import java.util.*; import java.util.concurrent.ConcurrentHashMap; @ServerEndpoint(value = "/tv/socket", configurator = TvWebSocketServer.TvConfigurator.class) @Component public class TvWebSocketServer { private static final Logger log = LoggerFactory.getLogger(TvWebSocketServer.class); private static final ConcurrentHashMap SESSIONS = new ConcurrentHashMap<>(); public static class TvConfigurator extends ServerEndpointConfig.Configurator { @Override public void modifyHandshake(ServerEndpointConfig sec, HandshakeRequest request, HandshakeResponse response) { String ip = extractIp(request); sec.getUserProperties().put("ip", ip); log.info("电视机WebSocket握手完成, requestUri: {}, clientIp: {}", request.getRequestURI(), ip); } private String extractIp(HandshakeRequest request) { String headerIp = extractIpFromHeaders(request.getHeaders()); if (isValidIp(headerIp)) { return headerIp; } String tomcatRemoteIp = extractIpFromTomcatRequest(request); if (isValidIp(tomcatRemoteIp)) { return tomcatRemoteIp; } log.warn("电视机WebSocket握手未获取到客户端IP, headers: {}", request.getHeaders().keySet()); return "unknown"; } private String extractIpFromHeaders(Map> headers) { if (headers == null || headers.isEmpty()) { return null; } String[] headerNames = { "X-Forwarded-For", "X-Real-IP", "Proxy-Client-IP", "WL-Proxy-Client-IP", "HTTP_X_FORWARDED_FOR", "HTTP_X_REAL_IP" }; for (String headerName : headerNames) { String ip = firstHeaderValue(headers, headerName); if (isValidIp(ip)) { return normalizeIp(ip.split(",")[0].trim()); } } String remoteAddress = firstHeaderValue(headers, "remoteAddress"); if (isValidIp(remoteAddress)) { return normalizeIp(remoteAddress); } return null; } private String extractIpFromTomcatRequest(HandshakeRequest request) { if (!(request instanceof WsHandshakeRequest wsHandshakeRequest)) { return null; } try { Field requestField = WsHandshakeRequest.class.getDeclaredField("request"); requestField.setAccessible(true); HttpServletRequest httpServletRequest = (HttpServletRequest) requestField.get(wsHandshakeRequest); if (httpServletRequest == null) { return null; } return normalizeIp(httpServletRequest.getRemoteAddr()); } catch (Exception e) { log.warn("电视机WebSocket从Tomcat握手请求中提取IP失败: {}", e.getMessage()); return null; } } private String firstHeaderValue(Map> headers, String headerName) { List values = headers.get(headerName); if (values == null || values.isEmpty()) { return null; } return values.get(0); } private boolean isValidIp(String ip) { return ip != null && !ip.isEmpty() && !"unknown".equalsIgnoreCase(ip); } private String normalizeIp(String ip) { if (ip == null) { return null; } String normalized = ip.trim(); if (normalized.startsWith("/")) { normalized = normalized.substring(1); } if (normalized.startsWith("::ffff:")) { normalized = normalized.substring(7); } if (normalized.startsWith("[") && normalized.contains("]")) { normalized = normalized.substring(1, normalized.indexOf(']')); } else if (normalized.chars().filter(ch -> ch == ':').count() == 1) { int colonIdx = normalized.lastIndexOf(':'); if (colonIdx > 0) { normalized = normalized.substring(0, colonIdx); } } return normalized; } } @OnOpen public void onOpen(Session session) { String ip = getIp(session); SESSIONS.put(ip, session); log.info("电视机WebSocket连接建立, IP: {}, 当前在线数: {}", ip, SESSIONS.size()); } @OnClose public void onClose(Session session) { String ip = getIp(session); SESSIONS.remove(ip); log.info("电视机WebSocket连接关闭, IP: {}, 当前在线数: {}", ip, SESSIONS.size()); } @OnError public void onError(Session session, Throwable error) { String ip = getIp(session); SESSIONS.remove(ip); log.warn("电视机WebSocket传输异常, IP: {}, error: {}", ip, error.getMessage()); } @OnMessage public void onMessage(String message, Session session) { // 电视机端无需发送消息,忽略 } public void sendMessageToDevice(String ip, String message) { Session session = SESSIONS.get(ip); if (session != null && session.isOpen()) { try { session.getBasicRemote().sendText(message); } catch (IOException e) { log.error("推送消息到设备 {} 失败: {}", ip, e.getMessage()); } } } public void sendMessageToAll(String message) { for (Map.Entry entry : SESSIONS.entrySet()) { Session session = entry.getValue(); if (session.isOpen()) { try { session.getBasicRemote().sendText(message); } catch (IOException e) { log.error("广播消息到设备 {} 失败: {}", entry.getKey(), e.getMessage()); } } } } public Set getOnlineIps() { return SESSIONS.keySet(); } private String getIp(Session session) { Object ip = session.getUserProperties().get("ip"); return ip != null ? ip.toString() : "unknown"; } }