基于 springboot+Redis 实现接口幂等性(随笔)

2021-04-20  本文已影响0人  简楼

前言

在开发中,我们或多或少总会碰到重复提交,造成数据错误的情况;
面对这种情况,我们一般会在页面控制提交按钮,在提交一次之后就变灰,这属于比较见效快,应用简单的一种方法,还有借助数据库实现等等;

但这总方法还是有弊端的,比如有心人会把链接单独拿出来请求等等,所以为了应对这种,接口幂等性校验就应运而生了;

什么是幂等性?

幂等性是数学与计算机学概念,常见于抽象代数中;

现在用在接口上,我们可以理解为,针对同一个接口,多次发出同一个请求,必须保证操作只执行一次;

因为在调用接口发生重复提交时,总是会造成系统所无法承受的损失,所以必须阻止这种现象的发生;

比如:付款操作,可能由于网络比较慢,在支付的时候,就重复点击了几次,如果没有幂等性校验的话,那么每点击一次就会付款一次,这就会让我们浪费钱,甚至让我们拒绝使用手机支付这种便捷支付功能;

思路

  1. 自定义注解,标注哪些方法或类是参加幂等性校验的;
  2. 自定义一个HttpServletRequest实现类,因为请求的 HttpServletRequest 默认实现类里的 getReader() 和 getInputStream() 只能调用一次;
  3. 自定义一个Filter实现类,过滤出我们需要拦截的请求;
  4. 自定义一个HandlerIntercepter的实现类,结合 Redis 实现幂等性校验;
  5. 将过滤器、拦截器注入到系统中;

自定义幂等性校验注解

@Target({ElementType.TYPE, ElementType.METHOD})
@Documented
@Retention(RetentionPolicy.RUNTIME)
public @interface Idempotent {

    /**
     * 是否把 require 数据用来计算幂等key
     *
     * @return
     */
    boolean require() default false;

    /**
     * 参与幂等性计算的字段,默认所有字段
     */
    String[] values() default {};

    /**
     * 幂等性校验失效时间(毫秒)
     */
    long expiredTime() default 60_000L;
}

作用:

  1. 标注的方法或类是否参加幂等性校验
  2. 参加幂等性校验的字段
  3. 幂等性校验的持续多长时间

HttpServletRequest实现类

实现 HttpServletRequestWrapper ,将 request 中的 @RequestBody 取出来,并重写 getReader() 和 getInputStream(),重写 getInputStream() 方法时,将标志位还原,让系统无感知我们取过值;

public class RequestWrapper extends HttpServletRequestWrapper {

    private String body;

