FastAPI 解读 by Gascognya程序员

FastAPI 完整CBV实现

2020-10-17  本文已影响0人  Gascognya
10.23更新:装饰器做了增强

FastAPI的CBV实现,之前出了篇文章。做了个简单的实现方法。
今天逛github发现有牛人造了个FastAPI的tools包。

https://github.com/dmontagu/fastapi-utils

里面有关于CBV的实现,这位对于反射的理解让我非常佩服。抽空在其代码上做了点简化。因为并非自己原创,不好意思贴回github。在此贴出。

首先我加了个CBV专用的Router类,可以被APIRouter所include。
class CBVRouter(Router):
    def __init__(
            self,
            app: FastAPI,
            path: str,
            group_name: str,
            tags: Optional[List[str]] = None,
            description: Optional[str] = None,
            summary: Optional[str] = None,
            routes: Optional[List[routing.BaseRoute]] = None,
            redirect_slashes: bool = True,
            default: Optional[ASGIApp] = None,
            dependency_overrides_provider: Optional[Any] = None,
            route_class: Type[APIRoute] = APIRoute,
            default_response_class: Optional[Type[Response]] = None,
            on_startup: Optional[Sequence[Callable]] = None,
            on_shutdown: Optional[Sequence[Callable]] = None,
    ) -> None:
        """

        :param app: FastAPI的APP
        :param group_name: 配置一个CBV的方法们独有的名字,方便标识。

        :param path: 整合参数,只能在此输入,必填
        :param tags: 整合参数,默认值是group_name
        :param description: 整合参数,只能在此输入
        :param summary: 整合参数,只能在此输入,默认值是group_name_方法名


        """
        super().__init__(
            routes=routes,
            redirect_slashes=redirect_slashes,
            default=default,
            on_startup=on_startup,
            on_shutdown=on_shutdown,
        )
        self.dependency_overrides_provider = dependency_overrides_provider
        self.route_class = route_class
        self.default_response_class = default_response_class

        self.app = app
        self.path = path
        self.name = group_name
        self.tags = tags or [group_name]
        self.description = description
        self.summary = summary

    def method(
            self,
            response_model: Optional[Type[Any]] = None,
            status_code: int = 200,
            summary: Optional[str] = None,
            tags: Optional[List[str]] = [],
            response_description: str = "Successful Response",
            dependencies: Optional[Sequence[params.Depends]] = None,
            responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
            deprecated: Optional[bool] = None,
            response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
            response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
            response_model_by_alias: bool = True,
            response_model_exclude_unset: bool = False,
            response_model_exclude_defaults: bool = False,
            response_model_exclude_none: bool = False,
            include_in_schema: bool = True,
            response_class: Optional[Type[Response]] = None,
            name: Optional[str] = None,
            callbacks: Optional[List[APIRoute]] = None,
    ) -> Callable:
        def decorator(func: Callable) -> Callable:
            method = getattr(func, "__name__", None)
            assert method, "装饰器使用方式错误"

            assert method in ['get', 'post', 'put', 'delete', 'options', 'head', 'patch', 'trace'], 
                "请将方法名配置为' HTTP METHOD '中的一个"

            tags.extend(self.tags)

            route_class = self.route_class
            route = route_class(
                self.path,
                endpoint=func,
                response_model=response_model,
                status_code=status_code,

                tags=tags,
                description=self.description,
                methods=[method],
                operation_id=f'{self.name}_{self.path[1:]}_{method}',
                summary=summary or f'{self.name} _ {method}',

                dependencies=dependencies,
                deprecated=deprecated,
                response_description=response_description,
                responses=responses or {},
                response_model_include=response_model_include,
                response_model_exclude=response_model_exclude,
                response_model_by_alias=response_model_by_alias,
                response_model_exclude_unset=response_model_exclude_unset,
                response_model_exclude_defaults=response_model_exclude_defaults,
                response_model_exclude_none=response_model_exclude_none,
                include_in_schema=include_in_schema,
                response_class=response_class or self.default_response_class,
                name=name,
                dependency_overrides_provider=self.dependency_overrides_provider,
                callbacks=callbacks,
            )
            self.routes.append(route)

            return func

        return decorator
@router.method()
def get(self):
-----------------------------
@app.get(path)
def xxx():

即上述两者是等价的

接下来的部分是对原代码的小修改
T = TypeVar("T")

