人工不智能PyTorch/JIT

理解PyTorch分发机制的内部工作原理

2023-06-02  本文已影响0人  A君来了

概述

PyTorch的成功归功于其简单易用性(与Python的用法相似)和动态灵活性。即使在PyTorch 2.0时代,它仍然保持着"Faster, more pythonic and dynamic as ever"的核心特性。

PyTorch的动态性源自内部的调度器(dispatcher),它可以根据不同的输入类型自动选择正确的运算方式。当调用Python函数时,调度器会根据传入的参数类型选择正确的操作实现,这个过程称为分派(dispatch)。

例如,当执行矩阵乘法(torch.matmul(a, b))时,调度器会根据输入张量a和b的类型(dtype、shape、device等)选择正确的BLAS库(CPU还是CUDA,float还是half,是否批量计算)来进行计算。对于PyTorch来说,模型的执行过程就是将各个操作(op)分派给本地方法(native function)执行的过程。

http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/

dispatcher 为每个 op 都维护了一张跳转表(它有点像 C++ 实现多态用的虚表),如上图所示,表中每个条目存储了一个本地方法,有些方法和输入张量所属的设备有关,比如 XLA/CUDA/CPU,有的和 requires_grad 有关,比如 Autograd(这图是从 ezyang’s blog 拿来的,他这篇博客详细讲解了分派机制,建议阅读)。

当 op 被执行时,e.g. aten::addmm,调度器会在它的跳转表中找出一个方法来执行,而且一个 op 执行过程可能会调用多个方法,例如,输入张量需要求导(requires_grad = true),那会先调用 Autograd 方法来构建反向图,再调用 backend(CPU/CUDA/XLA)的方法来运算。

分派规则

http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/

跳转表里的条目是以键值对的形式来存调度方法,其中“键”称为 dispatch key,以 bit 的形式存在,bit 值越大,优先级越高,调度器会从键集(dispatch key set)中选取优先级最高的条目来执行。

从上图可以看到,键集不只有一个,每个输入张量都有自己的键集,还有 local(local includelocal exclude) 和 global 键集,这些键集最终会合并,调度器从中选取优先级最高的键值对应的方法来执行。

输入张量的键集是比较好理解的,张量本身具有很多属性,如 layout (dense or sparse)、shape 和 device (CPU or CUDA),一个属性对应一个 dispatch key(可以从 DispatchKey.h 找到所有的 key)。对于不同类型的张量,我们希望能使用不同实现的操作以实现高性能计算的目标。

Local 键集 与张量个体无关,与模型的行为有关,表示模型运行在某模式中,比如 tracing。它可以允许用户在某个范围内开启或关闭模式。要开启模式就是往 local include 里添加键,要关闭模式就是往 local exclude 里添加要屏蔽的键。

Global 则表示无论什么操作都会添加的键集(图中 autograd 已经从 global 移到 tensor 键集)。

分派流程

http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/

前面也提到,一个 op 的执行是要经历多次分派的,上图就展示了这个过程:

前面提到,调度器会调用优先级最高的 dispatch key,因此,重新分派的前提是将已经调度过的键从键集里清除,否则重新分派将会重复调用相同的方法。

Autograd 的本地方法通过在 local exclude 键集中添加要屏蔽的键(Autograd)来避免方法的重复调用。可以通过创建 AutoNonVariableTypeMode RAII guard 来实现:

class MyAddFunction : public torch::autograd::Function<MyAddFunction> {
 public:
  static Tensor forward(
      AutogradContext *ctx, torch::Tensor self, torch::Tensor other) {
    at::AutoNonVariableTypeMode g;
    return myadd(self, other);
  }
  ...
};

注册自定义操作

回想一下分派规则:调度器首先找到 op 对应的跳转表,合并键集,并调用键值最大的条目中的函数。由于 dispatch key 是 PyTorch 固定且不可扩展的,因此注册自定义操作需要注册 op 以及跳转表中键的方法。

注册 op

TORCH_LIBRARY(myops, m) {
  m.def("myadd(Tensor self, Tensor other) -> Tensor");
}

PyTorch 提供 TORCH_LIBRARY 用于将 op(也称作 schema stringsignature)注册到一个库里,用户可以在 python 通过 c = torch._ops.myops.myadd(a, b) 调用该 op。

schema 与 TensorFlow 的 op_def 和 ONNX 的 node 一样,都用于描述一个操作,只是由于 PyTorch 是动态图的,schema 不需要也不能承载更多信息。

注册 dispatch function

TORCH_LIBRARY_IMPL(myops, CUDA, m) {
  m.impl("myadd", myadd_cuda);
}

注册完 op 后,接着就可以通过 TORCH_LIBRARY_IMPL 注册 dispatch key 对应的方法。上述代码片段通过将 myadd_cuda 注册到键:CUDA。

除了为每个键单独注册一个方法,还可以为所有的键注册一个共同的方法,这类方法称为 catch-all

TORCH_LIBRARY(myops, m) {
  m.def("myadd", myadd_catchall);
}

此外,还可以为所有 op 的某个键注册一个共同的 fallback 方法:

TORCH_LIBRARY_IMPL(_, XLA, m) {
  m.fallback(xla_fallback);
}

除了 dispatch key 具有优先级外,这些方法也有优先级:impl > catch-all > fallback:

http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/

END

PyTorch的调度器(dispatcher)和分派机制是其灵活性和高性能计算的关键。调度器根据输入类型自动选择适当的操作实现,通过分派流程将操作分派给本地方法执行。分派规则通过 dispatch key 和 keyset 确定执行方法的优先级。注册自定义操作的过程允许用户扩展PyTorch的功能。了解这些原理有助于深入理解PyTorch的内部工作机制,并为模型开发和优化提供指导。

上一篇下一篇

猜你喜欢

热点阅读