Redis+Lua的限流方案

Lua 是一种轻量小巧的脚本语言,用标准C语言编写并以源代码形式开放, 其设计目的是为了嵌入应用程序中,从而为应用程序提供灵活的扩展和定制功能,Redis支持Lua脚本,所以通过Lua实现限流的算法。
Lua脚本实现算法对比操作Redis实现算法的优点:
减少网络开销:使用Lua脚本,无需向Redis 发送多次请求,执行一次即可,减少网络传输
原子操作:Redis 将整个Lua脚本作为一个命令执行,原子,无需担心并发
复用:Lua脚本一旦执行,会永久保存 Redis 中,其他客户端可复用

运行环境:Windows
Redis 5.0.14.1
Lua 5.X
SpringBoot 2.7.0

1.搭建SpringBoot项目,引入依赖

 <dependencies>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-data-redis</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-aop</artifactId>
        </dependency>
        <dependency>
            <groupId>com.google.guava</groupId>
            <artifactId>guava</artifactId>
            <version>21.0</version>
        </dependency>
        <dependency>
            <groupId>org.apache.commons</groupId>
            <artifactId>commons-lang3</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
            <exclusions>
                <exclusion>
                    <groupId>org.junit.vintage</groupId>
                    <artifactId>junit-vintage-engine</artifactId>
                </exclusion>
            </exclusions>
        </dependency>
        <dependency>
            <groupId>org.springframework.data</groupId>
            <artifactId>spring-data-redis</artifactId>
            <version>2.7.2</version>
        </dependency>
    </dependencies>

2.项目整合Redis

application.yml

spring:  
    application:  
        name: redis_lua_limit  
    redis:  
        port: 6379  
        host: localhost

3.配置RedisTemplate

@Configuration
public class RedisConfig {
    
    @Bean
    public RedisTemplate<String, Serializable> limitRedisTemplate(
            LettuceConnectionFactory redisConnectionFactory) {
        RedisTemplate<String, Serializable> template = new RedisTemplate<>();
        template.setKeySerializer(new StringRedisSerializer());
        template.setValueSerializer(new GenericJackson2JsonRedisSerializer());
        template.setConnectionFactory(redisConnectionFactory);
        return template;
    }
}

4.限流类型枚举类

public enum LimitType {
    // 自定义key
    CUSTOMER,
    
    // 请求IP
    IP;
}

5.自定义@Limit注解

period表示请求限制时间段,count表示在period这个时间段内允许放行请求的次数。limitType代表限流的类型,可以根据请求的IP自定义key,如果不传limitType属性则默认用方法名作为默认key。

//表明注解可用于的地方  METHOD:方法上  TYPE:用于描述类、接口(包括注解类型) 或enum声明
@Target({ElementType.METHOD, ElementType.TYPE}) 
//存活阶段   runtime:运行期
@Retention(RetentionPolicy.RUNTIME)
//可继承
@Inherited
//作用域 javaDoc
@Documented
public @interface Limit {
 
    // key
    String key() default "";
 
    // 给定的时间范围
    int period();
 
    // 一定时间内最多访问次数
    int count();
 
    // 限流的类型  (自定义key或者请求ip)
    LimitType limitType() default LimitType.CUSTOMER;
 
}

6.定义切面类

@Aspect
@Configuration
public class LimitInterceptor {

    @Autowired
    private RedisTemplate<String, Serializable> redisTemplate;

   /**
    * @Author: Balla
    * @Description: ppt
    * @Date: 2023/11/15 15:53
   */
   @Around("execution(public * *(..)) && @annotation(limit)")
    public Object interceptor(ProceedingJoinPoint ppt,Limit limit) {

        // 获取方法对象
        MethodSignature signature = (MethodSignature) ppt.getSignature();
        Method method = signature.getMethod();

        // 获取@Limit注解对象
        Limit limitAnnotation = method.getAnnotation(Limit.class);

        // 获取key类型
        LimitType limitType = limitAnnotation.limitType();

        // 获取请求限制时间段、请求限制次数
        int limitPeriod = limitAnnotation.period();
        int limitCount = limitAnnotation.count();

        // 根据限流类型获取不同的key ,如果不传以方法名作为key
        String key;
        switch (limitType) {
            case IP:
                key = getIpAddress();
                break;
            case CUSTOMER:
                key = limitAnnotation.key();
                break;
            default:
                key = method.getName();
        }

        // 定义key参数
        List<String> keys = new ArrayList<String>();
        keys.add(key);

        try {
            // 获取Lua脚本内容
            String luaScript = buildLuaScript();

            // Reids整合Lua
            RedisScript<Long> redisScript = new DefaultRedisScript<>(
                    luaScript, Long.class);
            // 执行Lua,并返回key值
            Long count = redisTemplate.execute(redisScript, keys, limitCount,
                    limitPeriod);

            // 判断是否阻止请求
            if (count != null && count.intValue() <= limitCount) {
                return ppt.proceed();
            } else {
                throw new RuntimeException("please try again later");
            }
        } catch (Throwable e) {
            if (e instanceof RuntimeException) {
                throw new RuntimeException(e.getLocalizedMessage());
            }
            throw new RuntimeException("server error");
        }

    }

