JAX 记录
2020-10-16 本文已影响0人
yxd886
测试官方sample里的resnet50,用的机器是单卡v100.
batch size 设置为32.
先测试了一下对update函数默认带了@jit的,也就是开启了XLA Jit优化。
执行时间大概是0.16s/step
然后关闭jit发现时间变成了2~3s/step。两者差异巨大。
于是用nvprof profile了一下。
使用jit的情况:
GPU kernel 执行情况:

API call 情况:

关闭jit的情况:
GPU kernel 情况:

API call 情况:
