限流算法(四)AOP+RedisLua对接口进行限流
2020-11-15 本文已影响0人
茶还是咖啡
限流流程
每次请求,获取令牌桶中的令牌,如果令牌获取成功,代表没有被限流,可以正常访问,如果获取失败代表被限流,访问失败,这时会抛出一个RateLimitException结束。
最终效果
- 我打算结合springboot的手动装配,制作一个限流的工具,最终可以被封装成一个jar包,其他项目需要,直接引入就可以,不用重复开发。
- 具体的用法是这样的
- 在配置类上标注
@EnableRedisRateLimit
注解,激活限流工具 - 在需要限流的接口上标注
@RateLimit
注解,并根据具体的场景设置限流规则
@RateLimit(replenishRate = 3,burstCapacity = 300) @GetMapping("test-limit") public Result<Void> testLimit(){ return Result.buildSuccess(); }
- 在配置类上标注
核心代码介绍
-
@RateLimit
为了方便拓展,使得使用不同的场景,这里通过实现KeyResolver
接口来指定具体的限流维度 - 这里说一下limitProperties的作用,我们可以默认使用注解中的参数指定配置信息,但是为了方便拓展,这里提供了limitProperties,如果指定了limitProperties,那么会以limitProperties的配置为准。
- 上篇文章介绍的限流lua脚本只能针对秒为时间单位进行限流,我这里对它的lua脚本做了一个小小的改变,使得可以支持秒,分钟,小时,天 为时间单位的限流。
@Documented
@Inherited
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface RateLimit {
/**
* 限流维度,默认使用uri进行限流
*
* @return uri
*/
Class<? extends KeyResolver> keyResolver() default UriKeyResolver.class;
/**
* 限流配置,如果实现了该接口,默认以这个为准
*
* @return limitProp
*/
Class<? extends LimitProperties> limitProperties() default DefaultLimitProperties.class;
/**
* 令牌桶每秒填充平均速率
*
* @return replenishRate
*/
int replenishRate() default 1;
/**
* 令牌桶总容量
*
* @return burstCapacity
*/
int burstCapacity() default 3;
/**
* 限流时间维度,默认为秒
* 支持秒,分钟,小时,天
* 即,
* {@link TimeUnit#SECONDS},
* {@link TimeUnit#MINUTES},
* {@link TimeUnit#HOURS},
* {@link TimeUnit#DAYS}
*
* @return TimeUnit
* @since 1.0.2
*/
TimeUnit timeUnit() default TimeUnit.SECONDS;
}
limitProperties
public interface LimitProperties {
/**
* 令牌桶每秒填充平均速率
*
* @return replenishRate
*/
int replenishRate();
/**
* 令牌桶总容量
*
* @return burstCapacity
*/
int burstCapacity();
/**
* 限流时间维度,默认为秒
* 支持秒,分钟,小时,天
* 即,
* {@link TimeUnit#SECONDS},
* {@link TimeUnit#MINUTES},
* {@link TimeUnit#HOURS},
* {@link TimeUnit#DAYS}
*
* @return TimeUnit
* @since 1.0.2
*/
TimeUnit timeUnit();
}
- lua
local tokens_key = KEYS[1]
local timestamp_key = KEYS[2]
--redis.log(redis.LOG_WARNING, "tokens_key " .. tokens_key)
local rate = tonumber(ARGV[1])
local capacity = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local requested = tonumber(ARGV[4])
local time_unit = tonumber(ARGV[5])
-- 填满令牌桶所需要的时间
local fill_time = capacity/rate
local ttl = math.floor((fill_time*time_unit)*2)
--redis.log(redis.LOG_WARNING, "rate " .. ARGV[1])
--redis.log(redis.LOG_WARNING, "capacity " .. ARGV[2])
--redis.log(redis.LOG_WARNING, "now " .. ARGV[3])
--redis.log(redis.LOG_WARNING, "requested " .. ARGV[4])
--redis.log(redis.LOG_WARNING, "filltime " .. fill_time)
--redis.log(redis.LOG_WARNING, "ttl " .. ttl)
local last_tokens = tonumber(redis.call("get", tokens_key))
if last_tokens == nil then
last_tokens = capacity
end
--redis.log(redis.LOG_WARNING, "last_tokens " .. last_tokens)
local last_refreshed = tonumber(redis.call("get", timestamp_key))
if last_refreshed == nil then
last_refreshed = 0
end
--redis.log(redis.LOG_WARNING, "last_refreshed " .. last_refreshed)
local delta = math.max(0, now-last_refreshed)
local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
local allowed = filled_tokens >= requested
local new_tokens = filled_tokens
local allowed_num = 0
if allowed then
new_tokens = filled_tokens - requested
allowed_num = 1
end
--redis.log(redis.LOG_WARNING, "delta " .. delta)
--redis.log(redis.LOG_WARNING, "filled_tokens " .. filled_tokens)
--redis.log(redis.LOG_WARNING, "allowed_num " .. allowed_num)
--redis.log(redis.LOG_WARNING, "new_tokens " .. new_tokens)
redis.call("setex", tokens_key, ttl, new_tokens)
redis.call("setex", timestamp_key, ttl, now)
return { allowed_num, new_tokens }
- 核心的AOP
@Slf4j
@Aspect
public class RateLimitInterceptor implements ApplicationContextAware {
@Resource
private RedisTemplate<String, Object> stringRedisTemplate;
@Resource
private RedisScript<List<Long>> rateLimitRedisScript;
private ApplicationContext applicationContext;
@Around("execution(public * *(..)) && @annotation(org.ywb.aoplimiter.anns.RateLimit)")
public Object interceptor(ProceedingJoinPoint pjp) throws Throwable {
MethodSignature signature = (MethodSignature) pjp.getSignature();
Method method = signature.getMethod();
RateLimit rateLimit = method.getAnnotation(RateLimit.class);
// 断言不会被限流
assertNonLimit(rateLimit, pjp);
return pjp.proceed();
}
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
this.applicationContext = applicationContext;
}
public void assertNonLimit(RateLimit rateLimit, ProceedingJoinPoint pjp) {
Class<? extends KeyResolver> keyResolverClazz = rateLimit.keyResolver();
KeyResolver keyResolver = applicationContext.getBean(keyResolverClazz);
String resolve = keyResolver.resolve(HttpContentHelper.getCurrentRequest(), pjp);
List<String> keys = getKeys(resolve);
LimitProperties limitProperties = getLimitProperties(rateLimit);
// 根据限流时间维度计算时间
long timeLong = getCurrentTimeLong(limitProperties.timeUnit());
// The arguments to the LUA script. time() returns unixtime in seconds.
List<String> scriptArgs = Arrays.asList(limitProperties.replenishRate() + "",
limitProperties.burstCapacity() + "", (Instant.now().toEpochMilli() / timeLong) + "", "1", timeLong + "");
// 第一个参数是是否被限流,第二个参数是剩余令牌数
List<Long> rateLimitResponse = this.stringRedisTemplate.execute(this.rateLimitRedisScript, keys, scriptArgs.toArray());
Assert.notNull(rateLimitResponse, "redis execute redis lua limit failed.");
Long isAllowed = rateLimitResponse.get(0);
Long newTokens = rateLimitResponse.get(1);
log.info("rate limit key [{}] result: isAllowed [{}] new tokens [{}].", resolve, isAllowed, newTokens);
if (isAllowed <= 0) {
throw new RateLimitException(resolve);
}
}
private LimitProperties getLimitProperties(RateLimit rateLimit) {
Class<? extends LimitProperties> aClass = rateLimit.limitProperties();
if (aClass == DefaultLimitProperties.class) {
// 选取注解中的配置
return new DefaultLimitProperties(rateLimit.replenishRate(), rateLimit.burstCapacity(), rateLimit.timeUnit());
}
// 优先使用用户自己的配置类
return applicationContext.getBean(aClass);
}
private long getCurrentTimeLong(TimeUnit timeUnit) {
switch (timeUnit) {
case SECONDS:
return 1;
case MINUTES:
return 60;
case HOURS:
return 60 * 60;
case DAYS:
return 60 * 60 * 24;
default:
throw new IllegalArgumentException("timeUnit:" + timeUnit + " not support");
}
}
private List<String> getKeys(String id) {
// use `{}` around keys to use Redis Key hash tags
// this allows for using redis cluster
// Make a unique key per user.
String prefix = "request_rate_limiter.{" + id;
// You need two Redis keys for Token Bucket.
String tokenKey = prefix + "}.tokens";
String timestampKey = prefix + "}.timestamp";
return Arrays.asList(tokenKey, timestampKey);
}
}
核心代码就是这么多,完整的源码我已上传至github,
传送门