    /**
     * 编写 redis Lua 限流脚本
     */
    public String buildLuaScript() {
        StringBuilder lua = new StringBuilder();
        lua.append("local c");
        lua.append("\nc = redis.call('get',KEYS[1])");
        // 调用不超过最大值,则直接返回
        lua.append("\nif c and tonumber(c) > tonumber(ARGV[1]) then");
        lua.append("\nreturn c;");
        lua.append("\nend");
        // 执行计算器自加
        lua.append("\nc = redis.call('incr',KEYS[1])");
        lua.append("\nif tonumber(c) == 1 then");
        // 从第一次调用开始限流,设置对应键值的过期
        lua.append("\nredis.call('expire',KEYS[1],ARGV[2])");
        lua.append("\nend");
        lua.append("\nreturn c;");
        return lua.toString();
    }
    

    /**
     * 获取请求ip
     */
    public String getIpAddress() {
        HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder
                .getRequestAttributes()).getRequest();
        String ip = request.getHeader("x-forwarded-for");
        if (ip == null || ip.length() == 0) {
            ip = request.getHeader("Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0) {
            ip = request.getHeader("WL-Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0) {
            ip = request.getRemoteAddr();
        }
        return ip;
    }
}
  1. 获取 Lua 脚本内容: 通过调用 buildLuaScript() 方法,你获得了一个包含 Lua 脚本的字符串。
  2. 创建 RedisScript 对象: 使用 DefaultRedisScript 类,你创建了一个 RedisScript<Long> 对象。这里的泛型参数 Long 表示 Lua 脚本的返回类型。在你的 Lua 脚本中,它是一个表示计数器的整数。
  3. 执行 Lua 脚本: 使用 limitRedisTemplate.execute(redisScript, scriptKeys, limitCount, limitPeriod),你调用了 RedisTemplate 的 execute 方法,该方法用于执行 Redis 命令和 Lua 脚本。

    • redisScript 是表示 Lua 脚本的对象。
    • scriptKeys 是传递给 Lua 脚本的键。
    • limitCountlimitPeriod 是作为参数传递给 Lua 脚本的值。
  4. 处理 Lua 脚本的结果: Lua 脚本的执行结果会作为方法的返回值。在你的代码中,这个返回值是一个 Long 类型,表示 Lua 脚本的执行结果。这个值会在后续的代码中使用。

7.测试Controller

@RestController
public class LimitersController {

    private static final AtomicInteger ATOMIC_INTEGER_1 = new AtomicInteger();
    private static final AtomicInteger ATOMIC_INTEGER_2 = new AtomicInteger();
    private static final AtomicInteger ATOMIC_INTEGER_3 = new AtomicInteger();

    /**
     * @Author: Balla
     * @Description: 10秒内允许请求3次,key为方法名称
     * @Date: 2023/11/15 21:07
    */
    @Limit(key = "limitTest", period = 10, count = 3)
    @GetMapping("/limitTest1")
    public int testLimiter1() {

        return ATOMIC_INTEGER_1.incrementAndGet();
    }

    /**
     * @Author: Balla
     * @Description:10秒内允许请求3次,自定义key
     * @Date: 2023/11/15 21:07
    */
    @Limit(key = "customer_limit_test", period = 10, count = 3, limitType = LimitType.CUSTOMER)
    @GetMapping("/limitTest2")
    public int testLimiter2() {

        return ATOMIC_INTEGER_2.incrementAndGet();
    }

    /**
     * @Author: Balla
     * @Description:10秒内允许请求3次,key为请求ip
     * @Date: 2023/11/15 21:07
    */
    @Limit(key = "ip_limit_test", period = 10, count = 3, limitType = LimitType.IP)
    @GetMapping("/limitTest3")
    public int testLimiter3() {

        return ATOMIC_INTEGER_3.incrementAndGet();
    }

}
tag(s): none
show comments · back · home
Edit with Markdown