numpy的axis
2020-01-17 本文已影响0人
Ginkgo
这是别人的理解
原文
Numpy中有许多函数都带有一个参数:axis(对应于pytorch中的dim参数),用来指定函数计算是在哪个维度上进行的。这些函数包括:mean() 、sum()等。 想必很多人都曾为axis(dim)的用法感到困惑.
这里提供了一个简单的记忆方式:
假设输入的形状是(m, n, k):
- 如果指定axis(dim)=0, 输出的size就是(1, n, k)或者(n, k)
- 如果指定axis(dim)=1, 输出的size就是(m, 1, k)或者(m, k)
- 如果指定axis(dim)=2, 输出的size就是(m, n, 1)或者(m, n).
size中是否有“1”,取决于参数keepdims(或keepdim)。keepdims=True会保留维度1,通常默认是等于False。
注意:以上只是经验总结,对于绝大部分函数适用,少数函数如cumsum(累加)不符合。
此外,对于二维数组(矩阵),可以用一句话记住:axis=0表示分别对每一列做运算;如果axis=1表示分别对每一行做运算。这个tip对几乎所有场景适用(包括cumsum函数等)。
下面为自己的总结:
引用请注明出处,这是本人躺了很多坑之后的总结
最重要的就是记住axis=0位跨行,axis=1为跨列