基于Spring Aop + Redis实现分布式多维度前置后置限流
作者:互联网
基于Spring Aop + Redis实现分布式多维度前置后置限流
说明
在实际场景,比如发送短信验证码、刷评论是需要一定限流控制的,其中限流又可以分为前置限流,后置限流。
所谓前置限流即为调用目标接口前校验,无论被调用的接口是否发生异常或者是否返回预期值;
后置限流是调用接口后,可以根据指定的Condtion判断是否记录次数,Condtion支持EL表达式。
本文通过Spring Aop + 自定义注解 + Redis 分布式锁 + Redi lua脚本实现前、后置限流。并且提供用户维度、IP维度和自定义EL表达式Key多维度限流
配置Maven依赖
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</dependency>
配置Redis
略
定义限流注解
前面先创建枚举类型RateLimiterType,即定义支持哪几种限流模式
public enum RateLimiterType {
/**
* 客户端ip
*/
CLIENT_IP,
/**
* 用户
*/
USER,
/**
* 自定义模式,需要指定key
*/
CUSTOM
}
前置限流注解:KeyRateLimiter,其中key支持EL表达式解析,可以获取到目标方法上面的参数作为Key值;另一个type可以指定使用哪种限流维度。
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
@Repeatable(KeyRateLimiters.class)
public @interface KeyRateLimiter {
/**
* 限流Key,支持Spring el
*
* @return Key
*/
String key() default "";
/**
* 每秒令牌数
*
* @return 每秒令牌数
*/
int limit() default 1;
/**
* 频率,默认1
*/
int interval() default 1;
/**
* 频率单位,默认秒
*/
TimeUnit intervalUnit() default TimeUnit.SECONDS;
/**
* 限流类型,如果为CUSTOM,需要指定key
*/
RateLimiterType type() default RateLimiterType.CUSTOM;
/**
* 限流拒绝后的消息内容
*/
String message() default "您的操作过快,请稍后再试!";
}
后置限流注解:PostKeyRateLimiter,与KeyRateLimiter不同的地方时,增加了condtion,可以根据condtion表达式bool值判断是否记录调用是否计入。另外,PostKeyRateLimiter的实现方式也不一样,由于调用计数是发生在方法调用之后,所以需要结合Redis分布式锁来串行化调用,性能自然比会KeyRateLimiter差一些。两者都是使用Redis pipeline,同时在一个方法上面叠加配置。
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
@Repeatable(PostKeyRateLimiters.class)
public @interface PostKeyRateLimiter {
/**
* 限流Key,支持Spring el
*
* @return Key
*/
String key() default "";
/**
* 每秒令牌数
*
* @return 每秒令牌数
*/
int limit() default 1;
/**
* 频率,默认1
*/
int interval() default 1;
/**
* 频率单位,默认秒
*/
TimeUnit intervalUnit() default TimeUnit.SECONDS;
/**
* 限流类型,如果为CUSTOM,需要指定key
*/
RateLimiterType type() default RateLimiterType.CUSTOM;
/**
* 生效表达式(包括取返回值#rtv.code == 200)
*/
String condition() default "";
/**
* 限流拒绝后的消息内容
*/
String message() default "您的操作过快,请稍后再试!";
}
再来两个组合注解,支持多个使用限流注解同时使用
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface KeyRateLimiters {
KeyRateLimiter[] value();
}
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface PostKeyRateLimiters {
PostKeyRateLimiter[] value();
}
创建配置类,上面PostKeyRateLimiter和KeyRateLimiter最终转为RateLimitConfig实例
@Data
public class RateLimitConfig {
/**
* 限流Key
*/
private String key;
/**
* 区间令牌数
*/
private int limit;
/**
* 区间频率
*/
private int rateInterval;
/**
* 频率单位,默认秒
*/
private TimeUnit intervalUnit;
/**
* 限流触发条件,spEL表达式
*/
private String condition;
/**
* 限流类型
*/
private RateLimiterType type;
/**
* 限流拒绝后的消息内容
*/
private String message;
public RateLimitConfig(PostKeyRateLimiter keyRateLimiter) {
this.key = keyRateLimiter.key();
this.limit = keyRateLimiter.limit();
this.rateInterval = keyRateLimiter.interval();
this.intervalUnit = keyRateLimiter.intervalUnit();
this.message = keyRateLimiter.message();
this.condition = keyRateLimiter.condition();
this.type = keyRateLimiter.type();
}
public RateLimitConfig(KeyRateLimiter keyRateLimiter) {
this.key = keyRateLimiter.key();
this.limit = keyRateLimiter.limit();
this.rateInterval = keyRateLimiter.interval();
this.intervalUnit = keyRateLimiter.intervalUnit();
this.message = keyRateLimiter.message();
this.type = keyRateLimiter.type();
}
}
创建Aop类
@Slf4j
@Aspect
@RequiredArgsConstructor
public class RateLimitAspect extends AbstractAspect {
private final RedisTemplate<String, Object> redisTemplate;
private final RedisLockService redisLock;
private static RedisScript<Number> rateLuaScript;
static {
// 返回0,1形式
String luaScript = "local current = tonumber(redis.call('get',KEYS[1]) or '0')\n" +
"if current >= tonumber(ARGV[1]) then\n" +
"\treturn 0\n" +
"end\n" +
"current = redis.call('incr',KEYS[1])\n" +
"if current == 1 then\n" +
"\tredis.call('pexpire',KEYS[1],ARGV[2])\n" +
"end\n" +
"return 1";
rateLuaScript = new DefaultRedisScript<>(luaScript, Number.class);
}
/**
* 前置定义切入点
*/
@Pointcut("@annotation(com.iwork.boot.redis.rt.KeyRateLimiter) " +
"|| @annotation(com.iwork.boot.redis.rt.KeyRateLimiters) " +
"|| @annotation(com.iwork.boot.redis.rt.PostKeyRateLimiter) " +
"|| @annotation(com.iwork.boot.redis.rt.PostKeyRateLimiters)")
public void frontRateLimiter() {
}
private Set<String> validateFront(RateLimitConfig... rateLimitConfigs) {
Set<String> errorMsg = new HashSet<>(rateLimitConfigs.length);
List<Object> objects = redisTemplate.executePipelined(new SessionCallback<Number>() {
@Override
public Number execute(RedisOperations operations) throws DataAccessException {
for (RateLimitConfig limitConfig : rateLimitConfigs) {
// 这里不能使用long类型,否则越界 ERR value is not an integer or out of range
int period = (int) limitConfig.getIntervalUnit().toMillis(limitConfig.getRateInterval());
operations.execute(rateLuaScript, Collections.singletonList(limitConfig.getKey()), limitConfig.getLimit(), period);
}
return null;
}
});
for (int i = 0; i < rateLimitConfigs.length; i++) {
Number val = (Number) objects.get(i);
// 被限流
if (val.longValue() == 0L) {
errorMsg.add(rateLimitConfigs[i].getMessage());
}
}
return errorMsg;
}
private void setKey(String prefix, ProceedingJoinPoint joinPoint, List<RateLimitConfig> limitConfigs) {
for (RateLimitConfig limitConfig : limitConfigs) {
String key = limitConfig.getKey();
RateLimiterType type = limitConfig.getType();
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
String methodKey = prefix + parseElKey(joinPoint, limitConfig.getKey());
// 基于客户端ip
if (type == RateLimiterType.CLIENT_IP) {
HttpServletRequest request = Optional.ofNullable(RequestContextHolder.getRequestAttributes())
.map(ServletRequestAttributes.class::cast)
.map(ServletRequestAttributes::getRequest)
.orElseThrow(() -> new IllegalStateException("只能在Web环境中获取Request对象!"));
String clientIP = ServletUtil.getClientIP(request);
methodKey = methodKey + ":" + clientIP;
}
// 基于用户维度
else if (type == RateLimiterType.USER) {
String userId = authentication.getPrincipal().toString();
methodKey = methodKey + ":" + userId;
}
// 自定义,key不能为空
else {
Assert.hasText(key, "限流Key不能为空!");
}
limitConfig.setKey(methodKey);
}
}
@Around("frontRateLimiter()")
public Object executeFront(ProceedingJoinPoint joinPoint) throws Throwable {
MethodSignature signature = (MethodSignature) joinPoint.getSignature();
Method method = signature.getMethod();
List<RateLimitConfig> limitConfigs = new ArrayList<>(8);
List<RateLimitConfig> postLimitConfigs = new ArrayList<>(4);
KeyRateLimiter keyRateLimiter = method.getAnnotation(KeyRateLimiter.class);
KeyRateLimiters keyRateLimiters = method.getAnnotation(KeyRateLimiters.class);
PostKeyRateLimiter postKeyRateLimiter = method.getAnnotation(PostKeyRateLimiter.class);
PostKeyRateLimiters postKeyRateLimiters = method.getAnnotation(PostKeyRateLimiters.class);
if (keyRateLimiter != null) {
limitConfigs.add(new RateLimitConfig(keyRateLimiter));
}
if (keyRateLimiters != null && keyRateLimiters.value().length > 0) {
Stream.of(keyRateLimiters.value()).map(RateLimitConfig::new).forEach(limitConfigs::add);
}
if (postKeyRateLimiter != null) {
postLimitConfigs.add(new RateLimitConfig(postKeyRateLimiter));
}
if (postKeyRateLimiters != null && postKeyRateLimiters.value().length > 0) {
Stream.of(postKeyRateLimiters.value()).map(RateLimitConfig::new).forEach(postLimitConfigs::add);
}
// 前置校验
setKey("rt:front:", joinPoint, limitConfigs);
Set<String> errMsgSet = validateFront(limitConfigs.toArray(new RateLimitConfig[]{}));
if (!errMsgSet.isEmpty()) {
// 此处应该抛出特定异常,通过全局异常拦截处理
throw new BusinessException(errMsgSet.toString());
}
// 后置校验需要上锁
if (!postLimitConfigs.isEmpty()) {
// 设置Key
setKey("rt:post:", joinPoint, postLimitConfigs);
String key = "locks:" + postLimitConfigs.iterator().next().getKey();
// 获取锁后执行
return redisLock.executeWithLock(key, 10, 60, TimeUnit.SECONDS, () -> {
SessionCallback<Number> callback = new SessionCallback<Number>() {
@Override
public Number execute(RedisOperations operations) throws DataAccessException {
ValueOperations kvValueOperations = operations.opsForValue();
for (RateLimitConfig postLimitConfig : postLimitConfigs) {
String key1 = postLimitConfig.getKey();
kvValueOperations.get(key1);
}
return null;
}
};
List<Object> objects = redisTemplate.executePipelined(callback);
for (int i = 0; i < postLimitConfigs.size(); i++) {
Number val = (Number) objects.get(i);
RateLimitConfig rateLimitConfig = postLimitConfigs.get(i);
if (val != null && val.longValue() >= rateLimitConfig.getLimit()) {
errMsgSet.add(rateLimitConfig.getMessage());
}
}
if (!errMsgSet.isEmpty()) {
// 此处应该抛出特定异常,通过全局异常拦截处理
throw new BusinessException(errMsgSet.toString());
}
try {
// 执行业务方法
Object proceed = joinPoint.proceed();
// 扣减令牌
RateLimitConfig[] filterConfigs = postLimitConfigs.stream()
.filter(config -> parsePostSpEl(proceed, config))
.collect(Collectors.toList())
.toArray(new RateLimitConfig[]{});
validateFront(filterConfigs);
return proceed;
} catch (BusinessException e) {
throw e;
} catch (Throwable throwable) {
throw new BusinessException(throwable);
}
});
}
return joinPoint.proceed();
}
private boolean parsePostSpEl(Object val, RateLimitConfig limitConfig) {
String condition = limitConfig.getCondition();
if (StringUtils.isBlank(condition) || !condition.contains(EL_PREFIX)) {
return true;
}
StandardEvaluationContext context = new StandardEvaluationContext();
context.setVariable("rtv", val);
Expression expression = expressionParser.parseExpression(condition);
return Optional.ofNullable(expression.getValue(context, Boolean.class)).orElse(true);
}
}
上面使用Redis pipline、Redis Lock,逻辑不难就不做细讲了,有疑问欢迎提问!
使用
@KeyRateLimiter(type = RateLimiterType.CLIENT_IP)
基于IP的前置限流
@KeyRateLimiter(type = RateLimiterType.USER)
基于用户维度的前置限流
@KeyRateLimiter(type = RateLimiterType.CUSTOM, key = "#username" interval="60" condtion="#rtv.code ==200")
自定义Key限流,并且通过返回值code==200才标记有效访问,进行限流
@Slf4j
@RestController
@Api(tags = "系统:系统授权接口")
@RequiredArgsConstructor
public class AuthController {
@AnonymousAccess
@ApiOperation("获取验证码")
@GetMapping(value = "/code")
@KeyRateLimiter(type = RateLimiterType.CLIENT_IP)
@KeyRateLimiter(type = RateLimiterType.USER)
@KeyRateLimiter(type = RateLimiterType.CUSTOM, key = "#username" interval="60" condtion="#rtv.code ==200")
public XCloudResponse<Object> getCode(@RequestParam String username) {
// 省略代码细节
}
标签:String,Spring,多维度,keyRateLimiter,限流,key,new,RateLimitConfig 来源: https://blog.csdn.net/jjxxww12/article/details/113091238