程序员Starlette 解读 by Gascognya

Starlette 源码阅读 (九) 认证

2020-08-17  本文已影响0人  Gascognya

authentication.py

用于权限认证的模块,提供了一个装饰器,和一些简单的模型
在阅读此模块期间有重大收货

认证装饰器

笔者以前一直搞不懂,fastapi的参数检测和依赖检测,究竟是如何实现的。
今天在该函数中初窥一角,但仍旧让我很惊喜

def requires(
    scopes: typing.Union[str, typing.Sequence[str]],
    status_code: int = 403,
    redirect: str = None,
) -> typing.Callable:
    # 一个三层装饰器
    scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)

    def decorator(func: typing.Callable) -> typing.Callable:
        type = None
        sig = inspect.signature(func)
        # 重点:该函数可以检测一个函数的参数信息,并返回
        # OrderedDict([('a', <Parameter "a: str">), ('b', <Parameter "b: str">)])
        # 这样便能实现抓取闭包函数所标注的参数类型。可以获得类对象形式的结果
        # 了解fastapi的应该知道,其具有检测参数typing的功能,进行强制约束
        # 其原理正是运用了这种方法
        for idx, parameter in enumerate(sig.parameters.values()):
            # 检测参数名
            if parameter.name == "request" or parameter.name == "websocket":
                type = parameter.name
                break
                # 是http请求还是ws请求
        else:
            raise Exception(
                f'No "request" or "websocket" argument on function "{func}"'
            )

        if type == "websocket":
            # 处理ws函数. (用于是异步)
            @functools.wraps(func)
            async def websocket_wrapper(
                *args: typing.Any, **kwargs: typing.Any
            ) -> None:
                websocket = kwargs.get("websocket", args[idx])
                assert isinstance(websocket, WebSocket)
                # 从参数中找到传过来的ws对象
                if not has_required_scope(websocket, scopes_list):
                    # 缺少必须项,则直接关闭链接
                    await websocket.close()
                else:
                    await func(*args, **kwargs)

            return websocket_wrapper
        # 下面与两个逻辑基本一致
        elif asyncio.iscoroutinefunction(func):
            # 处理 异步 request/response 函数.
            @functools.wraps(func)
            async def async_wrapper(
                *args: typing.Any, **kwargs: typing.Any
            ) -> Response:
                request = kwargs.get("request", args[idx])
                assert isinstance(request, Request)

                if not has_required_scope(request, scopes_list):
                    if redirect is not None:
                        return RedirectResponse(
                            url=request.url_for(redirect), status_code=303
                        )
                    raise HTTPException(status_code=status_code)
                return await func(*args, **kwargs)

            return async_wrapper

        else:
            # 处理 同步 request/response 函数.
            @functools.wraps(func)
            def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response:
                request = kwargs.get("request", args[idx])
                assert isinstance(request, Request)

                if not has_required_scope(request, scopes_list):
                    if redirect is not None:
                        return RedirectResponse(
                            url=request.url_for(redirect), status_code=303
                        )
                    raise HTTPException(status_code=status_code)
                return func(*args, **kwargs)

            return sync_wrapper
            # 可以发现除了同步和异步的func执行方式不同,其他完全一致。
            # 如果我写的话,可能会强迫症合并
            # 但合并后性能并不会提高
    return decorator

对比scope的函数

def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool:
    for scope in scopes:
        if scope not in conn.auth.scopes:
            return False
    # 设置的必须scopes中的所有项,在实际的ws或request中必须全部存在
    return True

官方提供了一些简单的模型

class AuthenticationError(Exception):
    pass


class AuthenticationBackend:
    async def authenticate(
        self, conn: HTTPConnection
    ) -> typing.Optional[typing.Tuple["AuthCredentials", "BaseUser"]]:
        raise NotImplementedError()  # pragma: no cover


class AuthCredentials:
    def __init__(self, scopes: typing.Sequence[str] = None):
        self.scopes = [] if scopes is None else list(scopes)


class BaseUser:
    @property
    def is_authenticated(self) -> bool:
        raise NotImplementedError()  # pragma: no cover

    @property
    def display_name(self) -> str:
        raise NotImplementedError()  # pragma: no cover

    @property
    def identity(self) -> str:
        raise NotImplementedError()  # pragma: no cover


class SimpleUser(BaseUser):
    def __init__(self, username: str) -> None:
        self.username = username

    @property
    def is_authenticated(self) -> bool:
        return True

    @property
    def display_name(self) -> str:
        return self.username


class UnauthenticatedUser(BaseUser):
    @property
    def is_authenticated(self) -> bool:
        return False

    @property
    def display_name(self) -> str:
        return ""

上一篇 下一篇

猜你喜欢

热点阅读