基于Redis + Lua的令牌桶限流器的实现
2022-03-10 本文已影响0人
桃子是水果
开发环境
- jdk 11.0.10
- SpringBoot 2.6.2
- Idea
主要依赖
<dependency>
<groupId>redis.clients</groupId>
<artifactId>jedis</artifactId>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-pool2</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-devtools</artifactId>
<scope>runtime</scope>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-configuration-processor</artifactId>
<optional>true</optional>
</dependency>
核心代码
自定义注解
@Documented
@Retention(RUNTIME)
@Target(value = {ElementType.METHOD})
public @interface RateLimit {
/**
* 限流接口名称
* @return 限流接口名称
*/
String interfaceName();
/**
* 最大令牌数
* @return 最大令牌数
*/
long maxPermits();
/**
* 每秒生成的令牌数
* @return
*/
long tokensPerSeconds();
}
限流器抽象类
public abstract class RateLimiter {
private static final Logger logger = LoggerFactory.getLogger(RateLimiter.class);
/**
* 是否开启限流
*/
private boolean limited = true;
/**
* 开启限流功能
*/
public void open() {
if (!this.limited) {
this.limited = true;
} else {
logger.info("the limiter has started...");
}
}
/**
* 关闭限流功能
*/
public void close() {
if (this.limited) {
this.limited = false;
} else {
logger.info("the limiter has stopped...");
}
}
/**
* 获取令牌(指定接口限流)
* @param interfaceName 需要限流的接口名
* @param maxPermits 最大令牌数
* @param tokensPerSeconds 每秒生成的令牌数
* @return boolean 是否通过限流(获取到令牌)
*/
protected abstract boolean acquire(String interfaceName, long maxPermits, long tokensPerSeconds);
/**
* 获取令牌(指定接口)
* @param interfaceName 需要限流的接口名
* @return boolean 是否通过限流(获取到令牌)
*/
public boolean tryAcquire(String interfaceName, long maxPermits, long tokensPerSeconds) {
if (this.limited) {
return this.acquire(interfaceName, maxPermits, tokensPerSeconds);
} else {
return true;
}
}
}
令牌桶实现类
public class TokenBucketRateLimiter extends RateLimiter {
private static final Logger logger = LoggerFactory.getLogger(TokenBucketRateLimiter.class);
/**
* redis的lua脚本
*/
private DefaultRedisScript<Boolean> script;
/**
* redisTemplate
*/
private RedisTemplate<String, Object> redisTemplate;
public TokenBucketRateLimiter(DefaultRedisScript<Boolean> script, RedisTemplate<String, Object> redisTemplate) {
this.script = script;
this.redisTemplate = redisTemplate;
}
/**
* 限流检测(单个接口)
* @param interfaceName 需要限流的接口名
* @param maxPermits 最大令牌数
* @param tokensPerSeconds 每秒生成的令牌数
* @return 是否通过限流 true: 通过
*/
@Override
protected boolean acquire(String interfaceName, long maxPermits, long tokensPerSeconds) {
// 错误的参数将不起作用
if (maxPermits <= 0 || tokensPerSeconds <= 0) {
logger.warn("maxPermits and tokensPerSeconds can not be less than zero...");
return true;
}
// 参数结构: KEYS = [限流的key] ARGV = [最大令牌数, 每秒生成的令牌数, 本次请求的毫秒数]
Boolean result = this.redisTemplate.execute(this.script, Collections.singletonList(interfaceName), maxPermits, tokensPerSeconds, System.currentTimeMillis());
return result!=null && result;
}
}
具体实现令牌桶的Lua脚本
-- LUA脚本会以单线程执行,不会有并发问题,一个脚本中的执行过程中如果报错,那么已执行的操作不会回滚
-- KEYS和ARGV是外部传入进来需要操作的redis数据库中的key,下标从1开始
-- 参数结构: KEYS = [限流的key] ARGV = [最大令牌数, 每秒生成的令牌数, 本次请求的毫秒数]
local info = redis.pcall('HMGET', KEYS[1], 'last_time', 'stored_token_nums')
local last_time = info[1] --最后一次通过限流的时间
local stored_token_nums = tonumber(info[2]) -- 剩余的令牌数量
local max_token = tonumber(ARGV[1])
local token_rate = tonumber(ARGV[2])
local current_time = tonumber(ARGV[3])
local past_time = 0
local rateOfperMills = token_rate/1000 -- 每毫秒生产令牌速率
if stored_token_nums == nil then
-- 第一次请求或者键已经过期
stored_token_nums = max_token --令牌恢复至最大数量
last_time = current_time --记录请求时间
else
-- 处于流量中
past_time = current_time - last_time --经过了多少时间
if past_time <= 0 then
--高并发下每个服务的时间可能不一致
past_time = 0 -- 强制变成0 此处可能会出现少量误差
end
-- 两次请求期间内应该生成多少个token
local generated_nums = math.floor(past_time * rateOfperMills) -- 向下取整,多余的认为还没生成完
stored_token_nums = math.min((stored_token_nums + generated_nums), max_token) -- 合并所有的令牌后不能超过设定的最大令牌数
end
local returnVal = 0 -- 返回值
if stored_token_nums > 0 then
returnVal = 1 -- 通过限流
stored_token_nums = stored_token_nums - 1 -- 减少令牌
-- 必须要在获得令牌后才能重新记录时间。举例: 当每隔2ms请求一次时,只要第一次没有获取到token,那么后续会无法生产token,永远只过去了2ms
last_time = last_time + past_time
end
-- 更新缓存
redis.call('HMSET', KEYS[1], 'last_time', last_time, 'stored_token_nums', stored_token_nums)
-- 设置超时时间
-- 令牌桶满额的时间(超时时间)(ms) = 空缺的令牌数 * 生成一枚令牌所需要的毫秒数(1 / 每毫秒生产令牌速率)
redis.call('PEXPIRE', KEYS[1], math.ceil((1/rateOfperMills) * (max_token - stored_token_nums)))
return returnVal
切面类
@Aspect
public class RateLimitAspect {
private static final Logger logger = LoggerFactory.getLogger(RateLimitAspect.class);
private RateLimiter rateLimiter;
public RateLimitAspect(RateLimiter rateLimiter) {
this.rateLimiter = rateLimiter;
}
/**
* 标注切点-所有标识了RateLimit注解的方法
*/
@Pointcut("@annotation(cn.t.redis.limiter.annotations.RateLimit)")
public void pointCut(){};
@Before("pointCut()")
public void before(JoinPoint joinPoint) {
Method method = ((MethodSignature)joinPoint.getSignature()).getMethod();
RateLimit a = method.getAnnotation(RateLimit.class);
if (a != null) {
String name = a.interfaceName();
long maxPermits = a.maxPermits();
long tokensPerSeconds = a.tokensPerSeconds();
// 执行限流判断
var ret = this.rateLimiter.tryAcquire(name, maxPermits, tokensPerSeconds);
if (!ret) {
throw new RateLimitException("the interface can not be accessed in the meantime...");
}
}
}
}
自定义异常
public class RateLimitException extends RuntimeException {
public RateLimitException() {}
public RateLimitException(String message) {
super(message);
}
}
自动配置类
@Configuration
@AutoConfigureBefore(RedisAutoConfiguration.class) // 高优先级,先于自动默认的自动配置生成RedisTemplate
public class LimiterAutoConfiguration {
@Autowired
private RedisConnectionFactory connectionFactory;
/**
* 配置redisTemplate
* @return redisTemplate
*/
@Bean
@ConditionalOnMissingBean(RedisTemplate.class)
public RedisTemplate<String, Object> redisTemplate() {
RedisTemplate<String, Object> redisTemplate = new RedisTemplate<>();
redisTemplate.setConnectionFactory(this.connectionFactory);
// 定义Jackson2JsonRedisSerializer序列化对象
Jackson2JsonRedisSerializer<Object> jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer<>(Object.class);
ObjectMapper objectMapper = new ObjectMapper();
// 指定要序列化的域,ALL:field,get和set等,ANY: 可见性,会将有private修饰符的字段也序列化
objectMapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
objectMapper.activateDefaultTyping(LaissezFaireSubTypeValidator.instance,ObjectMapper.DefaultTyping.NON_FINAL);
jackson2JsonRedisSerializer.setObjectMapper(objectMapper);
// 使用StringRedisSerializer来序列化和反序列化redis的key值
redisTemplate.setKeySerializer(new StringRedisSerializer());
redisTemplate.setHashKeySerializer(new StringRedisSerializer());
// 使用jackson2JsonRedisSerializer序列化和反序列化value
redisTemplate.setValueSerializer(jackson2JsonRedisSerializer);
redisTemplate.setHashValueSerializer(jackson2JsonRedisSerializer);
// 属性设置完成afterPropertiesSet就会被调用,可以对设置不成功的做一些默认处理
redisTemplate.afterPropertiesSet();
return redisTemplate;
}
/**
* redis的lua脚本对象
* @return lua脚本对象
*/
@Bean
public DefaultRedisScript<Boolean> redisScript() {
DefaultRedisScript<Boolean> redisScript = new DefaultRedisScript<>();
redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("RateLimiter.lua")));
redisScript.setResultType(Boolean.class);
return redisScript;
}
/**
* 默认限流器的实现-令牌桶
* @return 默认限流器
*/
@Bean
@ConditionalOnMissingBean(RateLimiter.class)
public RateLimiter rateLimiter() {
return new TokenBucketRateLimiter(this.redisScript(), this.redisTemplate());
}
/**
* 限流切面
* @param rateLimiter
* @return
*/
@Bean
@ConditionalOnBean(RateLimiter.class)
public RateLimitAspect rateLimitAspect(RateLimiter rateLimiter) {
return new RateLimitAspect(rateLimiter);
}
}
自动配置类指示文件(src/main/resources/META-INF/spring.factories
)
org.springframework.boot.autoconfigure.EnableAutoConfiguration=\
cn.t.redis.limiter.configuration.LimiterAutoConfiguration
打包后在需要使用限流功能的模块中引入即可
使用方法
-
引入本jar包
<dependency> <groupId>cn.t.redis.limiter</groupId> <artifactId>limiter-spring-boot-starter</artifactId> <version>1.0.0</version> </dependency>
-
配置
Redis
连接信息
spring:
redis:
#host: localhost # 单点连接ip
#port: 18379 # # 单点连接端口
timeout: 6000 # 连接超时时间
password: your password
client-type: lettuce #指定连接工厂类型
cluster:
max-redirects: 3 # 获取失败 最大重定向次数
nodes: # 集群节点
- 127.0.0.1:7001
- 127.0.0.1:7002
- 127.0.0.1:7003
- 127.0.0.1:7004
- 127.0.0.1:7005
- 127.0.0.1:7006
lettuce: # lettuce连接池
pool:
max-active: 100 # 连接池最大连接数(使用负值表示没有限制)
max-idle: 20 # 最大空闲连接数
min-idle: 10 # 最小空闲连接数
max-wait: 1500 # 连接池最大阻塞等待时间(ms)(使用负值表示没有限制)
- 在需要限流的接口处使用注解
@RequestMapping("/index")
@RateLimit(interfaceName = "limit", maxPermits = 5, tokensPerSeconds = 1)
public String ratelimit() {
return "hello world";
}
- 未通过限流的访问会抛出异常,建议在全局异常处理器中捕获处理。
例如:
@RestControllerAdvice
public class GlobalErrorController {
@ExceptionHandler(RateLimitException.class)
public String ratelimiteHanler(RateLimitException e) {
return e.getMessage();
}
}
代码地址: 基于Redis + Lua的令牌桶限流器的实现