Commit 88aba3f9 by 袁伟铭

1.0.0

parent 1448fccb
......@@ -30,6 +30,9 @@ import java.lang.annotation.Target;
@Retention(RetentionPolicy.RUNTIME)
public @interface Limit {
// 限制类型
LimitType limitType() default LimitType.IP;
// 资源名称,用于描述接口功能
String name() default "";
......@@ -40,12 +43,30 @@ public @interface Limit {
String prefix() default "";
// 时间的,单位秒
int period();
int period() default 1;
// 限制访问次数
int count();
int count() default 3;
// 限制类型
LimitType limitType() default LimitType.CUSTOMER;
/**
* 对象里的属性名,仅当仅当{@link #limitType}为{@code RateLimitTypeEnum.POJO_FIELD}时有用
*
* @return
*/
String field() default "";
/**
* 要用来作为key组成的参数索引(从0开始), 该索引对应的参数必须为string/Long/Integer/Short/Byte, 仅当{@link #limitType}为{@code RateLimitTypeEnum.PARAM}时有用
*
* @return
*/
int keyParamIndex() default 0;
/**
* 达到限流上限时的错误提示
*
* @return
*/
String errMsg() default "操作过于频繁";
}
......@@ -15,14 +15,19 @@
*/
package com.zq.common.config.limit;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.zq.common.annotation.Limit;
import com.zq.common.config.redis.BaseCacheKeys;
import com.zq.common.context.ContextUtils;
import com.zq.common.http.HttpRequestUtils;
import com.zq.common.utils.AssertUtils;
import com.zq.common.vo.ApiTokenVo;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
......@@ -31,19 +36,16 @@ import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.servlet.http.HttpServletRequest;
import java.lang.reflect.Method;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
/**
* @author wilmiam
* @since 2021-07-09 17:51
*/
@Slf4j
@Aspect
@Component
public class LimitAspect {
......@@ -59,54 +61,25 @@ public class LimitAspect {
public void pointcut() {
}
@Around("pointcut()")
public Object around(ProceedingJoinPoint joinPoint) throws Throwable {
HttpServletRequest request = ((ServletRequestAttributes) Objects.requireNonNull(RequestContextHolder.getRequestAttributes())).getRequest();
@Before("pointcut()")
public void limitBeforeExecute(JoinPoint joinPoint) {
MethodSignature signature = (MethodSignature) joinPoint.getSignature();
Method signatureMethod = signature.getMethod();
Limit limit = signatureMethod.getAnnotation(Limit.class);
LimitType limitType = limit.limitType();
String key = limit.key();
// 构建key
if (StringUtils.isBlank(key)) {
if (limitType == LimitType.IP) {
key = HttpRequestUtils.getClientIp(request);
} else if (limitType == LimitType.USER_ID) {
Long userId = ContextUtils.getUserUserId();
key = userId != null ? userId.toString() : HttpRequestUtils.getClientIp(request);
} else {
// 获取方法名
key = signatureMethod.getName();
}
}
key = StringUtils.join(limit.prefix(), "_", key, "_", signatureMethod.getName());
/*String obj = stringRedisTemplate.opsForValue().get(key);
int currentLimit = 0;
if (obj != null) {
currentLimit = Integer.parseInt(obj);
}
if (currentLimit + 1 > limit.count()) {
throw new BusinessException("访问次数受限制");
}
stringRedisTemplate.opsForValue().set(key, (currentLimit + 1) + "");
stringRedisTemplate.expire(key, limit.period(), TimeUnit.SECONDS);
logger.info("第{}次访问key为 {},描述为 [{}] 的接口", currentLimit + 1, key, limit.name());
return joinPoint.proceed();*/
String key = buildLimitKey(limit, joinPoint, signatureMethod);
log.debug("限流缓存KEY: {}", key);
List<String> keys = Collections.singletonList(key);
String luaScript = buildLuaScript();
RedisScript<Long> redisScript = new DefaultRedisScript<>(luaScript, Long.class);
Long count = stringRedisTemplate.execute(redisScript, keys, String.valueOf(limit.count()), String.valueOf(limit.period()));
AssertUtils.isTrue(count != null && count != 0, "访问次数受限制");
AssertUtils.isTrue(count != null && count != 0, limit.errMsg());
logger.info("第{}次访问key为 {},描述为 [{}] 的接口", count, keys, limit.name());
return joinPoint.proceed();
String name = limit.name();
name = StringUtils.isNotBlank(name) ? name : signatureMethod.getName();
logger.info("第{}次访问key为 {},描述为 [{}] 的接口", count, keys, name);
}
/**
......@@ -135,4 +108,78 @@ public class LimitAspect {
"\nend";
}
/**
* 构建key
*
* @param limit
* @param joinPoint
* @return
*/
private String buildLimitKey(Limit limit, JoinPoint joinPoint, Method signatureMethod) {
String limitKey = null;
String key = limit.key();
key = StringUtils.isBlank(key) ? signatureMethod.getName() : key;
LimitType limitType = limit.limitType();
Object[] args = joinPoint.getArgs();
switch (limitType) {
case IP:
limitKey = BaseCacheKeys.rateLimitKey(LimitType.IP, key, HttpRequestUtils.getClientIp());
break;
case USER:
// 按用户登录id限流
ApiTokenVo apiTokenVo = ContextUtils.getUserContext();
if (apiTokenVo != null) {
limitKey = BaseCacheKeys.rateLimitKey(LimitType.USER, key, String.valueOf(apiTokenVo.getUserId()));
} else {
log.warn(">> 未找到登录用户信息,限流失败: {}", joinPoint);
}
break;
case POJO_FIELD:
String field = limit.field();
if (StringUtils.isBlank(field)) {
log.warn(">> 未设置field,限流失败: {}", joinPoint);
break;
}
if (args == null || args.length == 0 || args[0] == null) {
log.warn(">> 未找到对象,限流失败: {}", joinPoint);
break;
}
String fieldValue = getPojoField(field, args[0]);
limitKey = BaseCacheKeys.rateLimitKey(LimitType.POJO_FIELD, key, fieldValue);
break;
case PARAM:
int keyIndex = limit.keyParamIndex();
if (keyIndex < 0 || args == null || args.length < (keyIndex + 1) || args[keyIndex] == null) {
log.warn(">> 未找到参数或参数值为空,限流失败: {}, keyParamIndex={}", joinPoint, keyIndex);
} else if (isValidKeyParamType(args[keyIndex])) {
limitKey = BaseCacheKeys.rateLimitKey(LimitType.PARAM, key, String.valueOf(args[keyIndex]));
} else {
log.warn(">> 设置的参数不是string/long/int/short/byte类型,限流失败: {}", joinPoint);
}
break;
case KEY:
limitKey = BaseCacheKeys.rateLimitKey(LimitType.KEY, key);
break;
default:
// nothing to do
}
return limitKey;
}
private boolean isValidKeyParamType(Object param) {
return (param instanceof String) || (param instanceof Long) || (param instanceof Integer) || (param instanceof Short)
|| (param instanceof Byte);
}
private String getPojoField(String field, Object pojo) {
try {
JSONObject object = JSON.parseObject(JSON.toJSONString(pojo));
return object.getString(field);
} catch (Exception e) {
return null;
}
}
}
......@@ -22,10 +22,24 @@ package com.zq.common.config.limit;
* @since 2021-07-09 17:51
*/
public enum LimitType {
// 默认
CUSTOMER,
// by ip USER_ID
USER_ID,
// by ip address
IP
/**
* 针对每个IP进行限流
*/
IP,
/**
* 针对每个用户进行限流
*/
USER,
/**
* 针对对象的某个属性值进行限流
*/
POJO_FIELD,
/**
* 针对某个参数进行限流
*/
PARAM,
/**
* 直接对指定的key进行限流
*/
KEY
}
package com.zq.common.config.redis;
import com.zq.common.config.limit.LimitType;
import org.apache.commons.lang3.StringUtils;
/**
* 公共缓存key
*
......@@ -14,6 +17,8 @@ public abstract class BaseCacheKeys {
private static final String ADMIN_TOKEN = PREFIX + "admin-token.";
private static final String RATE_LIMIT = PREFIX + "rate-limit.";
/**
* 构建app端用户token的缓存key
*
......@@ -34,4 +39,51 @@ public abstract class BaseCacheKeys {
return ADMIN_TOKEN + token;
}
/**
* 构建限流Key
*
* @param type
* @param key
* @return
*/
public static String rateLimitKey(LimitType type, String key) {
return rateLimitKey(type, key, null);
}
/**
* 构建限流key
*
* @param type
* @param key
* @param param
* @return
*/
public static String rateLimitKey(LimitType type, String key, String param) {
String result = RATE_LIMIT;
switch (type) {
case IP:
result += "ip.";
break;
case USER:
result += "u.";
break;
case PARAM:
result += "p.";
break;
case POJO_FIELD:
result += "f.";
break;
case KEY:
result += "k.";
break;
default:
// nothing to do
}
result += key;
if (StringUtils.isNotBlank(param)) {
result += "." + param;
}
return result;
}
}
package com.zq.user.controller.app;
import com.zq.common.annotation.Limit;
import com.zq.common.utils.AssertUtils;
import com.zq.common.utils.ValidateUtil;
import com.zq.common.vo.ApiTokenVo;
......@@ -70,6 +71,7 @@ public class UserController {
return ResultVo.success(userService.passwdLogin(vo));
}
@Limit(count = 1)
@ApiOperation("获取用户信息")
@GetMapping(value = "/getUserInfo")
public ResultVo getUserInfo(@RequestParam String userId) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment