数据库
首页 > 数据库> > 基于SpringDataRedis实现高性能的集群限流组件

基于SpringDataRedis实现高性能的集群限流组件

作者:互联网

特性:

1.预申请资源, 减少对redis的请求次数, 提升性能

2.预判失败, 防止在限流资源不足时高频访问redis, 提升性能

3.限流的最小时间窗口为1s

 

基础依赖

 <dependency>
       <groupId>org.springframework.boot</groupId>
       <artifactId>spring-boot-starter-data-redis</artifactId>
       <version>2.x.x.RELEASE</version>
 </dependency>

 

工具类源码

RedisLimits

package com.idanchuang.component.redis.util;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.redis.core.script.DefaultRedisScript;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * 限流工具类
 *
 * @author yjy
 * @date 2020/9/3 18:21
 **/
public class RedisLimits {

    private static final Logger log = LoggerFactory.getLogger(RedisLimits.class);

    /** 默认预配额因子, 默认预配额为 总配额的 1%  (此值不能大于0.1) */
    private static double defaultFactor = 0.01D;

    /** 申请失败缓冲时间(当从redis申请配额失败后, 在一段时间内不再申请, 直接失败), 单位: 毫秒 */
    private static long bufferTime = 10L;

    /** redis key前缀 */
    private static final String REDIS_KEY_PREFIX = "limit:";

    /** 资源对应的预配额因子, 不设置则取 默认预配额因子 */
    private static final Map<String, Double> FACTOR_MAP = new HashMap<>();

    /** 本地预配额 */
    private static final Map<String, AtomicInteger> PRE_REQUIRE = new ConcurrentHashMap<>();

    /** 最近申请失败的记录, 资源名 -> 失败时间 */
    private static final Map<String, Long> LAST_FAILED = new ConcurrentHashMap<>();

    /**
     * 设置默认的预配额因子
     * @param defaultFactor 因子, 不能大于 0.1
     */
    public static void setDefaultFactor(double defaultFactor) {
        RedisLimits.defaultFactor = Math.min(defaultFactor, 0.1D);
    }

    /**
     * 设置指定资源的预配额因子
     * @param name 资源名
     * @param factor 因子, 不能大于 0.1
     */
    public static void setFactor(String name, double factor) {
        FACTOR_MAP.put(name, Math.min(factor, 0.1D));
    }

    /**
     * 申请失败缓冲时间
     * @param bufferTime 缓冲时间
     */
    public static void setBufferTime(long bufferTime) {
        RedisLimits.bufferTime = Math.max(bufferTime, 1L);
    }

    public static boolean require(String name, int limit) {
        return require(name, limit, 1, 1, TimeUnit.SECONDS);
    }

    public static boolean require(String name, int limit, int require) {
        return require(name, limit, require, 1, TimeUnit.SECONDS);
    }

    public static boolean require(String name, int limit, int timeWindow, TimeUnit timeUnit) {
        return require(name, limit, 1, timeWindow, timeUnit);
    }

    /**
     * 限流访问申请
     * @param name 限流资源名
     * @param limit 限流额度
     * @param require 申请额度, 默认: 1
     * @param timeWindow 时间窗口值, 默认: 1
     * @param timeUnit 时间窗口单位, 默认: 秒
     * @return 是否通过
     */
    public static boolean require(String name, int limit, int require, int timeWindow, TimeUnit timeUnit) {
        if (require < 1) {
            return true;
        }
        if (require > limit) {
            return false;
        }
        // 预判失败
        if (predictFailed(name)) {
            log.debug("require predictFailed, name: {}", name);
            return false;
        }
        // 尝试申请本地预配额
        if (requireLocal(name, require)) {
            log.debug("require requireLocal success, name: {}, require: {}", name, require);
            return true;
        }
        try {
            // 假如时间窗口配置大于1s, 则进行分割
            long secondWindow = timeUnit.toSeconds(timeWindow);
            secondWindow = Math.max(secondWindow, 1);
            // 计算每秒限流数
            int secondLimit = (int) (limit / secondWindow);
            secondLimit = Math.max(secondLimit, 1);
            // 计算每秒预配额
            int preRequire = (int)(limit * FACTOR_MAP.getOrDefault(name, defaultFactor));
            int secondPreRequire = (int) (preRequire / secondWindow);
            // 如果申请额度小于每秒预配额, 那么进行预申请
            if (secondPreRequire > 0 && secondPreRequire > require) {
                // 申请预配额
                if (requireRedis(name, secondLimit, secondPreRequire)) {
                    // 更新本地预配额
                    updateLocal(name, secondPreRequire - require);
                    return true;
                }
            }
            // 直接从redis申请配额
            return requireRedis(name, secondLimit, require);
        } catch (Exception e) {
            log.error("require error", e);
            // 降级通过
            return true;
        }
    }

