人工不智能XLA编译器

JIT(上):Tensorflow如何实现即时编译?

2022-04-13  本文已影响0人  A君来了

Tensorflow的JIT(just-in-time)是指在运行@tf.function修饰的python函数时,由jittf2xlaXLA一起完成一系列如子图构造、子图优化、图编译和图执行等操作。编译后的可执行程序--executable会存放到cache中,供再次调用时直接获取执行。JIT的好处在开篇已经讲过了,这里不再赘述。

https://sketch2sky.com/2019/09/24/tensorflow-jit-%E6%8A%80%E6%9C%AF%E8%AF%A6%E8%A7%A3

JIT的流程可以概括为:Tensorflow子图构造/优化,graph -> HLO,编译/执行,合并计算结果到Tensorflow图这四部分。本文只涉及图编译和图执行。

函数中ops在子图构造阶段被包裹进一个cluster node,并替换成xla_compilexla_run这两op,而XlaCompileOpXlaRunOp就是它们的OpKernel,分别用于图编译和执行。

XlaCompileOp通过XlaCompilationCache获取或编译executable,并将其封装成XlaExecutableClosure,并缓存在XlaExecutableClosureStoreXlaRunOp用从XlaCompileOp传递来的key在cache中查找并执行executable。

Signature

从编译流程图可以看到,XLA的编译结果会缓存到XlaCompilationCache,后续调用可以根据signature在cache中查找executable。

函数的signature是由BuildSignature(function, args)根据函数和arguments生成的。即使是同一个函数,只要input tensors不同,signature也会不一样,这就是power()被编译两次的原因:第三次函数调用时,由于无法通过signature在cache中找到executable而触发编译。

signature表示唯一的计算图:只要函数中的ops序列和arguments(type/shape)是确定的,那么计算图也是确定的。

Graph -> HLO

编译之前需要通过tf2xla将图转换成XLA支持的语言HLO。tf2xla为每个Tensorflow op创建了生成HLO的XlaOp,因此,只要执行该Tensorflow子图,就可以生成具有相同的拓扑排序的HLO -- XlaComputation

HLO -> Executable

XlaComputation(HLO)可以认为是一个运行在device上的纯函数,它的input/output会伴随着host-to-device(H2D)和device-to-host(D2H)的数据传输。

我们知道,Tensorflow图中的input tensor有两种:tf.Placeholdertf.Variable,前者每个step都会将新data发送到device,而后者是模型参数,它们会常驻内存,只在store/load checkpoint才会有H2D/D2H。

而纯函数的定义是:

不管是input还是output,虽然variable和其他argument一样存在于HLO的参数列表和返回值列表中,但它们实际上是常驻于device的,不需要也不应该H2D/D2H。

因此,HLO在编译时还需要通过argument_input_indicesresource_input_indicesresource_update_to_input_index等options来区分arguments和variables。

此外,如果有input是常数,为了避免无谓的H2D开销,可以把它固化到函数内部。同理,对于常数output,它没必要出现在函数中,可以直接定义在XlaCompilationResult的output buffer。

XlaCompilationResult是Graph -> HLO的output,它封装了HLO以及上述部分metadata、buffers。

XlaExecutableClosure

XlaCompileOp会把编译好的executable、metadata、input/output buffers、options等统统封装进一个closure -- XlaExecutableClosure,并将其缓存在XlaExecutableClosureStoreXlaRunOp获取。

XlaRunOp

XlaRunOp可以通过一个数字字符串key(从0开始累加)从cache中查找并执行XlaExecutableClosure,这个key由XlaCompileOp提供。

execution_output = closure.executable()->RunAsync(std::move(*execution_inputs), run_options);

References

END

上一篇下一篇

猜你喜欢

热点阅读