keras to_categorical 源码分析
2019-12-27 本文已影响0人
b424191ea349
最近对keras里面提供的一些好用的方法感兴趣,遂有此文!
函数的作用
这个函数是常用的将标签处理成ont-hot向量,比如:
array([0, 2, 1, 2, 0])
经过上述函数以后,处理成:
array([[ 1., 0., 0.],
[ 0., 0., 1.],
[ 0., 1., 0.],
[ 0., 0., 1.],
[ 1., 0., 0.]], dtype=float32)
目的可以说很直观了!
简单实现
针对上面的例子,写一个简单的实现:
y = np.array([0, 2, 1, 2, 0])
categorical = np.zeros((y.shape[0], 3))
categorical[np.arange(5), y] = 1
这样最终的结果是:
array([[1., 0., 0.],
[0., 0., 1.],
[0., 1., 0.],
[0., 0., 1.],
[1., 0., 0.]])
可以看到,三行代码问题就解决了!三行代码里面,一二两行很简单,第三行利用了numpy里面的广播机制,实质构成的效果等同于categorical[0,0]=1,categorical[1,2]=1
,如此5次。
广播机制可以参考:https://www.runoob.com/numpy/numpy-broadcast.html
源码分析
上述简单例子的实现不能满足复杂的矩阵,比如:我们将y
做一个修改:y = np.array([0, 2, 1, 2, 0]).reshape(-1,1)
,此时y变成了:
[[0]
[2]
[1]
[2]
[0]]
上述代码就不能很好的工作了,所以想要实现一个通用的to_categorical
函数,还需要分析源码!
分析源码之前介绍一波numpy
中的 ravel
方法:
这个方法是把数组碾平成一维的,比如:
x = np.arange(12).reshape(2,6)
x.ravel()
# 结果是:
# array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
再看一个例子
x = np.arange(12).reshape(2,2,3)
x.ravel()
# 结果是:
# array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
源码分析如下:
def to_categorical(y, num_classes=None, dtype='float32'):
# 确保是numpy 数组
y = np.array(y, dtype='int')
input_shape = y.shape
if input_shape and input_shape[-1] == 1 and len(input_shape) > 1:
# 当shape=(4,1)需要特殊处理一下,就是后面的1不要
# 至于为什么不要,因为下面会有一个shape拼接的操作
# 如果数据是[1,2,3,4],那么最后形成的shape实际上是(4,num_class)
# 如果数据维度是(4,2),形成的数据shape是(4,2,num_class)
input_shape = tuple(input_shape[:-1])
# 把y碾平成一维数组,然后就变成了和上面简单的例子一样了
y = y.ravel()
if not num_classes:
# 如果没有指定class num的话 自动计算一下
num_classes = np.max(y) + 1
# 同上面例子那部分
n = y.shape[0]
categorical = np.zeros((n, num_classes), dtype=dtype)
categorical[np.arange(n), y] = 1
# 将最终的shape拼接好(这里也是为什么如果最后一维是1,并且总维度大于1就需要删减掉最后一个维度的原因)
output_shape = input_shape + (num_classes,)
# 将维度reshape回去
categorical = np.reshape(categorical, output_shape)
return categorical