package com.zy.asrs.wms.common.security;
|
|
import com.alibaba.fastjson.JSON;
|
import com.alibaba.fastjson.JSONObject;
|
import com.zy.asrs.framework.common.R;
|
import com.zy.asrs.wms.common.annotation.CacheData;
|
import com.zy.asrs.wms.common.constant.Constants;
|
import com.zy.asrs.wms.common.constant.RedisConstants;
|
import com.zy.asrs.wms.common.domain.CacheHitDto;
|
import com.zy.asrs.wms.system.entity.User;
|
import com.zy.asrs.wms.system.entity.UserLogin;
|
import com.zy.asrs.wms.system.service.UserService;
|
import com.zy.asrs.wms.utils.EncryptUtils;
|
import com.zy.asrs.wms.utils.HttpUtils;
|
import com.zy.asrs.wms.utils.RedisUtil;
|
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Value;
|
import org.springframework.security.core.Authentication;
|
import org.springframework.security.core.context.SecurityContextHolder;
|
import org.springframework.security.core.userdetails.UsernameNotFoundException;
|
import org.springframework.stereotype.Component;
|
import org.springframework.web.filter.OncePerRequestFilter;
|
import org.springframework.web.method.HandlerMethod;
|
import org.springframework.web.servlet.HandlerExecutionChain;
|
import org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerMapping;
|
|
import javax.servlet.FilterChain;
|
import javax.servlet.ServletException;
|
import javax.servlet.http.HttpServletRequest;
|
import javax.servlet.http.HttpServletResponse;
|
import java.io.IOException;
|
import java.lang.reflect.Method;
|
import java.text.SimpleDateFormat;
|
import java.util.ArrayList;
|
import java.util.Arrays;
|
import java.util.Date;
|
import java.util.Map;
|
import java.util.stream.Collectors;
|
|
@Component
|
public class CacheFilter extends OncePerRequestFilter {
|
|
@Value("${system.enableCache}")
|
private Boolean enableCache;
|
@Autowired
|
private RequestMappingHandlerMapping handlerMapping;
|
@Autowired
|
private RedisUtil redisUtil;
|
@Autowired
|
private UserService userService;
|
|
@Override
|
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain chain) throws ServletException, IOException {
|
// 获取当前请求的处理器方法
|
HandlerExecutionChain handlerChain;
|
try {
|
handlerChain = handlerMapping.getHandler(request);
|
if (handlerChain != null) {
|
Object handler = handlerChain.getHandler();
|
if (handler instanceof HandlerMethod) {
|
HandlerMethod handlerMethod = (HandlerMethod) handler;
|
Method method = handlerMethod.getMethod();
|
if (method.isAnnotationPresent(CacheData.class)) {
|
CacheData cacheData = method.getAnnotation(CacheData.class);
|
if (enableCache && cacheData.cache()) {
|
// 创建一个包装请求体的 HttpServletRequestWrapper
|
CachedBodyHttpServletRequest cachedBodyHttpServletRequest = new CachedBodyHttpServletRequest(request);
|
// 创建一个包装响应体的 HttpServletResponseWrapper
|
CachedBodyHttpServletResponse cachedBodyHttpServletResponse = new CachedBodyHttpServletResponse(response);
|
String requestParamCode = getRequestParamCode(cachedBodyHttpServletRequest);
|
|
User user = getUser();
|
ArrayList<Long> roleIds = new ArrayList<>();
|
if (user == null) {
|
roleIds.add(0L);
|
}else {
|
roleIds.addAll(Arrays.asList(user.getUserRoleIds()));
|
}
|
|
Object object = null;
|
for (Long roleId : roleIds) {
|
Object obj = redisUtil.get(RedisConstants.getCacheKey(RedisConstants.CACHE_DATA, cacheData.tableName(), request.getRequestURI(), requestParamCode, roleId));
|
if(obj != null){
|
object = obj;
|
break;
|
}
|
}
|
|
if (object == null) {
|
chain.doFilter(cachedBodyHttpServletRequest, cachedBodyHttpServletResponse);
|
|
// 获取响应内容
|
byte[] responseContent = cachedBodyHttpServletResponse.getContent();
|
String responseBody = new String(responseContent);
|
|
JSONObject result = JSON.parseObject(responseBody);
|
if (Integer.parseInt(result.get("code").toString()) == 200) {
|
for (Long roleId : roleIds) {
|
redisUtil.set(RedisConstants.getCacheKey(RedisConstants.CACHE_DATA, cacheData.tableName(), request.getRequestURI(), requestParamCode, roleId), responseBody, 60 * 60 * 24);
|
}
|
}
|
|
// 将响应内容写回原始的 HttpServletResponse
|
response.getOutputStream().write(responseContent);
|
response.setContentLength(responseContent.length);
|
}else {
|
// 将响应内容写回原始的 HttpServletResponse
|
byte[] responseContent = object.toString().getBytes();
|
response.setContentType("application/json;charset=UTF-8");
|
response.getOutputStream().write(responseContent);
|
response.setContentLength(responseContent.length);
|
}
|
|
statisticsCacheHitCount(object, cacheData.tableName(), request.getRequestURI());
|
return;
|
}
|
}
|
}
|
}
|
|
chain.doFilter(request, response);
|
} catch (Exception e) {
|
e.printStackTrace();
|
HttpUtils.responseError(response, Constants.BAD_CREDENTIALS_CODE, Constants.BAD_CREDENTIALS_MSG,
|
e.toString());
|
return;
|
}
|
}
|
|
private String getRequestParamCode(CachedBodyHttpServletRequest request) throws IOException {
|
// 获取请求方法
|
String requestMethod = request.getMethod();
|
String md5 = "";
|
// 检查请求方法并处理
|
if ("POST".equalsIgnoreCase(requestMethod)) {
|
// 检查是否为 form-data 类型
|
String contentType = request.getContentType();
|
if (contentType != null && (contentType.startsWith("application/x-www-form-urlencoded") || contentType.startsWith("multipart/form-data"))) {
|
// 处理 form-data 参数
|
Map<String, String[]> parameterMap = request.getParameterMap();
|
String jsonString = JSON.toJSONString(parameterMap);
|
md5 = EncryptUtils.md5(jsonString);
|
} else {
|
// 读取请求体中的 JSON 数据
|
String jsonRequestBody = request.getReader().lines().collect(Collectors.joining(System.lineSeparator()));
|
md5 = EncryptUtils.md5(jsonRequestBody);
|
}
|
} else if ("GET".equalsIgnoreCase(requestMethod)) {
|
Map<String, String[]> map = request.getParameterMap();
|
String jsonString = JSON.toJSONString(map);
|
md5 = EncryptUtils.md5(jsonString);
|
}
|
return md5;
|
}
|
|
private User getUser() {
|
try {
|
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
|
if (authentication != null) {
|
Object object = authentication.getPrincipal();
|
if (object instanceof User) {
|
return (User) object;
|
}
|
if(object instanceof UserLogin) {
|
UserLogin userLogin = (UserLogin) object;
|
User user = userService.superGetById(userLogin.getUserId());
|
return user;
|
}
|
}
|
} catch (Exception e) {
|
System.out.println(e.getMessage());
|
}
|
return null;
|
}
|
|
private void statisticsCacheHitCount(Object object, String[] tableNames, String requestURI) {
|
statisticsCacheSaveRedis(object, requestURI);
|
for (String tableName : tableNames) {
|
statisticsCacheSaveRedis(object, tableName);
|
}
|
}
|
|
private void statisticsCacheSaveRedis(Object object, String key) {
|
SimpleDateFormat format = new SimpleDateFormat("yyyy-MM-dd");
|
String now = format.format(new Date());
|
|
String urlKey = RedisConstants.STATISTICS_CACHE_DATA + ":" + now + ":" + key;
|
Object urlCache = redisUtil.get(urlKey);
|
CacheHitDto cacheHitDto = new CacheHitDto(0, 0);
|
if (urlCache != null) {
|
cacheHitDto = JSON.parseObject(urlCache.toString(), CacheHitDto.class);
|
}
|
|
if (object == null) {
|
cacheHitDto.setMiss(cacheHitDto.getMiss() + 1);
|
}else {
|
cacheHitDto.setHit(cacheHitDto.getHit() + 1);
|
}
|
|
redisUtil.set(urlKey, JSON.toJSONString(cacheHitDto));
|
}
|
}
|