    public RequestWrapper(HttpServletRequest request) {
        super(request);
        StringBuilder stringBuilder = new StringBuilder();
        BufferedReader bufferedReader = null;
        InputStream inputStream = null;
        try {
            inputStream = request.getInputStream();
            if (inputStream != null) {
                bufferedReader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8));
                char[] charBuffer = new char[128];
                int bytesRead = -1;
                while ((bytesRead = bufferedReader.read(charBuffer)) > 0) {
                    stringBuilder.append(charBuffer, 0, bytesRead);
                }
            } else {
                stringBuilder.append("");
            }
        } catch (IOException ex) {

        } finally {
            if (inputStream != null) {
                try {
                    inputStream.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
            if (bufferedReader != null) {
                try {
                    bufferedReader.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
        body = stringBuilder.toString();
    }

    @Override
    public BufferedReader getReader() throws IOException {
        return new BufferedReader(new InputStreamReader(this.getInputStream(), StandardCharsets.UTF_8));
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {
        final ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(
                body.getBytes(StandardCharsets.UTF_8));
        ServletInputStream servletInputStream = new ServletInputStream() {
            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public void setReadListener(ReadListener readListener) {
            }

            @Override
            public int read() throws IOException {
                return byteArrayInputStream.read();
            }
        };
        return servletInputStream;
    }

    public String getBody() {
        return this.body;
    }
}

Filter实现类

这一步的作用,就是拦截请求,将我们自定义的 Request 注入进去,避开 getReader() 和 getInputStream() 只能调用一次的情况,因为做拦截,只是取值校验,不能影响后面的实际业务;

@Slf4j
public class ResubmitFilter implements Filter {

    @Override
    public void init(FilterConfig filterConfig) throws ServletException {
        Filter.super.init(filterConfig);
    }

    @Override
    public void destroy() {
        Filter.super.destroy();
    }

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
            throws IOException, ServletException {
        RequestWrapper requestWrapper = null;
        try {
            HttpServletRequest req = (HttpServletRequest) request;
            // 排除GET请求,不做幂等性校验
            if (!HttpMethod.GET.name().equals(req.getMethod())) {
                requestWrapper = new RequestWrapper(req);
            }
        } catch (Exception e) {
            log.warn("RequestWrapper Error:", e);
        }
        chain.doFilter((Objects.isNull(requestWrapper) ? request : requestWrapper),
                response);
    }
}

我这里只排除了 GET 请求不做幂等性校验,这里你们可以按照自己的实际需要定义;

HandlerIntercepter的实现类

这一步,做的就是幂等性校验的具体步骤;

结合注解,把请求中的参数、IP,通过摘要加密的方式,生成一个唯一的key,保存到 redis 中,并设置过期时间,这样就可以在过期时间内,保证该请求指挥出现一次;

@Slf4j
@Component
public class ResponseResultInterceptor implements HandlerInterceptor {

    @Resource
    private RedisTemplate<String, String> redisTemplate;

    private final static String REQUEST_URL = "url";

    /**
     * Controller逻辑执行之前 可以幂等性校验,防止重复提交
     * <p>
     * 注意:ServletRequest 中 getReader() 和 getInputStream() 只能调用一次,也就是 request 值取了一次,就无法再取
     *
     * @param request
     * @param response
     * @param handler
     * @return boolean
     */
    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler)
            throws Exception {
        if (request instanceof RequestWrapper) {
            // 获取@RequestBody注解参数
            RequestWrapper requestWrapper = (RequestWrapper) request;
            String body = requestWrapper.getBody();
            if (handler instanceof HandlerMethod) {
                HandlerMethod handlerMethod = (HandlerMethod) handler;
                Class<?> clazz = handlerMethod.getBeanType();
                Method method = handlerMethod.getMethod();
                if (clazz.isAnnotationPresent(Idempotent.class)) {
                    Idempotent annotation = clazz.getAnnotation(Idempotent.class);
                    validUrl(body, request, annotation);
                } else if (method.isAnnotationPresent(Idempotent.class)) {
                    Idempotent annotation = method.getAnnotation(Idempotent.class);
                    validUrl(body, request, annotation);
                }
            }
        }
        return true;
    }

    /**
     * 校验url,重复提交
     *
     * @param body
     * @param request
     * @param annotation
     **/
    private void validUrl(String body, HttpServletRequest request, Idempotent annotation) {
        if (annotation.require()) {
            String[] values = annotation.values();
            Map<String, String[]> parameterMap = request.getParameterMap();
            JSONObject jsonObject = JSONUtil.parseObj(body);
            jsonObject.set(REQUEST_URL, request.getRequestURL());
            jsonObject.putAll(parameterMap);
            Map<String, Object> stringObjectMap = sortByKey(jsonObject, values);
            // 摘要加密
            String md5Hex = DigestUtil.md5Hex(JSONUtil.toJsonStr(stringObjectMap));
            long expiredTime = annotation.expiredTime();
            Boolean bool = redisTemplate.opsForValue().setIfAbsent(md5Hex, "1", expiredTime, TimeUnit.MILLISECONDS);
            Assert.isTrue(bool, "提交太频繁了,请稍后提交!");
        }
    }

    /**
     * map 按 key 升序排序,只取 values 字段,values为空时,代表全部字段
     *
     * @param map
     * @param values
     */
    private Map<String, Object> sortByKey(Map<String, Object> map, String[] values) {
        Boolean bool = values.length < 1;
        Map<String, Object> result = new LinkedHashMap<>(map.size());
        map.entrySet().stream()
                .sorted(Map.Entry.comparingByKey())
                .forEachOrdered(e -> {
                    if (bool || isCheckKey(e.getKey(), values)) {
                        result.put(e.getKey(), e.getValue());
                    }
                });
        return result;
    }

    /**
     * 校验 key 是否存在 keys数组中
     *
     * @param key
     * @param keys
     * @return java.lang.Boolean
     **/
    private Boolean isCheckKey(String key, String[] keys) {
        for (String value : keys) {
            if (key.equals(value) || key.equals(REQUEST_URL)) {
                return true;
            }
        }
        return false;
    }
}

将过滤器、拦截器注入到系统

@Configuration
public class WebConfigurer implements WebMvcConfigurer {

    @Resource
    private ResponseResultInterceptor responseResultInterceptor;

    @Override
    public void addInterceptors(InterceptorRegistry registry) {
        // 添加自定义拦截器
        registry.addInterceptor(responseResultInterceptor).addPathPatterns("/**");
    }

    @Bean
    public FilterRegistrationBean servletRegistrationBean() {
        //通过FilterRegistrationBean实例设置优先级可以生效
        ResubmitFilter resubmitFilter = new ResubmitFilter();
        FilterRegistrationBean<ResubmitFilter> bean = new FilterRegistrationBean<>();
        //注册自定义过滤器
        bean.setFilter(resubmitFilter);
        //过滤器名称
        bean.setName("resubmitFilter");
        //过滤所有路径
        bean.addUrlPatterns("/*");
        //优先级,越低越优先
        bean.setOrder(Ordered.LOWEST_PRECEDENCE);
        return bean;
    }
}

保留项目

大家思考下,为什么这里不用AOP这么强大的功能,反而去用过滤器、拦截器呢?

欢迎大家留言,集思广益!!!

上一篇 下一篇

猜你喜欢

热点阅读