    /**
     * 在本地申请配额
     * @param name 资源名
     * @param require 申请配额数量
     * @return 是否申请成功
     */
    private static boolean requireLocal(String name, int require) {
        AtomicInteger preRequire = getLocal(name);
        if (preRequire.get() < require) {
            return false;
        }
        updateLocal(name, require * -1);
        return true;
    }

    /**
     * 更新本地配额
     * @param name 资源名
     * @param updateRequire 更新数量
     * @return 更新后的值
     */
    private static AtomicInteger updateLocal(String name, int updateRequire) {
        AtomicInteger local = getLocal(name);
        local.addAndGet(updateRequire);
        return local;
    }

    /**
     * 初始化本地配额
     * @param name 资源名
     * @return 本地配额
     */
    private static AtomicInteger getLocal(String name) {
        if (!PRE_REQUIRE.containsKey(name)) {
            PRE_REQUIRE.put(name, new AtomicInteger(0));
        }
        return PRE_REQUIRE.get(name);
    }

    /**
     * 从Redis申请配额
     * @param name 限流资源名
     * @param limit 限流额度
     * @param require 申请额度
     * @return 是否成功
     */
    private static boolean requireRedis(String name, int limit, int require) {
        log.debug("requireRedis name: {}, limit: {}, require: {}, timeWindow: {}, timeUnit: {}",
                name, limit, require, 1, TimeUnit.SECONDS);
        // 申请资源
        if (doRequire(name, limit, require)) {
            // 清除失败记录
            clearLastFailed(name);
            return true;
        }
        // 记录失败时间
        setLastFailed(name);
        return false;
    }

    private static boolean doRequire(String name, int limit, int require) {
        String key = getKey(name);
        long millisTtl = 1000;
        String script = "local apply,limit,ttl=ARGV[1],ARGV[2],ARGV[3] ";
        script += "if redis.call('EXISTS', KEYS[1])==1 " +
                    "then local curValue = redis.call('GET', KEYS[1]) " +
                    "local targetValue = curValue + apply " +
                    "if targetValue <= tonumber(limit) " +
                        "then redis.call('INCRBY', KEYS[1], apply) " +
                        "return true " +
                    "else " +
                        "return false " +
                    "end " +
                "else " +
                    "redis.call('SET', KEYS[1], apply) " +
                    "redis.call('PEXPIRE', KEYS[1], ttl) " +
                    "return true " +
                "end";
        DefaultRedisScript<Boolean> redisScript = new DefaultRedisScript<>();
        redisScript.setResultType(Boolean.class);
        redisScript.setScriptText(script);
        List<String> keys = new ArrayList<>();
        keys.add(key);
        return RedisUtil.getInstance().execute(redisScript, keys,
                String.valueOf(require), String.valueOf(limit), String.valueOf(millisTtl));
    }

    /**
     * 预判本次申请会失败
     * @param name 资源名
     * @return 是否会失败
     */
    private static boolean predictFailed(String name) {
        long last = LAST_FAILED.getOrDefault(name, 0L);
        return System.currentTimeMillis() < last + bufferTime;
    }

