从PyTorch到TorchScript: 打通深度学习模型的生
概述
PyTorch是一款非常流行的深度学习框架,开发者和研究者常常选择它,因为它具有灵活性、易用性和良好的性能。然而,PyTorch的灵活易用性是建立在动态计算图的基础上的,相比采用静态图的TensorFlow,PyTorch在推理性能和部署方面存在明显的劣势。
为了解决这个问题,TorchScript应运而生。它将PyTorch模型转换为静态类型的优化序列化格式,以实现高效的优化和跨平台部署(包括C++、Python、移动设备和云端)。
构建 TorchScript
TorchScript将PyTorch模型转换为静态图形式,因此构建TorchScript的核心是构建模型的静态计算图。
PyTorch提供了两种方法来构建TorchScript:trace和script。
-
torch.jit.trace
:该函数接收一个已训练好的模型和实际输入样例,通过运行模型的方式来生成静态图(static graph)。这种转换方式称为"追踪模式"(tracing mode)。 -
torch.jit.script
:该函数将PyTorch代码编译成静态图。与追踪模式相反,它被称为"脚本模式"(scripting mode),因为它直接将PyTorch代码翻译成静态图,而不需要追踪执行流程。
Tracing Mode
model = torch.nn.Sequential(nn.Linear(3, 4))
input = torch.randn(1, 3)
traced_model = torch.jit.trace(model, input)
追踪模式通过运行模型一次,并根据操作序列生成静态图。因此,它需要提供输入样例(input)。通过追踪机制,自动捕捉和生成模型的计算图。这是许多AI编译器采用的JIT模式。然而,追踪模式存在一个问题,即无法处理控制流,例如if、while等语句。
class MyModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
if x > 0:
x += 1
else:
x -= 1
return x
model = MyModel()
input = torch.randn(1)
traced_model = torch.jit.trace(model, input)
以if
语句为例,它是Python的语句,根据具体的x
值,只能在then分支或else分支上执行。因此,追踪模式只能捕捉到一个分支上的操作。要想生成完整的控制流图,需要采用脚本模式。
Scripting Mode
与追踪模式不同,Scripting 模式直接将 Python 和 PyTorch 的语句翻译成 TorchScript 的静态图,因此不需要追踪模型的执行流程,并且能够生成完整的控制流图:
script_model = torch.jit.script(model)
print(script_model.graph)
graph(%self : __torch__.___torch_mangle_3.MyModel,
%x.1 : Tensor):
......
%x : Tensor = prim::If(%6) # <ipython-input-3-6fda6c66b1df>:6:4
block0():
%x.7 : Tensor = aten::add_(%x.1, %8, %8) # <ipython-input-3-6fda6c66b1df>:7:6
-> (%x.7)
block1():
%x.13 : Tensor = aten::sub_(%x.1, %8, %8) # <ipython-input-3-6fda6c66b1df>:9:6
-> (%x.13)
return (%x)
然而,这种模式也有其局限性:对于每个语句,都需要提供相应的转换函数,将 Python/PyTorch 语句转换成 TorchScript 语句。目前,PyTorch仅支持部分 Python 内置函数和 PyTorch 语句的转换。
Tracing + Script
因此,对于具有控制流的模型,可以采用混合模式:将追踪模式无法处理的控制流图封装为子模块,使用脚本模式来转换这些子模块,然后通过追踪机制对整个模型进行追踪(通过脚本模式转换后的子模块不会再被追踪)。有关具体实现,请参考官方示例:https://pytorch.org/docs/stable/jit.html#mixing-tracing-and-scripting。
运行 TorchScript
![](https://img.haomeiwen.com/i13575947/437e1e8cbc0b5441.png)
前面生成的计算图会封装到 TorchScript 模块的 forward()
方法中,在运行时被编译成 native code(JIT)。如上图所示,特化后的计算图经过图优化后被编译成 native code,最后通过栈机解释器执行。
Specialization
JIT(just-in-time)将静态图编译后的结果以 <signature: executable> 键值对的形式存储在缓存中。只有在缓存未命中(cache miss)时,也就是首次运行时,才会触发编译过程。
Signature 表示唯一的静态计算图。在计算流图不变的情况下,静态图由输入参数(arguments)决定。不同的 dtype、shape 的参数将生成不同的静态计算图。
Specialization 的目的是根据 torchscript 的输入(Input),为参数赋予 dtype、shape、设备类型(CPU、CUDA)等静态信息(ArgumentSpec
),生成 signature,以便为缓存搜索做准备。
# post specialization, inputs are now specialized types
graph(%x : Float(*, *),
%hx : Float(*, *),
%cx : Float(*, *),
%w_ih : Float(*, *),
%w_hh : Float(*, *),
%b_ih : Float(*),
%b_hh : Float(*)):
%7 : int = prim::Constant[value=4]()
%8 : int = prim::Constant[value=1]()
%9 : Tensor = aten::t(%w_ih)
Optimization
PyTorch JIT 使用一系列 passes(torch.jit.passes
)对图进行优化,旨在从执行效率、内存占用等方面优化计算。其中包括对 dtype、shape 和常量进行前向推导的形状推导(Shape inference)和常数传播(Const propagation)等优化,以减少实际操作的数量。
除了上述常见的优化,对于 GPU 来说,最核心的优化是算子融合(Operation fusion):将匹配的一组算子合并为一个算子。例如,将连续的一系列 element-wise 操作合并为一个操作,这样可以减少 CUDA kernels 的启动时间开销,并减少操作之间访问全局内存的次数。
图优化是 AI 编译器的标配,用于优化计算图的执行效率和内存占用等方面。PyTorch 的图优化通过一系列 passes(torch.jit.passes)来实现,包括常数折叠(Constant folding)、死代码清除(Dead code elimination)和算子融合(Operation fusion)等。
在图优化过程中,FuseGraph
pass 将可以融合的算子封装为 FusionGroup 静态子图:
graph(%x : Float(*, *),
...):
%9 : Float(*, *) = aten::t(%w_ih)
...
%77 : Tensor[] = prim::ListConstruct(%b_hh, %b_ih, %10, %12)
%78 : Tensor[] = aten::broadcast_tensors(%77)
%79 : Tensor, %80 : Tensor, %81 : Tensor, %82 : Tensor = prim::ListUnpack(%78)
%hy : Float(*, *), %cy : Float(*, *) = prim::FusionGroup_0(%cx, %82, %81, %80, %79)
%30 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%hy, %cy)
return (%30);
with prim::FusionGroup_0 = graph(%13 : Float(*, *),
...):
%87 : Float(*, *), %88 : Float(*, *), %89 : Float(*, *), %90 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%86)
%82 : Float(*, *), %83 : Float(*, *), %84 : Float(*, *), %85 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%81)
%77 : Float(*, *), %78 : Float(*, *), %79 : Float(*, *), %80 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%76)
%72 : Float(*, *), %73 : Float(*, *), %74 : Float(*, *), %75 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%71)
%69 : int = prim::Constant[value=1]()
%70 : Float(*, *) = aten::add(%77, %72, %69)
%66 : Float(*, *) = aten::add(%78, %73, %69)
...
%4 : Float(*, *) = aten::tanh(%cy)
%hy : Float(*, *) = aten::mul(%outgate, %4)
return (%hy, %cy)
Codegen
优化的最后是为图(symbolic graph)中的符号操作生成加速器所需的操作内核(op kernel)。PyTorch 已经为 CPU 和 Nvidia GPU 提供了一个名为 ATen 的 C++ 算子库,像图中的 aten::add
节点就会在运行时调用 built-in 算子。
对于融合算子,PyTorch 提供了基于 LLVM 的 NNC
编译器,用于生成相应的目标代码。它将 FusionGroup 子图里的 node lowering 成 C++ functions,再基于 LLVM 将它们编译成一个大算子:
RegisterNNCLoweringsFunction aten_matmul(
{"aten::mm(Tensor self, Tensor mat2) -> (Tensor)",
"aten::matmul(Tensor self, Tensor other) -> (Tensor)"},
computeMatmul);
Tensor computeMatmul(...) {
...
return Tensor(
ResultBuf.node(),
ExternalCall::make(ResultBuf, "nnc_aten_matmul", {a, b}, {}));
}
void nnc_aten_matmul(...) {
...
try {
at::matmul_out(r, self, other);
} catch (...) {}
}
Interpreter
TorchScript 提供一个栈机解释器在C++上高效地运行计算图:
// Create a vector of inputs.
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({1, 3, 224, 224}));
// Execute the model and turn its output into a tensor.
at::Tensor output = module.forward(inputs).toTensor();
std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';