数据库
首页 > 数据库> > 基于Spring Aop + Redis实现分布式多维度前置后置限流

基于Spring Aop + Redis实现分布式多维度前置后置限流

作者:互联网

基于Spring Aop + Redis实现分布式多维度前置后置限流

说明

在实际场景,比如发送短信验证码、刷评论是需要一定限流控制的,其中限流又可以分为前置限流,后置限流。

所谓前置限流即为调用目标接口前校验,无论被调用的接口是否发生异常或者是否返回预期值;

后置限流是调用接口后,可以根据指定的Condtion判断是否记录次数,Condtion支持EL表达式。

本文通过Spring Aop + 自定义注解 + Redis 分布式锁 + Redi lua脚本实现前、后置限流。并且提供用户维度、IP维度和自定义EL表达式Key多维度限流

配置Maven依赖

<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>

<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-aop</artifactId>
</dependency>

配置Redis

定义限流注解

前面先创建枚举类型RateLimiterType,即定义支持哪几种限流模式

public enum RateLimiterType {
    /**
     * 客户端ip
     */
    CLIENT_IP,

    /**
     * 用户
     */
    USER,

    /**
     * 自定义模式,需要指定key
     */
    CUSTOM
}

前置限流注解:KeyRateLimiter,其中key支持EL表达式解析,可以获取到目标方法上面的参数作为Key值;另一个type可以指定使用哪种限流维度。

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
@Repeatable(KeyRateLimiters.class)
public @interface KeyRateLimiter {
    /**
     * 限流Key,支持Spring el
     *
     * @return Key
     */
    String key() default "";

    /**
     * 每秒令牌数
     *
     * @return 每秒令牌数
     */
    int limit() default 1;

    /**
     * 频率,默认1
     */
    int interval() default 1;

    /**
     * 频率单位,默认秒
     */
    TimeUnit intervalUnit() default TimeUnit.SECONDS;

    /**
     * 限流类型,如果为CUSTOM,需要指定key
     */
    RateLimiterType type() default RateLimiterType.CUSTOM;

    /**
     * 限流拒绝后的消息内容
     */
    String message() default "您的操作过快,请稍后再试!";
}

后置限流注解:PostKeyRateLimiter,与KeyRateLimiter不同的地方时,增加了condtion,可以根据condtion表达式bool值判断是否记录调用是否计入。另外,PostKeyRateLimiter的实现方式也不一样,由于调用计数是发生在方法调用之后,所以需要结合Redis分布式锁来串行化调用,性能自然比会KeyRateLimiter差一些。两者都是使用Redis pipeline,同时在一个方法上面叠加配置。

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
@Repeatable(PostKeyRateLimiters.class)
public @interface PostKeyRateLimiter {
    /**
     * 限流Key,支持Spring el
     *
     * @return Key
     */
    String key() default "";

    /**
     * 每秒令牌数
     *
     * @return 每秒令牌数
     */
    int limit() default 1;

    /**
     * 频率,默认1
     */
    int interval() default 1;

    /**
     * 频率单位,默认秒
     */
    TimeUnit intervalUnit() default TimeUnit.SECONDS;

    /**
     * 限流类型,如果为CUSTOM,需要指定key
     */
    RateLimiterType type() default RateLimiterType.CUSTOM;

    /**
     * 生效表达式(包括取返回值#rtv.code == 200)
     */
    String condition() default "";

    /**
     * 限流拒绝后的消息内容
     */
    String message() default "您的操作过快,请稍后再试!";
}

再来两个组合注解,支持多个使用限流注解同时使用

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface KeyRateLimiters {
    KeyRateLimiter[] value();
}
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface PostKeyRateLimiters {
    PostKeyRateLimiter[] value();
}

创建配置类,上面PostKeyRateLimiter和KeyRateLimiter最终转为RateLimitConfig实例

@Data
public class RateLimitConfig {
    /**
     * 限流Key
     */
    private String key;

    /**
     * 区间令牌数
     */
    private int limit;

    /**
     * 区间频率
     */
    private int rateInterval;

    /**
     * 频率单位,默认秒
     */
    private TimeUnit intervalUnit;

    /**
     * 限流触发条件,spEL表达式
     */
    private String condition;

    /**
     * 限流类型
     */
    private RateLimiterType type;

    /**
     * 限流拒绝后的消息内容
     */
    private String message;

    public RateLimitConfig(PostKeyRateLimiter keyRateLimiter) {
        this.key = keyRateLimiter.key();
        this.limit = keyRateLimiter.limit();
        this.rateInterval = keyRateLimiter.interval();
        this.intervalUnit = keyRateLimiter.intervalUnit();
        this.message = keyRateLimiter.message();
        this.condition = keyRateLimiter.condition();
        this.type = keyRateLimiter.type();
    }

