Python参数校验类实现

2018-08-24  本文已影响637人  大富帅

我们在开发接口或者后台提交数据的时候,经常要验证参数,如果把验证参数写成一个通用库调用,就可以减少很多重复工作了。
如果我们可以把所有的验证参数逻辑都交给统一的地方处理,就不用在业务逻辑上写太多dirty的代码,而且如果参数错误,直接给予一个错误的handler来处理,或返回error的json,或返回错误消息页面。

看下面例子,我们在view前面加个装饰器,专门封装处理参数的,把所有需要的参数,和参数的要求,包括类型,是否必须传过来、是否不能为空、默认值、参数名字、字符串长度,数值的大小max,min值, 等等等的属性。
还要加上错误处理的Handler,处理错误,假设都通过验证了,我们可以直接从safe_vars里面安全放心的获取参数,而且参数也已经经过处理成我们要的类型,比如传过来的是str的数字,会转成int的数值,这真的很省心。开发也会舒心很多

@form_check({
    "valid_time": F_str(u"有效期") & "required" & "strict",
    "title": 0 < F_str(u"标题") <= 13 & "required" & "strict",
    "content": 0 < F_str(u"内容") <= 40 & "required" & "strict",
    "msg_type": F_int(u"消息类型") & "required" & "strict",
    "msg_id": util.F_object_id(u"消息id") & "optional"
}, error_handler=functools.partial(JsonInterfaceResponse, code=cfg.ERROR_CODE.PARAM))
def create_message(safe_vars):
    title = safe_vars["title"]
    content = safe_vars["content"]
    ...

这里涉及到两个重点:

我们先来实现表单验证类, 我们先要有大概的构思, 首先验证不同的字段会有不同格式,不同验证规则,所以可以定义不同验证字段,我们可以为每一种不同的验证字段定义一个类,比如F_str(), F_int(),F_boolean(),F_float(),F_email(),F_mobile(), F_datetime(),F_file() 等等需要验证的字段。

我们想想这些类都有什么共性,他们必须有一个验证参数值的方法,我们定义为_check()方法。
我们可以为这些类先作出一个父类Input(),因为他们都是同一类的东西,所以可以把他们共同的东西抽象出来,写到父类。
这里涉及到面向对象编程的特性--接口,我们的验证字段考虑成可以根据以后开发的需要,自己扩展不同的验证类F_xx(),但是我们自己扩展的时候,就要遵循一些规则,比如,一定需要有一个_check()方法,用于检验参数的。然后如果验证成功,返回什么数据,错误,返回什么数据,都需要有一个明确的规则,这样才能按规范扩展。

所以我们定义的父类Input的时候,一定要定义一个抽象方法,让子类去继承实现。

_default_encoding = 'utf8'
_default_messages = { 
    'max': u'%(name)s 的最大值为%(max)s',
    'min': u'%(name)s 最小值为%(min)s',
    'max_len': u'%(name)s 最大长度为%(max_len)s个字符',
    'min_len': u'%(name)s 最小长度为%(min_len)s个字符',
    'blank': u'%(name)s 不能为空',
    'callback': u'%(name)s 输入出错了',
    'format': u'%(name)s 格式有错',
    'default': u'%(name)s 格式有错',
}


class Input(object):
    def __init__(self, field_name=None, default_value=None):
        # 初始化错误消息类型列表
        self._messages = {}
        self._messages.update(_default_messages)

        # 错误消息的参数, 比如 标题 不能为空
        self._message_vars = {
            "name": field_name
        }
        # 设置默认参数
        if defalut_value is not None:
            self._default_value = default_value

    # 抽象方法,接口,用于让子类继承,子类没有继承重写这个方法,就会报错
    def _check(self, value):
        raise NotImplementedError("you should use sub-class for checking")

    # 对外调用的验证方法
    def check(self, name, value):
    ''' 
        验证方法, 返回4个参数,
        1、是否验证成功,
        2、原输入的数据raw_data
        3、如果验证成功,就是处理后的数据valid_data
        4、如果验证失败,就是错误信息
        @param name: field name of the form
        @param value: field value
        @return: a tuple of 4 items:
            the first indicates if the value is valid
            the second is the raw data of user input
            the third is the valid data or, if not valid, is None
            the last is error message or, if valid, is None
    '''
        if self._message_vars['name'] is None:
            self._message_vars['name'] = name

        if self._multiple:
            return self.check_multi(value)
        else:
            return self.check_value(value)

    # 验证多参数方法
    def check_multi(value):
        pass

     # 验证单参数方法
    def check_value(value):
        pass

