keras深度学习模型自然语言处理(NLP)

keras to_categorical 源码分析

2019-12-27  本文已影响0人  b424191ea349

最近对keras里面提供的一些好用的方法感兴趣,遂有此文!

函数的作用

这个函数是常用的将标签处理成ont-hot向量,比如:

array([0, 2, 1, 2, 0])

经过上述函数以后,处理成:

   array([[ 1.,  0.,  0.],
          [ 0.,  0.,  1.],
          [ 0.,  1.,  0.],
          [ 0.,  0.,  1.],
          [ 1.,  0.,  0.]], dtype=float32)

目的可以说很直观了!

简单实现

针对上面的例子,写一个简单的实现:

y = np.array([0, 2, 1, 2, 0])
categorical = np.zeros((y.shape[0], 3))
categorical[np.arange(5), y] = 1

这样最终的结果是:

array([[1., 0., 0.],
       [0., 0., 1.],
       [0., 1., 0.],
       [0., 0., 1.],
       [1., 0., 0.]])

可以看到,三行代码问题就解决了!三行代码里面,一二两行很简单,第三行利用了numpy里面的广播机制,实质构成的效果等同于categorical[0,0]=1,categorical[1,2]=1,如此5次。
广播机制可以参考:https://www.runoob.com/numpy/numpy-broadcast.html

源码分析

上述简单例子的实现不能满足复杂的矩阵,比如:我们将y做一个修改:y = np.array([0, 2, 1, 2, 0]).reshape(-1,1),此时y变成了:

[[0]
 [2]
 [1]
 [2]
 [0]]

上述代码就不能很好的工作了,所以想要实现一个通用的to_categorical函数,还需要分析源码!

分析源码之前介绍一波numpy中的 ravel方法:
这个方法是把数组碾平成一维的,比如:

x = np.arange(12).reshape(2,6)
x.ravel()
# 结果是:
# array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])

再看一个例子

x = np.arange(12).reshape(2,2,3)
x.ravel()
# 结果是:
# array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])

源码分析如下:

def to_categorical(y, num_classes=None, dtype='float32'):
    # 确保是numpy 数组
    y = np.array(y, dtype='int')
    input_shape = y.shape
    if input_shape and input_shape[-1] == 1 and len(input_shape) > 1:
        # 当shape=(4,1)需要特殊处理一下,就是后面的1不要
        # 至于为什么不要,因为下面会有一个shape拼接的操作
        # 如果数据是[1,2,3,4],那么最后形成的shape实际上是(4,num_class)
        # 如果数据维度是(4,2),形成的数据shape是(4,2,num_class)
        input_shape = tuple(input_shape[:-1])
    # 把y碾平成一维数组,然后就变成了和上面简单的例子一样了
    y = y.ravel()
    if not num_classes:
        # 如果没有指定class num的话 自动计算一下
        num_classes = np.max(y) + 1
    # 同上面例子那部分
    n = y.shape[0]
    categorical = np.zeros((n, num_classes), dtype=dtype)
    categorical[np.arange(n), y] = 1
    # 将最终的shape拼接好(这里也是为什么如果最后一维是1,并且总维度大于1就需要删减掉最后一个维度的原因)
    output_shape = input_shape + (num_classes,)
    # 将维度reshape回去
    categorical = np.reshape(categorical, output_shape)
    return categorical
上一篇下一篇

猜你喜欢

热点阅读