算法Redis

限流算法(四)AOP+RedisLua对接口进行限流

2020-11-15  本文已影响0人  茶还是咖啡

限流流程


每次请求,获取令牌桶中的令牌,如果令牌获取成功,代表没有被限流,可以正常访问,如果获取失败代表被限流,访问失败,这时会抛出一个RateLimitException结束。

最终效果

  1. 我打算结合springboot的手动装配,制作一个限流的工具,最终可以被封装成一个jar包,其他项目需要,直接引入就可以,不用重复开发。
  2. 具体的用法是这样的
    1. 在配置类上标注@EnableRedisRateLimit注解,激活限流工具
    2. 在需要限流的接口上标注@RateLimit注解,并根据具体的场景设置限流规则
     @RateLimit(replenishRate = 3,burstCapacity = 300)
     @GetMapping("test-limit")
     public Result<Void> testLimit(){
         return Result.buildSuccess();
     }
    

核心代码介绍

  1. @RateLimit 为了方便拓展,使得使用不同的场景,这里通过实现KeyResolver接口来指定具体的限流维度
  2. 这里说一下limitProperties的作用,我们可以默认使用注解中的参数指定配置信息,但是为了方便拓展,这里提供了limitProperties,如果指定了limitProperties,那么会以limitProperties的配置为准。
  3. 上篇文章介绍的限流lua脚本只能针对秒为时间单位进行限流,我这里对它的lua脚本做了一个小小的改变,使得可以支持秒,分钟,小时,天 为时间单位的限流。
@Documented
@Inherited
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface RateLimit {

    /**
     * 限流维度,默认使用uri进行限流
     *
     * @return uri
     */
    Class<? extends KeyResolver> keyResolver() default UriKeyResolver.class;

    /**
     * 限流配置,如果实现了该接口,默认以这个为准
     *
     * @return limitProp
     */
    Class<? extends LimitProperties> limitProperties() default DefaultLimitProperties.class;

    /**
     * 令牌桶每秒填充平均速率
     *
     * @return replenishRate
     */
    int replenishRate() default 1;

    /**
     * 令牌桶总容量
     *
     * @return burstCapacity
     */
    int burstCapacity() default 3;

    /**
     * 限流时间维度,默认为秒
     * 支持秒,分钟,小时,天
     * 即,
     * {@link TimeUnit#SECONDS},
     * {@link TimeUnit#MINUTES},
     * {@link TimeUnit#HOURS},
     * {@link TimeUnit#DAYS}
     *
     * @return TimeUnit
     * @since 1.0.2
     */
    TimeUnit timeUnit() default TimeUnit.SECONDS;
}
  1. limitProperties
public interface LimitProperties {
    /**
     * 令牌桶每秒填充平均速率
     *
     * @return replenishRate
     */
    int replenishRate();

    /**
     * 令牌桶总容量
     *
     * @return burstCapacity
     */
    int burstCapacity();

    /**
     * 限流时间维度,默认为秒
     * 支持秒,分钟,小时,天
     * 即,
     * {@link TimeUnit#SECONDS},
     * {@link TimeUnit#MINUTES},
     * {@link TimeUnit#HOURS},
     * {@link TimeUnit#DAYS}
     *
     * @return TimeUnit
     * @since 1.0.2
     */
    TimeUnit timeUnit();
}
  1. lua
local tokens_key = KEYS[1]
local timestamp_key = KEYS[2]
--redis.log(redis.LOG_WARNING, "tokens_key " .. tokens_key)

local rate = tonumber(ARGV[1])
local capacity = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local requested = tonumber(ARGV[4])
local time_unit = tonumber(ARGV[5])
-- 填满令牌桶所需要的时间
local fill_time = capacity/rate
local ttl = math.floor((fill_time*time_unit)*2)

--redis.log(redis.LOG_WARNING, "rate " .. ARGV[1])
--redis.log(redis.LOG_WARNING, "capacity " .. ARGV[2])
--redis.log(redis.LOG_WARNING, "now " .. ARGV[3])
--redis.log(redis.LOG_WARNING, "requested " .. ARGV[4])
--redis.log(redis.LOG_WARNING, "filltime " .. fill_time)
--redis.log(redis.LOG_WARNING, "ttl " .. ttl)

local last_tokens = tonumber(redis.call("get", tokens_key))
if last_tokens == nil then
    last_tokens = capacity
end
--redis.log(redis.LOG_WARNING, "last_tokens " .. last_tokens)

local last_refreshed = tonumber(redis.call("get", timestamp_key))
if last_refreshed == nil then
    last_refreshed = 0
