pytorch的dim理解
2024-04-24 本文已影响0人
sretik
import numpy as np
mynp = np.arange(24).reshape(2,3,4)
print(mynp)
print(mynp.mean(0))
print(mynp.mean(1))
print(mynp.mean(2))
print(mynp.mean(-3))
print(mynp.mean(-2))
print(mynp.mean(-1))
#输出
[[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]
[[12 13 14 15]
[16 17 18 19]
[20 21 22 23]]]
#dim=0
[[ 6. 7. 8. 9.]
[10. 11. 12. 13.]
[14. 15. 16. 17.]]
#dim=1
[[ 4. 5. 6. 7.]
[16. 17. 18. 19.]]
#dim=2
[[ 1.5 5.5 9.5]
[13.5 17.5 21.5]]
#dim=-3
[[ 6. 7. 8. 9.]
[10. 11. 12. 13.]
[14. 15. 16. 17.]]
#dim=-2
[[ 4. 5. 6. 7.]
[16. 17. 18. 19.]]
#dim=-1
[[ 1.5 5.5 9.5]
[13.5 17.5 21.5]]
pytorch中shape(a,b,c)中的a、b、c表示表示对应维度的size,而维度也有一个序号:正序是0、1、2;逆序是-1、-2、-3。
如示例代码所示,0和-3、1和-2、2和-1维度计算的均值是相同的。
另外,指定维度的运算是对应维度内元素的运算,如dim=0的均值,维度0的size为2,即其由2个元素
[[ 0 1 2 3][ 4 5 6 7][ 8 9 10 11]]
[[12 13 14 15][16 17 18 19][20 21 22 23]]
其均值计算是元素对应位置相加求平均:0和12、1和13...