下面我们来分析一下这个Input类

下面看下check_value() 和 check_multi()的实现

      # 检查单参数
    def check_value(self, raw):
        valid, valid_data, message = False, None, None
        value = raw
        if value is None or (value == '' and self._strict):
            if self._default_value is not None:
                if callable(self._default_value):
                    return True, raw, self._default_value(), None
                return True, raw, self._default_value, None
            else:
                value = None

        # `None' means the client does not send the field's value,
        # this differs from empty string, although when `strict' is specified, they are the same.
        # if value:
        if value is not None:
            valid, data = self._check(value)
            if valid:
                valid, data = self._callbacks[0](data)
                if valid:
                    valid_data = data
                else:
                    message = self._messages['callback']
            else:
                message = data
                                                                                                                   
        elif self._optional:
            valid = True
            valid_data = value

        # value is empty and is not optional
        else:
            message = self._messages['blank']

        if message:
            message = message % self._message_vars

        return valid, raw, valid_data, message
   
     # 检查多参数
    def check_multi(self, values): 
        if not values and not self._optional: 
            message = self._messages.get('blank') or self._messages['default'] 
            message = message % self._message_vars 
            return False, values, None, message 
 
        valid_data = [] 
        for value in values: 
            valid, origin, valid_value, message = self.check_value(value) 
            if not valid: 
                return False, values, None, message 
            valid_data.append(valid_value) 
 
        return True, values, valid_data, None

下面我们来分析一下check_value()的方法

下面我们来看下 0 < F_str(u"标题") <= 13 & "required" & "strict" 这种形式的定义,是怎么实现参数长度限制,是否必须,是否严格限制传参的。这个就python反射机制有关了。

class Input(object):
    _type = None
    _min = None
    _max = None
    _optional = False
    _multiple = False
    _strict = False  # means empty string will treat as no input
    _attrs = ('optional', 'required', 'multiple', 'strict')

   ....

    def __and__(self, setting):
        if callable(setting):
            self._callbacks = (setting,)
        elif isinstance(setting, dict):
            self._messages.update(setting)
        elif setting in self._attrs:
            value = True
            if setting == 'required':
                setting = 'optional'
                value = False

            attr = '_%s' % setting
            if hasattr(self, attr):
                setattr(self, attr, value)
        else:
            raise NameError('%s is not support' % setting)

        return self

    def __lt__(self, max_value):
        raise NotImplementedError('"<" is not supported now')

    def __gt__(self, min_value):
        raise NotImplementedError('">" is not supported now')

    def __le__(self, max_value):
        self._max = max_value
        self._message_vars['max'] = max_value
        self._message_vars['max_len'] = max_value
        return self

    def __ge__(self, min_value):
        self._min = min_value
        self._message_vars['min'] = min_value
        self._message_vars['min_len'] = min_value
        return self

    def __eq__(self, value):
        self.__ge__(value)
        self.__le__(value)
        return self

    def _check_mm(self, value):
        if self._max is not None and value > self._max:
            self._message_key = 'max'
            return False
        if self._min is not None and value < self._min:
            self._message_key = 'min'
            return False
        return True

这些方法都是为了使对象能想其他对象一样可以做一些比较或者操作,比如iobj1 == obj2, obj1 > 200, obj1 & obj2这些的操作如果想使用在对象上,就必须定义对应符号的方法,比如== 这个判断,就会调用 __ne__()方法。obj1 >= 200 就会调用 __ge__()方法。

我们来分析一下代码:

好了,看完父类的定义, 它把所有设给子类的接口和对外的调用方法都定义好了,我们的子类,或者说我们要扩展一个自己的校验器,就可以有迹可循,有规范了。比如,看看我们定义的几个子类