end
--redis.log(redis.LOG_WARNING, "last_refreshed " .. last_refreshed)

local delta = math.max(0, now-last_refreshed)
local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
local allowed = filled_tokens >= requested
local new_tokens = filled_tokens
local allowed_num = 0
if allowed then
    new_tokens = filled_tokens - requested
    allowed_num = 1
end

--redis.log(redis.LOG_WARNING, "delta " .. delta)
--redis.log(redis.LOG_WARNING, "filled_tokens " .. filled_tokens)
--redis.log(redis.LOG_WARNING, "allowed_num " .. allowed_num)
--redis.log(redis.LOG_WARNING, "new_tokens " .. new_tokens)

redis.call("setex", tokens_key, ttl, new_tokens)
redis.call("setex", timestamp_key, ttl, now)

return { allowed_num, new_tokens }
  1. 核心的AOP
@Slf4j
@Aspect
public class RateLimitInterceptor implements ApplicationContextAware {

    @Resource
    private RedisTemplate<String, Object> stringRedisTemplate;

    @Resource
    private RedisScript<List<Long>> rateLimitRedisScript;

    private ApplicationContext applicationContext;

    @Around("execution(public * *(..)) && @annotation(org.ywb.aoplimiter.anns.RateLimit)")
    public Object interceptor(ProceedingJoinPoint pjp) throws Throwable {
        MethodSignature signature = (MethodSignature) pjp.getSignature();
        Method method = signature.getMethod();
        RateLimit rateLimit = method.getAnnotation(RateLimit.class);
        // 断言不会被限流
        assertNonLimit(rateLimit, pjp);
        return pjp.proceed();
    }

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        this.applicationContext = applicationContext;
    }

    public void assertNonLimit(RateLimit rateLimit, ProceedingJoinPoint pjp) {
        Class<? extends KeyResolver> keyResolverClazz = rateLimit.keyResolver();
        KeyResolver keyResolver = applicationContext.getBean(keyResolverClazz);
        String resolve = keyResolver.resolve(HttpContentHelper.getCurrentRequest(), pjp);
        List<String> keys = getKeys(resolve);

        LimitProperties limitProperties = getLimitProperties(rateLimit);

        // 根据限流时间维度计算时间
        long timeLong = getCurrentTimeLong(limitProperties.timeUnit());

        // The arguments to the LUA script. time() returns unixtime in seconds.
        List<String> scriptArgs = Arrays.asList(limitProperties.replenishRate() + "",
                limitProperties.burstCapacity() + "", (Instant.now().toEpochMilli() / timeLong) + "", "1", timeLong + "");
        // 第一个参数是是否被限流,第二个参数是剩余令牌数
        List<Long> rateLimitResponse = this.stringRedisTemplate.execute(this.rateLimitRedisScript, keys, scriptArgs.toArray());
        Assert.notNull(rateLimitResponse, "redis execute redis lua limit failed.");
        Long isAllowed = rateLimitResponse.get(0);
        Long newTokens = rateLimitResponse.get(1);
        log.info("rate limit key [{}] result: isAllowed [{}] new tokens [{}].", resolve, isAllowed, newTokens);
        if (isAllowed <= 0) {
            throw new RateLimitException(resolve);
        }
    }

    private LimitProperties getLimitProperties(RateLimit rateLimit) {
        Class<? extends LimitProperties> aClass = rateLimit.limitProperties();
        if (aClass == DefaultLimitProperties.class) {
            // 选取注解中的配置
            return new DefaultLimitProperties(rateLimit.replenishRate(), rateLimit.burstCapacity(), rateLimit.timeUnit());
        }
        // 优先使用用户自己的配置类
        return applicationContext.getBean(aClass);
    }

    private long getCurrentTimeLong(TimeUnit timeUnit) {
        switch (timeUnit) {
            case SECONDS:
                return 1;
            case MINUTES:
                return 60;
            case HOURS:
                return 60 * 60;
            case DAYS:
                return 60 * 60 * 24;
            default:
                throw new IllegalArgumentException("timeUnit:" + timeUnit + " not support");
        }
    }

    private List<String> getKeys(String id) {
        // use `{}` around keys to use Redis Key hash tags
        // this allows for using redis cluster

        // Make a unique key per user.
        String prefix = "request_rate_limiter.{" + id;

        // You need two Redis keys for Token Bucket.
        String tokenKey = prefix + "}.tokens";
        String timestampKey = prefix + "}.timestamp";
        return Arrays.asList(tokenKey, timestampKey);
    }
}

核心代码就是这么多,完整的源码我已上传至github,
传送门

上一篇下一篇

猜你喜欢

热点阅读