def API(item: object):
    """
    我们为了灵活性,设立了主动和被动两种模式。
    主动模式代表,将FastAPI的app传入到router,router自动实现挂载。
    # app = FastAPI()
    # @CBV
    # class TestClass:
    #     router = CBVRouter(app=app, path="/user", group_name="User")

    如果本模块是启动的主模块,这样是一个好的选择。

    被动模式代表,提供router,包含app的主模块import这个router,交由它们去处理!
    这样我们不再期待获得app,并且我们需要暴露一个router
    # router = CBVRouter(path="/user", group_name="User")
    # @CBV(router)
    # class TestClass:

    这样我认为是一个比较好的方式。所以主动与被动的最大区别是写法。
    故这个装饰器,包含两种使用方式: @CBV 与 @CBV(router)

    :param item: 可能代表router或者decorator传入的cls
    :return: 一个decorator,或者一个cls
    """
    if isinstance(item, CBVRouter):
        router = item

        def decorator(cls: Type[T]):
            _get_method(cls, router)
            return cls

        return decorator

    # --------要求配置Router以及指定App--------
    else:
        cls = item
        router = None
        for attr in cls.__dict__.values():
            if isinstance(attr, CBVRouter):
                router = attr
        assert router, "请配置一个Router到类属性router"
        app = getattr(router, 'app')
        assert app, "请指定要挂载的app"

        _get_method(cls, router)
        app.include_router(router)
        return cls
    
def _get_method(cls, router):
    """抽离的公共代码"""
    # ------------修改__init__签名------------
    update_cbv_class_init(cls)

    # ----------------抓取方法----------------
    function_members = inspect.getmembers(cls, inspect.isfunction)
    functions_set = set(func for _, func in function_members)

    def temp(r):
        if isinstance(r, (Route, WebSocketRoute)) and r.endpoint in functions_set:
            _update_endpoint_self_param(cls, r)
            return True
        return False

    router.routes = list(filter(temp, router.routes))


def update_cbv_class_init(cls: Type[Any]) -> None:
    """
    重定义类的__init__(), 更新签名和参数
    """
    CBV_CLASS_KEY = "__cbv_class__"

    if getattr(cls, CBV_CLASS_KEY, False):
        return  # Already initialized

    old_init: Callable[..., Any] = cls.__init__
    old_signature = inspect.signature(old_init)
    old_parameters = list(old_signature.parameters.values())[1:]

    new_parameters = [
        x for x in old_parameters 
        if x.kind not in 
           (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
    ]

    dependency_names: List[str] = []
    for name, hint in get_type_hints(cls).items():
        if is_classvar(hint):
            continue
        parameter_kwargs = {"default": getattr(cls, name, Ellipsis)}
        dependency_names.append(name)
        new_parameters.append(
            inspect.Parameter(
                name=name, 
                kind=inspect.Parameter.KEYWORD_ONLY, 
                annotation=hint, 
                **parameter_kwargs
            )
        )
    new_signature = old_signature.replace(parameters=new_parameters)

    def new_init(self: Any, *args: Any, **kwargs: Any) -> None:
        for dep_name in dependency_names:
            dep_value = kwargs.pop(dep_name)
            setattr(self, dep_name, dep_value)
        old_init(self, *args, **kwargs)

    setattr(cls, "__signature__", new_signature)
    setattr(cls, "__init__", new_init)
    setattr(cls, CBV_CLASS_KEY, True)


def _update_endpoint_self_param(cls: Type[Any], route: Union[Route, WebSocketRoute]) -> None:
    """
    调整endpoint的self参数,使其变为self=Depends(cls)
    这样每次处理依赖时,就可以实例化一个对象
    """
    old_endpoint = route.endpoint
    old_signature = inspect.signature(old_endpoint)
    old_parameters: List[inspect.Parameter] = list(old_signature.parameters.values())
    old_first_parameter = old_parameters[0]
    new_first_parameter = old_first_parameter.replace(default=Depends(cls))
    new_parameters = [new_first_parameter] + [
        parameter.replace(kind=inspect.Parameter.KEYWORD_ONLY) for parameter in old_parameters[1:]
    ]
    new_signature = old_signature.replace(parameters=new_parameters)
    setattr(route.endpoint, "__signature__", new_signature)

使用示例

主动方式:
def dependency(num: int) -> int:
    return num

app = FastAPI()

@API
class TestClass:
    router = CBVRouter(app, path="/user", group_name="User")

    x: int = Depends(dependency)
    cx: ClassVar[int] = 1
    cy: ClassVar[int]

    def __init__(self, z: int = Depends(dependency)):
        self.y = 1
        self.z = z

    @router.method(response_model=int)
    def get(self) -> int:
        return self.cx + self.x + self.y + self.z

    @router.method(response_model=bool)
    def post(self) -> bool:
        return hasattr(self, "cy")

    @router.method()
    def put(self):
        return {"msg": "put"}
    
    @router.method()
    def delete(self):
        return {"msg": "delete"}
被动方式
router = CBVRouter(path="/user", group_name="User")

@API(router)
class TestClass:
    x: int = Depends(dependency)
    cx: ClassVar[int] = 1
    cy: ClassVar[int]
    ......
successful
能够正确发现依赖
上一篇下一篇

猜你喜欢

热点阅读