Starlette 解读 by Gascognya

Starlette 源码阅读 (四) request

2020-08-11  本文已影响0人  Gascognya

本篇开始解读requests.py

requests.py
包含三个类,三个函数。
ClientDisconnect继承于Exception,内容为pass

HTTPConnection类

class HTTPConnection(Mapping):
    """
    一个传入HTTP链接的基类, 它定义了Request和WebSocket共同使用的所有功能。
    """

    def __init__(self, scope: Scope, receive: Receive = None) -> None:
        assert scope["type"] in ("http", "websocket")
        self.scope = scope

    def __getitem__(self, key: str) -> str:
        return self.scope[key]

    def __iter__(self) -> typing.Iterator[str]:
        return iter(self.scope)

    def __len__(self) -> int:
        return len(self.scope)

    @property
    def app(self) -> typing.Any:
        return self.scope["app"]

    @property
    def url(self) -> URL:
        if not hasattr(self, "_url"):
            self._url = URL(scope=self.scope)
        return self._url

    @property
    def base_url(self) -> URL:
        if not hasattr(self, "_base_url"):
            base_url_scope = dict(self.scope)
            base_url_scope["path"] = "/"
            base_url_scope["query_string"] = b""
            base_url_scope["root_path"] = base_url_scope.get(
                "app_root_path", base_url_scope.get("root_path", "")
            )
            self._base_url = URL(scope=base_url_scope)
        return self._base_url

    @property
    def headers(self) -> Headers:
        if not hasattr(self, "_headers"):
            self._headers = Headers(scope=self.scope)
        return self._headers

    @property
    def query_params(self) -> QueryParams:
        if not hasattr(self, "_query_params"):
            self._query_params = QueryParams(self.scope["query_string"])
        return self._query_params

    @property
    def path_params(self) -> dict:
        return self.scope.get("path_params", {})

    @property
    def cookies(self) -> typing.Dict[str, str]:
        if not hasattr(self, "_cookies"):
            cookies: typing.Dict[str, str] = {}
            cookie_header = self.headers.get("cookie")
            # 这可以直接写成 if cookie_header := self.headers.get("cookie"):
            if cookie_header:
                cookies = cookie_parser(cookie_header)
                # 从headers中获取cookies进行解析
            self._cookies = cookies
        return self._cookies

    @property
    def client(self) -> Address:
        host, port = self.scope.get("client") or (None, None)
        return Address(host=host, port=port)

    @property
    def session(self) -> dict:
        assert (
            "session" in self.scope
        ), "SessionMiddleware must be installed to access request.session"
        return self.scope["session"]

    @property
    def auth(self) -> typing.Any:
        assert (
            "auth" in self.scope
        ), "AuthenticationMiddleware must be installed to access request.auth"
        return self.scope["auth"]

    @property
    def user(self) -> typing.Any:
        assert (
            "user" in self.scope
        ), "AuthenticationMiddleware must be installed to access request.user"
        return self.scope["user"]

    @property
    def state(self) -> State:
        if not hasattr(self, "_state"):
            # 如果state信息尚未填充,请确保至少是个空字典.
            self.scope.setdefault("state", {})
            # 创建一个 state 实例 其中包含对存储state信息的dict的引用
            self._state = State(self.scope["state"])
        return self._state

    def url_for(self, name: str, **path_params: typing.Any) -> str:
        router = self.scope["router"]
        url_path = router.url_path_for(name, **path_params)
        return url_path.make_absolute_url(base_url=self.base_url)

两个用于触发异常的函数

async def empty_receive() -> Message:
    raise RuntimeError("Receive channel has not been made available")

async def empty_send(message: Message) -> None:
    raise RuntimeError("Send channel has not been made available")

Request类

class Request(HTTPConnection):
    def __init__(
        self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send
    ):
        super().__init__(scope)
        # 调用父类的
        # assert scope["type"] in ("http", "websocket")
        # self.scope = scope
        assert scope["type"] == "http"
        self._receive = receive
        self._send = send
        self._stream_consumed = False
        self._is_disconnected = False

    @property
    def method(self) -> str:
        return self.scope["method"]

    @property
    def receive(self) -> Receive:
        return self._receive

stream, body, json三个方法,逐层调用