    /**
     * 更新失败时间
     * @param name 资源名
     */
    private static void setLastFailed(String name) {
        LAST_FAILED.put(name, System.currentTimeMillis());
    }

    /**
     * 清除失败时间
     * @param name 资源名
     */
    private static void clearLastFailed(String name) {
        LAST_FAILED.remove(name);
    }

    /**
     * @param name 资源名
     * @return redis中的资源key
     */
    private static String getKey(String name) {
        return REDIS_KEY_PREFIX + name;
    }

}

 

使用方式

// 申请1个资源 (限制1s资源总数为10)
if (RedisLimits.require("yjy", 10, 1, 1, TimeUnit.SECONDS)) {
    // 申请成功, 执行业务逻辑...
}

 

使用起来不够优雅? 想用注解? 请继续看

RedisLimit

package com.idanchuang.component.redis.annotation;

import java.lang.annotation.*;

/**
 * @author yjy
 * @date 2020/9/3 18:05
 **/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Inherited
public @interface RedisLimit {

    /** 限流qps */
    int value();

    /** 资源名称, 默认取 类名:方法名 */
    String name() default "";

    /** 时间窗口, 单位: 秒 */
    int timeWindow() default 1;

    /** 资源名称追加业务方法参数的值, 如: #appId */
    String[] appendKeys() default {};

    /** 预配额因子 */
    double factor() default 0.01D;

    /** 限流时返回的异常信息 */
    String errMessage() default "";

}
@RedisLimit

RedisLimitAspect

package com.idanchuang.component.redis.aspect;

import com.idanchuang.component.redis.annotation.RedisLimit;
import com.idanchuang.component.redis.helper.BusinessKeyHelper;
import com.idanchuang.component.redis.util.RedisLimits;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;

import java.lang.reflect.Method;
import java.util.concurrent.TimeUnit;

/**
 * Aspect for methods with {@link RedisLimit} annotation.
 *
 * @author yjy
 */
@Aspect
@Component
public class RedisLimitAspect {

    @Pointcut("@annotation(com.idanchuang.component.redis.annotation.RedisLimit)")
    public void redisLimitAnnotationPointcut() {
    }

    @Around("redisLimitAnnotationPointcut()")
    public Object invokeWithRedisLimit(ProceedingJoinPoint pjp) throws Throwable {
        Method originMethod = resolveMethod(pjp);
        RedisLimit annotation = originMethod.getAnnotation(RedisLimit.class);
        if (annotation == null) {
            // Should not go through here.
            throw new IllegalStateException("Wrong state for RedisLimit annotation");
        }
        String sourceName = getName(annotation, pjp);
        // 设置预配额因子
        RedisLimits.setFactor(sourceName, annotation.factor());
        if (RedisLimits.require(sourceName, annotation.value(), annotation.timeWindow(), TimeUnit.SECONDS)) {
            return pjp.proceed();
        }
        throw new RuntimeException(annotation.errMessage());
    }

    /**
     * 获取资源名
     * @param annotation 注解信息
     * @param pjp 调用信息
     * @return 资源名
     */
    private String getName(RedisLimit annotation, ProceedingJoinPoint pjp) {
        String sourceName = annotation.name();
        // if 未指定sourceName, 则默认取 类名:方法名
        if (StringUtils.isEmpty(sourceName)) {
            Method originMethod = resolveMethod(pjp);
            sourceName = originMethod.getDeclaringClass().getSimpleName() + ":" + originMethod.getName();
        }
        sourceName += BusinessKeyHelper.getKeyName(pjp, annotation.appendKeys());
        return sourceName;
    }

    private Method resolveMethod(ProceedingJoinPoint joinPoint) {
        MethodSignature signature = (MethodSignature) joinPoint.getSignature();
        Class<?> targetClass = joinPoint.getTarget().getClass();

        Method method = getDeclaredMethodFor(targetClass, signature.getName(),
                signature.getMethod().getParameterTypes());
        if (method == null) {
            throw new IllegalStateException("Cannot resolve target method: " + signature.getMethod().getName());
        }
        return method;
    }

