JAX 记录

2020-10-16  本文已影响0人  yxd886

测试官方sample里的resnet50,用的机器是单卡v100.

batch size 设置为32.
先测试了一下对update函数默认带了@jit的,也就是开启了XLA Jit优化。
执行时间大概是0.16s/step

然后关闭jit发现时间变成了2~3s/step。两者差异巨大。

于是用nvprof profile了一下。
使用jit的情况:
GPU kernel 执行情况:


image.png

API call 情况:


image.png

关闭jit的情况:

GPU kernel 情况:


image.png

API call 情况:


image.png
上一篇下一篇

猜你喜欢

热点阅读