    public RateLimitConfig(KeyRateLimiter keyRateLimiter) {
        this.key = keyRateLimiter.key();
        this.limit = keyRateLimiter.limit();
        this.rateInterval = keyRateLimiter.interval();
        this.intervalUnit = keyRateLimiter.intervalUnit();
        this.message = keyRateLimiter.message();
        this.type = keyRateLimiter.type();
    }
}

创建Aop类

@Slf4j
@Aspect
@RequiredArgsConstructor
public class RateLimitAspect extends AbstractAspect {
    private final RedisTemplate<String, Object> redisTemplate;
    private final RedisLockService redisLock;
    private static RedisScript<Number> rateLuaScript;

    static {
        // 返回0,1形式
        String luaScript = "local current = tonumber(redis.call('get',KEYS[1]) or '0')\n" +
                "if current >= tonumber(ARGV[1]) then\n" +
                "\treturn 0\n" +
                "end\n" +
                "current = redis.call('incr',KEYS[1])\n" +
                "if current == 1 then\n" +
                "\tredis.call('pexpire',KEYS[1],ARGV[2])\n" +
                "end\n" +
                "return 1";
        rateLuaScript = new DefaultRedisScript<>(luaScript, Number.class);
    }

    /**
     * 前置定义切入点
     */
    @Pointcut("@annotation(com.iwork.boot.redis.rt.KeyRateLimiter) " +
            "|| @annotation(com.iwork.boot.redis.rt.KeyRateLimiters)  " +
            "|| @annotation(com.iwork.boot.redis.rt.PostKeyRateLimiter) " +
            "|| @annotation(com.iwork.boot.redis.rt.PostKeyRateLimiters)")
    public void frontRateLimiter() {
    }

    private Set<String> validateFront(RateLimitConfig... rateLimitConfigs) {
        Set<String> errorMsg = new HashSet<>(rateLimitConfigs.length);
        List<Object> objects = redisTemplate.executePipelined(new SessionCallback<Number>() {
            @Override
            public Number execute(RedisOperations operations) throws DataAccessException {
                for (RateLimitConfig limitConfig : rateLimitConfigs) {
                    // 这里不能使用long类型,否则越界 ERR value is not an integer or out of range
                    int period = (int) limitConfig.getIntervalUnit().toMillis(limitConfig.getRateInterval());
                    operations.execute(rateLuaScript, Collections.singletonList(limitConfig.getKey()), limitConfig.getLimit(), period);
                }
                return null;
            }
        });

        for (int i = 0; i < rateLimitConfigs.length; i++) {
            Number val = (Number) objects.get(i);
            // 被限流
            if (val.longValue() == 0L) {
                errorMsg.add(rateLimitConfigs[i].getMessage());
            }
        }

        return errorMsg;
    }

    private void setKey(String prefix, ProceedingJoinPoint joinPoint, List<RateLimitConfig> limitConfigs) {
        for (RateLimitConfig limitConfig : limitConfigs) {
            String key = limitConfig.getKey();
            RateLimiterType type = limitConfig.getType();
            Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
            String methodKey = prefix + parseElKey(joinPoint, limitConfig.getKey());

            // 基于客户端ip
            if (type == RateLimiterType.CLIENT_IP) {
                HttpServletRequest request = Optional.ofNullable(RequestContextHolder.getRequestAttributes())
                        .map(ServletRequestAttributes.class::cast)
                        .map(ServletRequestAttributes::getRequest)
                        .orElseThrow(() -> new IllegalStateException("只能在Web环境中获取Request对象!"));
                String clientIP = ServletUtil.getClientIP(request);
                methodKey = methodKey + ":" + clientIP;
            }
            // 基于用户维度
            else if (type == RateLimiterType.USER) {
                String userId = authentication.getPrincipal().toString();
                methodKey = methodKey + ":" + userId;
            }
            // 自定义,key不能为空
            else {
                Assert.hasText(key, "限流Key不能为空!");
            }

            limitConfig.setKey(methodKey);
        }
    }

