mxnet:交叉熵损失函数

2020-04-25  本文已影响0人  AI秘籍

1.定义

softmax运算将输出变换为一个合法的概率分布;
对于真实标签,也可以用类别分布表达:
对于样本i,仅样本i的类别的离散数值为1,其余为0.


image.png

为什么不用平方损失函数?


image.png

因此,
改善上述问题的⼀个⽅法是使⽤更适合衡量两个概率分布差异的测量函数。
其中,交叉熵(cross entropy)是⼀个常⽤的衡量⽅法:


image.png

其实,就是熵的定义公式.


image.png

假设训练数据集的样本数为n,交叉熵损失函数定义为


image.png

最小化交叉熵损失函数等价于最⼤化训练数据集所有标签类别的联合预测概率


image.png

2.交叉熵损失函数的实现

为了得到标签的预测概率,我们可以使⽤pick函数。

    # y是两个样本的标签类别,分别是0,2
    y = nd.array([0, 2], dtype='int32')
    # y_hat是两个样本在3个类别的预测概率
    y_hat = nd.array([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
    print(nd.pick(y_hat, y))
image.png

第一个样本,0类别的预测概率是0.1;
第二个样本,2类别的预测概率是0.5.

交叉熵损失函数:

def cross_entropy(y_hat, y):
  return -nd.pick(y_hat, y).log()

参考:

动手学深度学习

上一篇下一篇

猜你喜欢

热点阅读