Springboot

dubbo rpc调用参数校验

2018-08-10  本文已影响207人  上重楼

使用spring的时候http调用参数校验还是很方便的,只是我们rpc用的比较多,然后就有了这个了。

其实实现很简单,实现一个dubbo的Filter 然后在这里根据反射获取参数的注解使用javax.validation的注解做各种操作就行了。要使用的项目需要加入自定义的Filter
类似:


image.png

加上paramValidationFilter 多个Filter用英文逗号分割。
加这个的意思是,这个项目的全局rpc调用Filter加入自定义的参数校验。 就算你的项目不是使用xml的配置方法,也可有其他类似的行为。需要注意多个Filter的优先级

我们的rpc都是返回一个固定的包装对象 ResultData<T> 所以会返回一个有错误状态码和错误信息的包装ResultData给调用方
如果返回类型不为ResultData会抛出我们自定义的异常

上实现代码:

package com.jx.common.dubbo.filter;

import com.alibaba.dubbo.common.extension.Activate;
import com.alibaba.dubbo.rpc.*;
import com.google.common.collect.Maps;
import com.jx.common.exception.BusinessException;
import com.jx.common.model.ResultData;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;

import javax.validation.ConstraintViolation;
import javax.validation.Valid;
import javax.validation.Validation;
import javax.validation.Validator;
import javax.validation.constraints.*;
import javax.validation.groups.Default;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

/**
 * @author 周广
 **/
@Slf4j
@Activate(order = -100)
public class ParamValidationFilter implements Filter {
    private static Validator validator = Validation.buildDefaultValidatorFactory().getValidator();


    /**
     * 对象实体校验
     *
     * @param obj 待校验对象
     * @param <T> 待校验对象的泛型
     * @return 校验结果
     */
    private static <T> ValidationResult validateEntity(T obj) {
        Set<ConstraintViolation<T>> set = validator.validate(obj, Default.class);
        return getValidationResult(set);
    }

    /**
     * 将校验结果转换返回对象
     *
     * @param set 错误信息set
     * @param <T> 校验对象的泛型
     * @return 校验结果
     */
    private static <T> ValidationResult getValidationResult(Set<ConstraintViolation<T>> set) {
        ValidationResult result = new ValidationResult();
        if (set != null && !set.isEmpty()) {
            result.setHasErrors(true);
            Map<String, String> errorMsg = Maps.newHashMap();
            for (ConstraintViolation<T> violation : set) {
                errorMsg.put(violation.getPropertyPath().toString(), violation.getMessage());
            }
            result.setErrorMsg(errorMsg);
        }
        return result;
    }