class F_int(Input):
    _strict = True

    def _check(self, value):
        try:
            value = int(value)
        except ValueError:
            return False, self._messages['default']

        if not self._check_mm(value):
            message = self._messages[self._message_key]
            return False, message

        return True, value


class F_str(Input):

    def __init__(self, field_name=None,
                 default_value=None, format=None):
        super(F_str, self).__init__(field_name, default_value)
        self._format = format

        self._max = fstr_default_max
        self._message_vars['max'] = fstr_default_max
        self._message_vars['max_len'] = fstr_default_max

    def _check(self, value):
        if not self._check_mm(len(value)):
            message = self._messages[self._message_key + '_len']
            return False, message

        if self._format:
            import re
            if not re.match(self._format, value):
                return False, self._messages['format']
        return True, value

我们分析一下代码:

再举例看看其他的实现:

class F_datetime(Input):

    def __init__(self, field_name=None,
                 default_value=None, format='%Y-%m-%d %H:%M:%S'):
        super(F_datetime, self).__init__(field_name, default_value)
        self._format = format

    def _check(self, value):
        from datetime import datetime
        from time import strptime
        try:
            return True, datetime(*strptime(value, self._format)[0:6])
        except ValueError, e:
            return False, self._messages['format']

class F_email(F_str):

    def _check(self, value):
        if not self._check_mm(len(value)):
            message = self._messages[self._message_key + '_len']
            return False, message
        return self.is_email(value)

    def is_email(self, value):
        import re
        email = re.compile(r"^[\w.%+-]+@(?:[A-Z0-9-]+\.)+[A-Z]{2,4}$", re.I)
        if not email.match(value):
            return False, self._messages['default']
        return True, value

F_uuid = lambda: F_str(u"uuid") <= 128
F_engine = lambda: F_str(u"engine") <= 32
F_channel = lambda: F_str(u"channel") <= 32
F_package = lambda: F_str(u"package") <= 64

那么至此我们已经看完整个校验类的实现,在调用的时候就是这么执行

title_input = F_str("title") <= 50 & "required" & "strict" // 返回一个Input对象
is_valid, raw_value, valid_value, title_input.check_value("你好哇!") // 校验参数,然后返回4个值

可以看到 只需要两步:

class FormChecker(object):

    def __init__(self, source, input_form, method='both', err_msg_encoding=_default_encoding):
        """
        @param source: data source, can be object with properties GET, POST and VAR, or a dict-like object
        @param input_form: form setting
        @param method: GET or POST, or BOTH, if source is dict-like object, ignored
        @param err_msg_encoding: encoding of string of error message
        """
        self._source = source
        self._form = input_form
        self._checked = False
        self._method = method
        self._eme = err_msg_encoding
    def check(self, source=None):
        method = self._method.upper()

        if source is None:
            # if is a dict or some others with a `get' method
            if hasattr(self._source, 'get'):
                source = self._source
            # request object
            elif method == 'GET':
                source = self._source.GET
            elif method == 'POST':
                source = self._source.POST
            else:
                source = self._source.VAR

        form = self._form

        valid_data, raw_data,  messages = {}, {}, {}
        self._valid = True

        for field, checker in form.items():

            if checker.multiple and hasattr(source, 'getall'):
                value = source.getall(field)
            else:
                value = source.get(field, None)

            valid, raw_data[field], v, m = checker.check(field, value)
            if valid:
                valid_data[field] = v
            else:
                messages[field] = m

            self._valid = self._valid and valid

        self._raw_data = raw_data
        self._valid_data = valid_data
        self._messages = messages
        self._checked = True

        for field in self._messages:
            if self._messages[field]:
                self._messages[field] = self._messages[field].encode(self._eme)
            else:
                self._messages.pop(field)
 def is_valid(self):
        if not self._checked:
            self.check()
        return self._valid

    def get_error_messages(self):
        return self.err_msg

    @property
    def err_msg(self):
        if not self._checked:
            self.check()
        return self._messages

    def get_valid_data(self):
        return self.valid_data

    @property
    def valid_data(self):
        if not self._checked:
            self.check()
        return self._valid_data

    def get_raw_data(self):
        return self._raw_data

上一篇 下一篇

猜你喜欢

热点阅读