    @Around("frontRateLimiter()")
    public Object executeFront(ProceedingJoinPoint joinPoint) throws Throwable {
        MethodSignature signature = (MethodSignature) joinPoint.getSignature();
        Method method = signature.getMethod();
        List<RateLimitConfig> limitConfigs = new ArrayList<>(8);
        List<RateLimitConfig> postLimitConfigs = new ArrayList<>(4);
        KeyRateLimiter keyRateLimiter = method.getAnnotation(KeyRateLimiter.class);
        KeyRateLimiters keyRateLimiters = method.getAnnotation(KeyRateLimiters.class);
        PostKeyRateLimiter postKeyRateLimiter = method.getAnnotation(PostKeyRateLimiter.class);
        PostKeyRateLimiters postKeyRateLimiters = method.getAnnotation(PostKeyRateLimiters.class);

        if (keyRateLimiter != null) {
            limitConfigs.add(new RateLimitConfig(keyRateLimiter));
        }
        if (keyRateLimiters != null && keyRateLimiters.value().length > 0) {
            Stream.of(keyRateLimiters.value()).map(RateLimitConfig::new).forEach(limitConfigs::add);
        }
        if (postKeyRateLimiter != null) {
            postLimitConfigs.add(new RateLimitConfig(postKeyRateLimiter));
        }
        if (postKeyRateLimiters != null && postKeyRateLimiters.value().length > 0) {
            Stream.of(postKeyRateLimiters.value()).map(RateLimitConfig::new).forEach(postLimitConfigs::add);
        }

        // 前置校验
        setKey("rt:front:", joinPoint, limitConfigs);
        Set<String> errMsgSet = validateFront(limitConfigs.toArray(new RateLimitConfig[]{}));
        if (!errMsgSet.isEmpty()) {
            // 此处应该抛出特定异常,通过全局异常拦截处理
            throw new BusinessException(errMsgSet.toString());
        }

        // 后置校验需要上锁
        if (!postLimitConfigs.isEmpty()) {
            // 设置Key
            setKey("rt:post:", joinPoint, postLimitConfigs);
            String key = "locks:" + postLimitConfigs.iterator().next().getKey();
            // 获取锁后执行
            return redisLock.executeWithLock(key, 10, 60, TimeUnit.SECONDS, () -> {
                SessionCallback<Number> callback = new SessionCallback<Number>() {
                    @Override
                    public Number execute(RedisOperations operations) throws DataAccessException {
                        ValueOperations kvValueOperations = operations.opsForValue();
                        for (RateLimitConfig postLimitConfig : postLimitConfigs) {
                            String key1 = postLimitConfig.getKey();
                            kvValueOperations.get(key1);
                        }
                        return null;
                    }
                };
                List<Object> objects = redisTemplate.executePipelined(callback);
                for (int i = 0; i < postLimitConfigs.size(); i++) {
                    Number val = (Number) objects.get(i);
                    RateLimitConfig rateLimitConfig = postLimitConfigs.get(i);
                    if (val != null && val.longValue() >= rateLimitConfig.getLimit()) {
                        errMsgSet.add(rateLimitConfig.getMessage());
                    }
                }

                if (!errMsgSet.isEmpty()) {
                    // 此处应该抛出特定异常,通过全局异常拦截处理
                    throw new BusinessException(errMsgSet.toString());
                }
                try {
                    // 执行业务方法
                    Object proceed = joinPoint.proceed();
                    // 扣减令牌
                    RateLimitConfig[] filterConfigs = postLimitConfigs.stream()
                            .filter(config -> parsePostSpEl(proceed, config))
                            .collect(Collectors.toList())
                            .toArray(new RateLimitConfig[]{});
                    validateFront(filterConfigs);
                    return proceed;
                } catch (BusinessException e) {
                    throw e;
                } catch (Throwable throwable) {
                    throw new BusinessException(throwable);
                }
            });
        }

        return joinPoint.proceed();
    }

    private boolean parsePostSpEl(Object val, RateLimitConfig limitConfig) {
        String condition = limitConfig.getCondition();
        if (StringUtils.isBlank(condition) || !condition.contains(EL_PREFIX)) {
            return true;
        }
        StandardEvaluationContext context = new StandardEvaluationContext();
        context.setVariable("rtv", val);
        Expression expression = expressionParser.parseExpression(condition);
        return Optional.ofNullable(expression.getValue(context, Boolean.class)).orElse(true);
    }
}

上面使用Redis pipline、Redis Lock,逻辑不难就不做细讲了,有疑问欢迎提问!

使用

@KeyRateLimiter(type = RateLimiterType.CLIENT_IP)
基于IP的前置限流

@KeyRateLimiter(type = RateLimiterType.USER)
基于用户维度的前置限流

@KeyRateLimiter(type = RateLimiterType.CUSTOM, key = "#username" interval="60" condtion="#rtv.code ==200")
自定义Key限流,并且通过返回值code==200才标记有效访问,进行限流
@Slf4j
@RestController
@Api(tags = "系统:系统授权接口")
@RequiredArgsConstructor
public class AuthController {

	@AnonymousAccess
    @ApiOperation("获取验证码")
    @GetMapping(value = "/code")
    @KeyRateLimiter(type = RateLimiterType.CLIENT_IP)
    @KeyRateLimiter(type = RateLimiterType.USER)
    @KeyRateLimiter(type = RateLimiterType.CUSTOM, key = "#username" interval="60" condtion="#rtv.code ==200")
    public XCloudResponse<Object> getCode(@RequestParam String username) {
        // 省略代码细节
    }

标签:String,Spring,多维度,keyRateLimiter,限流,key,new,RateLimitConfig
来源: https://blog.csdn.net/jjxxww12/article/details/113091238