    /**
     * 方法级别的参数验证 手撕o(╥﹏╥)o
     *
     * @param annotation 待验证的注解
     * @param param      待校验参数
     * @param paramName  参数名称
     * @return 校验结果
     */
    static ValidationResult validateMethod(Annotation annotation, Object param, String paramName) {
        ValidationResult result = new ValidationResult();
        Map<String, String> errorMsg = Maps.newHashMap();
        result.setErrorMsg(errorMsg);

        if (annotation instanceof DecimalMax) {

            if (param instanceof BigDecimal) {
                BigDecimal value = new BigDecimal(((DecimalMax) annotation).value());
                if (((BigDecimal) param).compareTo(value) > 0) {
                    result.setHasErrors(true);
                    if ("{javax.validation.constraints.DecimalMax.message}".equals(((DecimalMax) annotation).message())) {
                        errorMsg.put(paramName, param.getClass().getName() + " 值为:" + param + "大于" + value);
                    } else {
                        errorMsg.put(paramName, ((DecimalMax) annotation).message());
                    }
                }
            } else if (param instanceof BigInteger) {
                BigInteger value = new BigInteger(((DecimalMax) annotation).value());
                if (((BigInteger) param).compareTo(value) > 0) {
                    result.setHasErrors(true);
                    if ("{javax.validation.constraints.DecimalMax.message}".equals(((DecimalMax) annotation).message())) {
                        errorMsg.put(paramName, param.getClass().getName() + " 值为:" + param + "大于" + value);
                    } else {
                        errorMsg.put(paramName, ((DecimalMax) annotation).message());
                    }
                }
            } else {
                result.setHasErrors(true);
                errorMsg.put(paramName, "类型:" + param.getClass().getName() + "与注解:" + annotation.getClass().getName() + "不匹配");
            }

        } else if (annotation instanceof DecimalMin) {
            if (param instanceof BigDecimal) {
                BigDecimal value = new BigDecimal(((DecimalMin) annotation).value());
                if (((BigDecimal) param).compareTo(value) < 0) {
                    result.setHasErrors(true);
                    if ("{javax.validation.constraints.DecimalMin.message}".equals(((DecimalMin) annotation).message())) {
                        errorMsg.put(paramName, paramName + param.getClass().getName() + " 值为:" + param + "小于" + value);
                    } else {
                        errorMsg.put(paramName, ((DecimalMin) annotation).message());
                    }
                }
            } else if (param instanceof BigInteger) {
                BigInteger value = new BigInteger(((DecimalMin) annotation).value());
                if (((BigInteger) param).compareTo(value) < 0) {
                    result.setHasErrors(true);
                    if ("{javax.validation.constraints.DecimalMin.message}".equals(((DecimalMin) annotation).message())) {
                        errorMsg.put(paramName, param.getClass().getName() + " 值为:" + param + "小于" + value);
                    } else {
                        errorMsg.put(paramName, ((DecimalMin) annotation).message());
                    }
                }
            } else {
                result.setHasErrors(true);
                errorMsg.put(paramName, "类型:" + param.getClass().getName() + "与注解:" + annotation.getClass().getName() + "不匹配");
            }
        } else if (annotation instanceof Max) {
            long value = ((Max) annotation).value();
            if (Long.valueOf(param.toString()) > value) {
                result.setHasErrors(true);
                if ("{javax.validation.constraints.Max.message}".equals(((Max) annotation).message())) {
                    errorMsg.put(paramName, param.getClass().getName() + " 值为:" + param + "大于" + value);
                } else {
                    errorMsg.put(paramName, ((Max) annotation).message());
                }
            }
        } else if (annotation instanceof Min) {
            long value = ((Min) annotation).value();
            if (Long.valueOf(param.toString()) < value) {
                result.setHasErrors(true);
                if ("{javax.validation.constraints.Min.message}".equals(((Min) annotation).message())) {
                    errorMsg.put(paramName, param.getClass().getName() + " 值为:" + param + "小于" + value);
                } else {
                    errorMsg.put(paramName, ((Min) annotation).message());
                }
            }

        } else if (annotation instanceof NotNull) {
            if (param == null) {
                result.setHasErrors(true);
                if ("{javax.validation.constraints.NotNull.message}".equals(((NotNull) annotation).message())) {
                    errorMsg.put(paramName, "值为null");
                } else {
                    errorMsg.put(paramName, ((NotNull) annotation).message());
                }
            }
        } else if (annotation instanceof Size) {
            int value = Integer.valueOf(param.toString());
            if (value > ((Size) annotation).max()) {
                result.setHasErrors(true);
                if ("{javax.validation.constraints.Size.message}".equals(((Size) annotation).message())) {
                    errorMsg.put(paramName, param.getClass().getName() + " 值为:" + param + "大于" + ((Size) annotation).max());
                } else {
                    errorMsg.put(paramName, ((Size) annotation).message());
                }
            } else if (value < ((Size) annotation).min()) {
                result.setHasErrors(true);
                if ("{javax.validation.constraints.Size.message}".equals(((Size) annotation).message())) {
                    errorMsg.put(paramName, param.getClass().getName() + " 值为:" + param + "小于" + ((Size) annotation).min());
                } else {
                    errorMsg.put(paramName, ((Size) annotation).message());
                }
                errorMsg.put(paramName, param.getClass().getName() + " 值为:" + param + "小于" + ((Size) annotation).min());
            }
        }
        return result;
    }

    @Override
    public Result invoke(Invoker<?> invoker, Invocation invocation) throws RpcException {
        Method method = null;

        for (Method m : invoker.getInterface().getDeclaredMethods()) {
            if (m.getName().equals(invocation.getMethodName()) && invocation.getArguments().length == m.getParameterCount()) {
                Class[] invokerMethodParamClassList = invocation.getParameterTypes();
                Class[] matchMethodParamClassList = m.getParameterTypes();
                if (verifyClassMatch(invokerMethodParamClassList, matchMethodParamClassList)) {
                    method = m;
                    break;
                }
            }
        }

        //如果找不到对应的方法就跳过参数校验
        if (method == null) {
            return invoker.invoke(invocation);
        }


        //一个参数可以有多个注解
        Annotation[][] paramAnnotation = method.getParameterAnnotations();
        //参数的class
        Class<?>[] paramClass = invocation.getParameterTypes();
        Object[] paramList = invocation.getArguments();
        //获取参数名称
        List<String> paramNameList = Arrays.stream(method.getParameters()).map(Parameter::getName).collect(Collectors.toList());

        for (int i = 0; i < paramList.length; i++) {
            Object param = paramList[i];


            Annotation[] annotations = paramAnnotation[i];
            if (annotations.length == 0) {
                continue;
            }


            //循环注解处理,快速失败
            for (Annotation annotation : annotations) {
                if (isJavaxValidationAnnotation(annotation)) {
                    ValidationResult result;
                    try {
                        result = validateMethod(annotation, param, paramNameList.get(i));
                    } catch (Exception e) {
                        log.error("参数校验异常", e);
                        return invoker.invoke(invocation);
                    }
                    if (result.isHasErrors()) {
                        if (ResultData.class.equals(method.getReturnType())) {
                            return generateFailResult(result.getErrorMsg().toString());
                        } else {
                            throw new BusinessException(result.getErrorMsg().toString());
                        }
                    }
                } else if (annotation instanceof Valid) {
                    if (param == null) {
                        String errMsg = String.format("待校验对象:%s 不可为null", paramClass[i].getName());
                        if (ResultData.class.equals(method.getReturnType())) {
                            return generateFailResult(errMsg);
                        } else {
                            throw new BusinessException(errMsg);
                        }
                    }
                    ValidationResult result = validateEntity(param);
                    if (result.isHasErrors()) {
                        if (ResultData.class.equals(method.getReturnType())) {
                            return generateFailResult(result.getErrorMsg().toString());
                        } else {
                            throw new BusinessException(result.getErrorMsg().toString());
                        }
                    }
                }
            }
        }

        return invoker.invoke(invocation);
    }