用于将body转化成json格式

    async def stream(self) -> typing.AsyncGenerator[bytes, None]:
        """
        数据流,接受数据,然后迭代出去
        """
        if hasattr(self, "_body"):
            yield self._body
            yield b""
            return

        if self._stream_consumed:
            raise RuntimeError("Stream consumed")
        # 设置流的消耗状态
        self._stream_consumed = True
        while True:
            message = await self._receive()
            # 接收数据
            if message["type"] == "http.request":
                body = message.get("body", b"")
                if body:
                    yield body
                    # 将数据的body传输出去
                if not message.get("more_body", False):
                    # 没有后续,则停止传输
                    break
            elif message["type"] == "http.disconnect":
                self._is_disconnected = True
                raise ClientDisconnect()
                # 断开连接
        yield b""

    async def body(self) -> bytes:
        if not hasattr(self, "_body"):
            chunks = []
            async for chunk in self.stream():
                chunks.append(chunk)
            self._body = b"".join(chunks)
        return self._body

    async def json(self) -> typing.Any:
        if not hasattr(self, "_json"):
            body = await self.body()
            self._json = json.loads(body)
            # 将body序列化
            # def json → def body → def stream
        return self._json
    async def form(self) -> FormData:
        if not hasattr(self, "_form"):
            assert (
                parse_options_header is not None
            ), "The `python-multipart` library must be installed to use form parsing."
            content_type_header = self.headers.get("Content-Type")
            content_type, options = parse_options_header(content_type_header)
            # 表单解析
            if content_type == b"multipart/form-data":
                multipart_parser = MultiPartParser(self.headers, self.stream())
                self._form = await multipart_parser.parse()
            elif content_type == b"application/x-www-form-urlencoded":
                form_parser = FormParser(self.headers, self.stream())
                self._form = await form_parser.parse()
            else:
                self._form = FormData()
            # 按照类型解析,都不满足返回空表单对象
        return self._form

    async def close(self) -> None:
        if hasattr(self, "_form"):
            await self._form.close()
            # 推测关于文件上传,信息不足

    async def is_disconnected(self) -> bool:
        # 判断链接是否已经断开
        if not self._is_disconnected:
            # 如果没断开
            try:
                # 尝试接收信息
                message = await asyncio.wait_for(self._receive(), timeout=0.0000001)
            except asyncio.TimeoutError:
                message = {}

            if message.get("type") == "http.disconnect":
                self._is_disconnected = True
                # 如果获得断开信息,则设置断开flag为true

        return self._is_disconnected

    async def send_push_promise(self, path: str) -> None:
        """
        推送承诺
        """
        if "http.response.push" in self.scope.get("extensions", {}):
            # 如果scope的扩展项中存在推送项
            raw_headers = []
            for name in SERVER_PUSH_HEADERS_TO_COPY:
                #     "accept"
                #     "accept-encoding"
                #     "accept-language"
                #     "cache-control"
                #     "user-agent"
                for value in self.headers.getlist(name):
                    # 从headers中逐个获取这些项的内容,加入到raw_headers中
                    raw_headers.append(
                        (name.encode("latin-1"), value.encode("latin-1"))
                    )
            await self._send(
                {"type": "http.response.push", "path": path, "headers": raw_headers}
            )
            # 将其发送出去

requests.py→SERVER_PUSH_HEADERS_TO_COPY

SERVER_PUSH_HEADERS_TO_COPY = {
    "accept",
    "accept-encoding",
    "accept-language",
    "cache-control",
    "user-agent",
}

requests.py→cookie_parser

def cookie_parser(cookie_string: str) -> typing.Dict[str, str]:
    """
    这个函数将一个 Cookie 解析为一组键/值对。

    它试图模仿浏览器cookie解析行为:浏览器和web服务器在设置和读取cookie时经常忽略规范RFC 6265
    因此我们尝试适应这里常见的场景。

    这个函数改编自Django 3.1.0。
    注意: 我们明确不使用SimpleCookie.load因为它是基于一个过时的规范,
    并且会在我们想要支持的很多输入时失败
    """
    cookie_dict: typing.Dict[str, str] = {}
    for chunk in cookie_string.split(";"):
        if "=" in chunk:
            key, val = chunk.split("=", 1)
        else:
            # 假设名称为空
            # https://bugzilla.mozilla.org/show_bug.cgi?id=169091
            key, val = "", chunk
        key, val = key.strip(), val.strip()
        if key or val:
            # unquote使用python算法.
            cookie_dict[key] = http_cookies._unquote(val)  # type: ignore
            # 推测是将utf-8编码转换
    return cookie_dict

requests.py的解读到此告一段落,下篇将解读responses.py

上一篇 下一篇

猜你喜欢

热点阅读