JIT(上):Tensorflow如何实现即时编译?
Tensorflow的JIT(just-in-time)是指在运行@tf.function
修饰的python函数时,由jit
、tf2xla
和XLA
一起完成一系列如子图构造、子图优化、图编译和图执行等操作。编译后的可执行程序--executable
会存放到cache中,供再次调用时直接获取执行。JIT的好处在开篇已经讲过了,这里不再赘述。
JIT的流程可以概括为:Tensorflow子图构造/优化,graph -> HLO,编译/执行,合并计算结果到Tensorflow图这四部分。本文只涉及图编译和图执行。
函数中ops在子图构造阶段被包裹进一个cluster node,并替换成xla_compile
和xla_run
这两op,而XlaCompileOp
和XlaRunOp
就是它们的OpKernel
,分别用于图编译和执行。
XlaCompileOp通过XlaCompilationCache
获取或编译executable,并将其封装成XlaExecutableClosure
,并缓存在XlaExecutableClosureStore
。XlaRunOp
用从XlaCompileOp传递来的key在cache中查找并执行executable。
Signature
从编译流程图可以看到,XLA的编译结果会缓存到XlaCompilationCache,后续调用可以根据signature
在cache中查找executable。
函数的signature
是由BuildSignature(function, args)
根据函数和arguments生成的。即使是同一个函数,只要input tensors不同,signature也会不一样,这就是power()
被编译两次的原因:第三次函数调用时,由于无法通过signature在cache中找到executable而触发编译。
signature表示唯一的计算图:只要函数中的ops序列和arguments(type/shape)是确定的,那么计算图也是确定的。
Graph -> HLO
编译之前需要通过tf2xla
将图转换成XLA支持的语言HLO
。tf2xla为每个Tensorflow op创建了生成HLO的XlaOp
,因此,只要执行该Tensorflow子图,就可以生成具有相同的拓扑排序的HLO -- XlaComputation
。
HLO -> Executable
XlaComputation
(HLO)可以认为是一个运行在device上的纯函数,它的input/output会伴随着host-to-device(H2D)和device-to-host(D2H)的数据传输。
我们知道,Tensorflow图中的input tensor有两种:tf.Placeholder
和tf.Variable
,前者每个step都会将新data发送到device,而后者是模型参数,它们会常驻内存,只在store/load checkpoint才会有H2D/D2H。
而纯函数的定义是:
- 除了中间计算结果以外的所有tensors都要以arguments的形式传入函数。因此,不管是tensor还是variable都在函数的参数列表中。
- 所有的输出结果都是通过返回值(ROOT)返回。模型训练的结果是那些经过优化器更新后的参数(variable),它们会作为HLO的返回值。
不管是input还是output,虽然variable和其他argument一样存在于HLO的参数列表和返回值列表中,但它们实际上是常驻于device的,不需要也不应该H2D/D2H。
因此,HLO在编译时还需要通过argument_input_indices
、resource_input_indices
和resource_update_to_input_index
等options来区分arguments和variables。
此外,如果有input是常数,为了避免无谓的H2D开销,可以把它固化到函数内部。同理,对于常数output,它没必要出现在函数中,可以直接定义在XlaCompilationResult
的output buffer。
XlaCompilationResult是Graph -> HLO
的output,它封装了HLO以及上述部分metadata、buffers。
XlaExecutableClosure
XlaCompileOp会把编译好的executable、metadata、input/output buffers、options等统统封装进一个closure -- XlaExecutableClosure
,并将其缓存在XlaExecutableClosureStore
供XlaRunOp
获取。
XlaRunOp
XlaRunOp
可以通过一个数字字符串key(从0开始累加)从cache中查找并执行XlaExecutableClosure,这个key由XlaCompileOp提供。
execution_output = closure.executable()->RunAsync(std::move(*execution_inputs), run_options);