    private boolean verifyClassMatch(Class[] invokerMethodParamClassList, Class[] matchMethodParamClassList) {
        for (int i = 0; i < invokerMethodParamClassList.length; i++) {
            if (!invokerMethodParamClassList[i].equals(matchMethodParamClassList[i])) {
                return false;
            }
        }
        return true;
    }

    /**
     * 构建失败返回对象
     *
     * @param errorMsg 错误信息
     * @return dubbo 返回对象
     */
    private Result generateFailResult(String errorMsg) {
        //请求参数非法
        return new RpcResult(new ResultData<>("0001", errorMsg));
    }

    /**
     * 判断是否为javax.validation.constraints的注解
     *
     * @param annotation 目标注解
     */
    private boolean isJavaxValidationAnnotation(Annotation annotation) {
        if (annotation instanceof AssertFalse) {
            return true;
        } else if (annotation instanceof AssertTrue) {
            return true;
        } else if (annotation instanceof DecimalMax) {
            return true;
        } else if (annotation instanceof DecimalMin) {
            return true;
        } else if (annotation instanceof Digits) {
            return true;
        } else if (annotation instanceof Future) {
            return true;
        } else if (annotation instanceof Max) {
            return true;
        } else if (annotation instanceof Min) {
            return true;
        } else if (annotation instanceof NotNull) {
            return true;
        } else if (annotation instanceof Null) {
            return true;
        } else if (annotation instanceof Past) {
            return true;
        } else if (annotation instanceof Pattern) {
            return true;
        } else if (annotation instanceof Size) {
            return true;
        }
        return false;
    }

    @Data
    static class ValidationResult {

        /**
         * 校验结果是否有错
         */
        private boolean hasErrors;

        /**
         * 校验错误信息
         */
        private Map<String, String> errorMsg;
    }
}

使用文档直接就从公司内部我写的wiki文档复制过来的, 送佛送到西
使用方法很简单,如果需要使用参数校验需要在接口的方法上面加上一些注解,注意是接口! 实现类加了没用


image.png

这里我们要区分2类参数

第一类:自定义对象 必须使用@Valid注解,标记这个对象需要进行校验
第二类:基本类型、包装类型、Decimal类型、String 等等一般都是java自带的类型

第二类可以使用:
第一个:javax.validation.constraints.DecimalMax
用于比较BigDecimal或者BigInteger
例子:

public void test1(@DecimalMax(value = "10", message = "最大为十") BigDecimal value)

意思是这个值最大为10 message是超出范围的时候的错误信息

第二个:

javax.validation.constraints.DecimalMin

和@DecimalMax 一样的使用方法,只是这个是最低 刚好相反

第三第四个:

javax.validation.constraints.Max

javax.validation.constraints.Min

例子:

public void test3(@Max(value = 10) int value) {
public void test4(@Min(value = 10,message = "最小为10") Integer value) 

这个可以用于比较多种整形数值类型

第五个:

javax.validation.constraints.NotNull

例子:

public void test5(@NotNull int value)

意思是校验的参数不可为null

第六个:

javax.validation.constraints.Size

例子:

public void test6(@Size(min = 10, max = 100,message = "最小10 最大100 不在10到100就报错") int value) 

设置范围的。

如上所述,虽然使用的是javax.validation.constraints下的注解,但是基本只支持注解的 value 和message 2个通用参数 别的都不支持! 因为我是手撕的校验 (validateMethod方法 逻辑可以自己调整)

接下来是第一类:自定义对象的校验方法


image.png

对数组、List使用的@NotEmpty 是org.hibernate.validator.constraints.NotEmpty 用于校验数组、List不为空

而且对象字段testClass2头上有2个注解

@NotNull
@Valid
NotNull用于标记这个对象自身不可为null, Valid用于标记这个对象内部还需要校验

所以testClass2List的2个注解也很容易理解了,就是标记List不可为null,且还要检查List里面每一个对象的字段

image.png

如图所示,如果是基本类型、包装类型、String、Decimal 等内建类型 可以使用如下注解

而且是完整的功能,除了上面说的value 和message2个通用参数,还可以使用其他所有javax.validation.constraints支持的功能。因为自定义对象的校验是使用javax.validation提供的校验

上一篇下一篇

猜你喜欢

热点阅读