    /**
     * Get declared method with provided name and parameterTypes in given class and its super classes.
     * All parameters should be valid.
     *
     * @param clazz          class where the method is located
     * @param name           method name
     * @param parameterTypes method parameter type list
     * @return resolved method, null if not found
     */
    private Method getDeclaredMethodFor(Class<?> clazz, String name, Class<?>... parameterTypes) {
        try {
            return clazz.getDeclaredMethod(name, parameterTypes);
        } catch (NoSuchMethodException e) {
            Class<?> superClass = clazz.getSuperclass();
            if (superClass != null) {
                return getDeclaredMethodFor(superClass, name, parameterTypes);
            }
        }
        return null;
    }

}
RedisLimitAspect.java

BusinessKeyHelper

package com.idanchuang.component.redis.helper;

import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.context.expression.MethodBasedEvaluationContext;
import org.springframework.core.DefaultParameterNameDiscoverer;
import org.springframework.core.ParameterNameDiscoverer;
import org.springframework.expression.EvaluationContext;
import org.springframework.expression.ExpressionParser;
import org.springframework.expression.spel.standard.SpelExpressionParser;
import org.springframework.util.StringUtils;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;

/**
* 获取用户定义业务key
* @author sxp
* @return
* @date 2020/7/3 10:54
*/
public class BusinessKeyHelper {

    private static ParameterNameDiscoverer nameDiscoverer = new DefaultParameterNameDiscoverer();

    private static ExpressionParser parser = new SpelExpressionParser();

    public static String getKeyName(ProceedingJoinPoint joinPoint, String[] keys) {
        if (keys == null || keys.length == 0) {
            return "";
        }
        Method method = getMethod(joinPoint);
        List<String> definitionKeys = getSpelDefinitionKey(keys, method, joinPoint.getArgs());
        List<String> keyList = new ArrayList<>(definitionKeys);
        return StringUtils.collectionToDelimitedString(keyList, "", "-", "");
    }

    private static Method getMethod(ProceedingJoinPoint joinPoint) {
        MethodSignature signature = (MethodSignature) joinPoint.getSignature();
        Method method = signature.getMethod();
        if (method.getDeclaringClass().isInterface()) {
            try {
                method = joinPoint.getTarget().getClass().getDeclaredMethod(signature.getName(),
                        method.getParameterTypes());
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        return method;
    }

    private static List<String> getSpelDefinitionKey(String[] definitionKeys, Method method, Object[] parameterValues) {
        List<String> definitionKeyList = new ArrayList<>();
        for (String definitionKey : definitionKeys) {
            if (definitionKey != null && !definitionKey.isEmpty()) {
                EvaluationContext context = new MethodBasedEvaluationContext(null, method, parameterValues, nameDiscoverer);
                Object obj = parser.parseExpression(definitionKey).getValue(context);
                if (obj == null) {
                    definitionKeyList.add("");
                } else {
                    definitionKeyList.add(obj.toString());
                }
            }
        }
        return definitionKeyList;
    }

}
BusinessKeyHelper.java

 

添加上面三个类后, 我们就可以使用注解啦, 如下

/**
 * @author yjy
 * @date 2020/5/8 10:21
 **/
@Component
public class BusinessService {

    private static AtomicInteger a = new AtomicInteger(0);

    @RedisLimit(value = 10, name = "source_name", appendKeys = {"#from"}, timeWindow = 1, factor = 0.01D, errMessage = "惨被限流")
    public void doSomething(String from) {
        System.out.println(from + " 通过: " + a.incrementAndGet());
    }
}

 

好了, 基于Redis实现的集群限流工具就介绍到这里了, 下回见啦, 拜拜~

标签:return,name,require,限流,static,组件,import,SpringDataRedis,String
来源: https://www.cnblogs.com/imyjy/p/15808989.html