numpy的axis是什么
2018-03-19 本文已影响0人
22790fe4fb44
numpy的axis是一个比较抽象的概念,我们先看看一个例子,从一些感性的认识开始。
import numpy as np
x = np.array([[1, 2], [3, 4]])
print('x is %s' % x)
sum0 = np.sum(x, axis=0, keepdims=True)
print('sum0 is %s' % sum0)
sum1 = np.sum(x, axis=1, keepdims=True)
print('sum1 is %s' % sum1)
x is [[1 2]
[3 4]]
sum0 is [[4 6]]
sum1 is [[3]
[7]]
从中可以看到一个大致的规律,就是axis=0的时候,表示以外面的括号为准,将括号里面的元素相加,那么就是[1, 2] + [3, 4], 而axis=1的时候,表示以里面的括号为准,将括号里面的元素相加,那么就是[1 + 2][3 + 4]。
我们再来看看axis的定义。
In Numpy dimensions are called axes. The number of axes is rank.
x = np.array([[1, 2], [3, 4]])
print(x.shape)
print(x.ndim)
(2, 2)
2
也就是说,numpy的axis就是矩阵的维度,通过shape属性可以获取到它是一个几乘几的矩阵。ndim表示矩阵的rank,也就是我们说他是几个维度的。我们看一个3维的矩阵。
x = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
print(x.shape)
print(x.ndim)
(2, 2, 2)
3
我们看看axis更直观的理解,在一个矩阵里,通过沿着不同的axis进行索引,我们就可以确定某个元素在该矩阵中的位置。
x = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
print(x[0, 1, 0])
# 选取axis为0的第一个元素
print(x[0, :, :])
# 选取axis为0的第二个元素
print(x[1, :, :])
# 两个元素的和
print(np.sum(x, axis=0, keepdims=True))
print(np.sum(x, axis=0))
3
[[1 2]
[3 4]]
[[5 6]
[7 8]]
[[[ 6 8]
[10 12]]]
[[ 6 8]
[10 12]]
我们了解数学上的定义后,我们结合具体的例子,看看统计学上的意味。某足球联赛的三个小组,A,B,C,按照A:B, B:C, A:C的顺序比赛,他们某一轮次的比分记为:[[0, 1], [2, 3], [2, 4]],那么比赛3轮之后的比分记为:
x = np.array([[[0, 1], [2, 3], [2, 4]], [[2, 1], [4, 4], [1, 2]], [[1, 1], [2, 1], [2, 0]]])
print(x.shape)
(3, 3, 2)
那么我们如果想知道哪一轮的进球次数最多,应该如何计算呢?
# 每一轮的进球次数
scores_sum = np.sum(np.sum(x, axis=2), axis=1)
print(scores_sum)
print('进球次数最多的是第%d轮' % (np.argmax(scores_sum) + 1))
[12 14 7]
进球次数最多的是第2轮
如果我们计算哪两个小队之间的比赛的进球次数最多,那么有:
# 两个小队之间的进球次数
pairs = ['A和B', 'B和C', 'A和C']
scores_pair = np.sum(np.sum(x, axis=0), axis=1)
print(scores_pair)
print('进球次数最多的是%s之间的比赛' % (pairs[np.argmax(scores_pair)]))
[ 6 16 11]
进球次数最多的是B和C之间的比赛
我们可以看到,在不同的axis上统计,就会使用不同的视角来索引数据,进行归并。