package com.zy.asrs.wms.common.security; import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONObject; import com.zy.asrs.framework.common.R; import com.zy.asrs.wms.common.annotation.CacheData; import com.zy.asrs.wms.common.constant.Constants; import com.zy.asrs.wms.common.constant.RedisConstants; import com.zy.asrs.wms.common.domain.CacheHitDto; import com.zy.asrs.wms.system.entity.User; import com.zy.asrs.wms.system.entity.UserLogin; import com.zy.asrs.wms.system.service.UserService; import com.zy.asrs.wms.utils.EncryptUtils; import com.zy.asrs.wms.utils.HttpUtils; import com.zy.asrs.wms.utils.RedisUtil; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.userdetails.UsernameNotFoundException; import org.springframework.stereotype.Component; import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.method.HandlerMethod; import org.springframework.web.servlet.HandlerExecutionChain; import org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerMapping; import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.io.IOException; import java.lang.reflect.Method; import java.text.SimpleDateFormat; import java.util.ArrayList; import java.util.Arrays; import java.util.Date; import java.util.Map; import java.util.stream.Collectors; @Component public class CacheFilter extends OncePerRequestFilter { @Value("${system.enableCache}") private Boolean enableCache; @Autowired private RequestMappingHandlerMapping handlerMapping; @Autowired private RedisUtil redisUtil; @Autowired private UserService userService; @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain chain) throws ServletException, IOException { // 获取当前请求的处理器方法 HandlerExecutionChain handlerChain; try { handlerChain = handlerMapping.getHandler(request); if (handlerChain != null) { Object handler = handlerChain.getHandler(); if (handler instanceof HandlerMethod) { HandlerMethod handlerMethod = (HandlerMethod) handler; Method method = handlerMethod.getMethod(); if (method.isAnnotationPresent(CacheData.class)) { CacheData cacheData = method.getAnnotation(CacheData.class); if (enableCache && cacheData.cache()) { // 创建一个包装请求体的 HttpServletRequestWrapper CachedBodyHttpServletRequest cachedBodyHttpServletRequest = new CachedBodyHttpServletRequest(request); // 创建一个包装响应体的 HttpServletResponseWrapper CachedBodyHttpServletResponse cachedBodyHttpServletResponse = new CachedBodyHttpServletResponse(response); String requestParamCode = getRequestParamCode(cachedBodyHttpServletRequest); User user = getUser(); ArrayList roleIds = new ArrayList<>(); if (user == null) { roleIds.add(0L); }else { roleIds.addAll(Arrays.asList(user.getUserRoleIds())); } Object object = null; for (Long roleId : roleIds) { Object obj = redisUtil.get(RedisConstants.getCacheKey(RedisConstants.CACHE_DATA, cacheData.tableName(), request.getRequestURI(), requestParamCode, roleId)); if(obj != null){ object = obj; break; } } if (object == null) { chain.doFilter(cachedBodyHttpServletRequest, cachedBodyHttpServletResponse); // 获取响应内容 byte[] responseContent = cachedBodyHttpServletResponse.getContent(); String responseBody = new String(responseContent); JSONObject result = JSON.parseObject(responseBody); if (Integer.parseInt(result.get("code").toString()) == 200) { for (Long roleId : roleIds) { redisUtil.set(RedisConstants.getCacheKey(RedisConstants.CACHE_DATA, cacheData.tableName(), request.getRequestURI(), requestParamCode, roleId), responseBody, 60 * 60 * 24); } } // 将响应内容写回原始的 HttpServletResponse response.getOutputStream().write(responseContent); response.setContentLength(responseContent.length); }else { // 将响应内容写回原始的 HttpServletResponse byte[] responseContent = object.toString().getBytes(); response.setContentType("application/json;charset=UTF-8"); response.getOutputStream().write(responseContent); response.setContentLength(responseContent.length); } statisticsCacheHitCount(object, cacheData.tableName(), request.getRequestURI()); return; } } } } chain.doFilter(request, response); } catch (Exception e) { e.printStackTrace(); HttpUtils.responseError(response, Constants.BAD_CREDENTIALS_CODE, Constants.BAD_CREDENTIALS_MSG, e.toString()); return; } } private String getRequestParamCode(CachedBodyHttpServletRequest request) throws IOException { // 获取请求方法 String requestMethod = request.getMethod(); String md5 = ""; // 检查请求方法并处理 if ("POST".equalsIgnoreCase(requestMethod)) { // 检查是否为 form-data 类型 String contentType = request.getContentType(); if (contentType != null && (contentType.startsWith("application/x-www-form-urlencoded") || contentType.startsWith("multipart/form-data"))) { // 处理 form-data 参数 Map parameterMap = request.getParameterMap(); String jsonString = JSON.toJSONString(parameterMap); md5 = EncryptUtils.md5(jsonString); } else { // 读取请求体中的 JSON 数据 String jsonRequestBody = request.getReader().lines().collect(Collectors.joining(System.lineSeparator())); md5 = EncryptUtils.md5(jsonRequestBody); } } else if ("GET".equalsIgnoreCase(requestMethod)) { Map map = request.getParameterMap(); String jsonString = JSON.toJSONString(map); md5 = EncryptUtils.md5(jsonString); } return md5; } private User getUser() { try { Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); if (authentication != null) { Object object = authentication.getPrincipal(); if (object instanceof User) { return (User) object; } if(object instanceof UserLogin) { UserLogin userLogin = (UserLogin) object; User user = userService.superGetById(userLogin.getUserId()); return user; } } } catch (Exception e) { System.out.println(e.getMessage()); } return null; } private void statisticsCacheHitCount(Object object, String[] tableNames, String requestURI) { statisticsCacheSaveRedis(object, requestURI); for (String tableName : tableNames) { statisticsCacheSaveRedis(object, tableName); } } private void statisticsCacheSaveRedis(Object object, String key) { SimpleDateFormat format = new SimpleDateFormat("yyyy-MM-dd"); String now = format.format(new Date()); String urlKey = RedisConstants.STATISTICS_CACHE_DATA + ":" + now + ":" + key; Object urlCache = redisUtil.get(urlKey); CacheHitDto cacheHitDto = new CacheHitDto(0, 0); if (urlCache != null) { cacheHitDto = JSON.parseObject(urlCache.toString(), CacheHitDto.class); } if (object == null) { cacheHitDto.setMiss(cacheHitDto.getMiss() + 1); }else { cacheHitDto.setHit(cacheHitDto.getHit() + 1); } redisUtil.set(urlKey, JSON.toJSONString(cacheHitDto)); } }