统一 MXNet & PyTorch & TensorFlow

2021-02-21  本文已影响0人  水之心

为了在 MXNet、PyTorch 以及 TensorFlow 之间进行互转,本文研究它们之间的基础运算的异同。

基础函数

对于 MXNet 有 np(numpy)模块和 npx(numpy_extension)模块。np 模块包含了 NumPy 支持的函数。而 npx 模块包含了一组扩展函数,用来在类似 NumPy 的环境中实现深度学习开发。当使用张量时,我们几乎总是会调用 set_np 函数:这是为了兼容 MXNet 其他的张量处理组件。

from mxnet import np, npx

npx.set_np()

同样,我们可以将 PyTorch 和 TensorFlow 写作 np,这样有:

def import_np(module_name):
    if module_name == 'mxnet':
        from mxnet import np, npx
        npx.set_np()
        np.randn = np.random.randn
        return np
    elif module_name == 'torch':
        import torch as np
        np.array = np.tensor
        np.concatenate = np.cat
        return np
    elif module_name == 'tensorflow':
        from tensorflow.experimental import numpy as np
        return np

这样 MXNet,TensorFlow 与 PyTorch 有相同的函数:

MXNet TensorFlow PyTorch

当然也有许多不同,比如求张量的大小,标准正太分布与张量定义,MXNet:

MXNet TensorFlow PyTorch

张量运算

MXNet,TensorFLow 与 PyTorch 是几乎一致的。

逐元素运算

MXNet TensorFlow PyTorch

张量运算

张量的真值、元素求和、拼接:

MXNet TensorFlow PyTorch

仔细观察也可以看到细微的不同,np.arange 的数据类型的默认值不同,故而,建议:定义张量最好也把 dtype 也指定了。

广播机制

MXNet TensorFlow PyTorch

索引和切片

MXNet TensorFlow PyTorch

节省内存

运行一些操作可能会导致为新结果分配内存。例如,如果我们用 Y = X + Y,我们将取消引用 Y 指向的张量,而是指向新分配的内存处的张量。下面展示了节省内存的方法:

MXNet & PyTorch TensorFlow

__array__ 的妙用

由于 MXNet,TensorFLow 与 PyTorch 均实现了 __array__,故而,它们均可直接传入 NumPy 数据到张量。下面仅以 TensorFlow 为例:

这样的好处是,可以直接使用 Matplotlib 画图:

BUG

nvidia-smi指令报错:Failed to initialize NVML: Driver解决 - 知乎 (zhihu.com)

上一篇下一篇

猜你喜欢

热点阅读