基于Redis和配置中心的实时频率限制

2021-06-24  本文已影响0人  十毛tenmao

如果使用网关,一般可以在网关进行限频控制;如果使用nginx,也可以使用lua+redis实现分布式限频;但是有的底层服务提供给内网其他应用调用,有的调用方本身没有对客户请求限频,所以请求都会到达底层服务。 内部应用,就不一定走网关,所以底层服务本身需要提供限频能力。

关键特性

实现原理

实现

@Slf4j
@Component
@Order(1000)
@SuppressWarnings("UnstableApiUsage")
@WebFilter
public class RateLimiterFilter implements Filter {
    private static final Type RATE_RULE_MAP_TYPE = new TypeToken<LinkedHashMap<String, List<RateLimiterRule>>>() {
    }.getType();
    private static final Joiner DIM_JOINER = Joiner.on(":");
    private static final Joiner.MapJoiner HEADER_JOINER = Joiner.on(",").withKeyValueSeparator("=");

    @Resource
    private RedisTemplate<String, Long> redisTemplate;

    private LinkedHashMap<Pattern, List<RateLimiterRule>> rateLimiterRules = new LinkedHashMap<>();

    /**
     * 设置频率限制规则.
     * 这里使用Spring配置,结合配置中心可以实现动态配置效果。 也可以把配置信息写入数据库,再提供管理端页面,方便后续运维.
     */
    @Value("${rateLimiterRule:}")
    private void configRateLimiterRule(String rateLimiterRule) {
        log.info("start to configRateLimiterRule: {}", rateLimiterRule);
        if (!StringUtils.hasText(rateLimiterRule)) {
            return;
        }
        try {
            LinkedHashMap<String, List<RateLimiterRule>> rateLimiterRules = JsonUtil.fromJson(rateLimiterRule, RATE_RULE_MAP_TYPE);
            this.rateLimiterRules = rateLimiterRules.entrySet().stream()
                    .collect(Collectors.toMap(entry -> Pattern.compile(entry.getKey()), Map.Entry::getValue, (f, s) -> s, LinkedHashMap::new));
        } catch (RuntimeException e) {
            log.error("fail to configRateLimiterRule: [{}]", rateLimiterRule, e);
        }
    }

    /**
     * 限频开关.
     */
    @Value("${switch.rateLimiter:false}")
    private boolean rateLimiterSwitch;

    /**
     * 限频规则.
     */
    @Data
    private static class RateLimiterRule {
        /**
         * 计算频率的维度.
         */
        private List<String> dimensions;

        /**
         * 限制次数.
         */
        private Integer limit;
    }

    @Override
    public void init(FilterConfig filterConfig) throws ServletException {
        Filter.super.init(filterConfig);
    }

    @Override
    public void destroy() {
        Filter.super.destroy();
    }

    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
        HttpServletRequest request = (HttpServletRequest) servletRequest;
        if (!rateLimiterSwitch) {
            filterChain.doFilter(servletRequest, servletResponse);
            return;
        }

        log.debug("check rate limit");
        //如果分布式限频出现故障,不能影响服务正常运行
        boolean access = true;
        try {
            access = checkAccess(request);
        } catch (RuntimeException e) {
            log.warn("fail to check access limit", e);
        }

        if (access) {
            filterChain.doFilter(servletRequest, servletResponse);
        } else {
            HttpServletResponse response = (HttpServletResponse) servletResponse;
            response.setHeader("Content-Type", "application/json;charset=UTF-8");
            response.getWriter().print(new Gson().toJson(new WebResult<>(409, "TOO_LARGE_FREQUENT")));
            response.setStatus(200);
        }
    }

    private boolean checkAccess(HttpServletRequest request) {
        String url = request.getRequestURI();
        //根据url找到对应的规则,按照顺序,找到一个就返回
        Optional<Pattern> ruleKeyOpt = rateLimiterRules.keySet().stream()
                .filter(pattern -> pattern.matcher(url).matches())
                .findFirst();

        //配置了规则才需要校验
        if (ruleKeyOpt.isPresent()) {
            List<RateLimiterRule> rules = this.rateLimiterRules.get(ruleKeyOpt.get());
            List<Boolean> rulesAllowed = rules.stream().map(rule -> checkAccessRule(request, rule)).collect(Collectors.toList());

            return rulesAllowed.stream().allMatch(c -> c);
        }
        return true;
    }

    private boolean checkAccessRule(HttpServletRequest request, RateLimiterRule rule) {
        List<String> dimensions = rule.getDimensions();

        //找到所有的限频的维度值
        Map<String, String> dimValues = Collections.emptyMap();
        if (!CollectionUtils.isEmpty(dimensions)) {
            dimValues = dimensions.stream()
                    .map(dim -> {
                        String v = request.getHeader(dim);
                        return v == null ? null : Pair.of(dim, v);
                    })
                    .filter(Objects::nonNull)
                    .collect(Collectors.toMap(Pair::getFirst, Pair::getSecond));
            //如果维度值没有找到,则该规则不限制,这么做是因为度如果没有维度分开统计,该接口调用频率会远超过预计有维度值的调用
            if (dimValues.size() < dimensions.size()) {
                return true;
            }
        }

        //每个维度值都有对应的统计信息
        String dimKey = DIM_JOINER.join(dimValues.values());
        String key = String.format("%s:%s:%s", request.getRequestURI(), dimKey, System.currentTimeMillis() / 10000);

        //访问的次数
        Long accessTimes = redisTemplate.opsForValue().increment(key);
        if (accessTimes == null || accessTimes == 1L) {
            //如果是第一次设置,则设置超时时间
            redisTemplate.expire(key, 5, TimeUnit.MINUTES);
        }
        boolean allowed = accessTimes == null || accessTimes <= rule.getLimit();
        if (!allowed) {
            log.info("NOT ALLOWED: uri[{}], dim[{}], times[{}], limit[{}]",
                    request.getRequestURI(), HEADER_JOINER.join(dimValues), accessTimes, rule.getLimit());
        }
        return allowed;
    }
}
{
    "/tenmao/api/hello":[
        {
            "dimensions":[
                "uid"
            ],
            "limit":2
        }
    ],
    "/.*":[
        {
            "dimensions":[
                "uid"
            ],
            "limit":10
        }
    ]
}
上一篇下一篇

猜你喜欢

热点阅读