深度互学习为什么使用kldivloss作为损失函数?

2021-03-23  本文已影响0人  不懂不学不问

最近在实践知识蒸馏过程中,在Pytorch中不同损失函数的作用也各有各的不同。在查看Loss源码时,发现具体的损失函数有_WeightedLoss,L1Loss,NLLLoss,NLLLoss2d,PoissonNLLLoss,KLDivLossMSELoss,HingeEmbeddingLoss,CrossEntropyLossMarginRankingLoss,CTCLoss等等类。

今天仔细研习了几种(着重);

1. NLLLoss —— log似然代价函数

The negative log likelihood loss. It is useful to train a classification problem with C classes.
If provided, the optional argument :attr:weight should be a 1D Tensor assigning weight to each of the classes. This is particularly useful when you have an unbalanced training set.

似然函数就是我们有一堆观察所得到的结果,然后我们用这堆观察结果对模型的参数进行估计。
常用于多分类任务,NLLLoss 函数输入 input 之前,需要对 input 进行 log_softmax 处理,即将 input 转换成概率分布的形式,并且取对数,底数为 e,在求取平均值。

class torch.nn.NLLLoss(weight=None, size_average=None, ignore_index=-100, 
                       reduce=None, reduction='mean')

代码解释得很详细

2. KLDivLoss —— 相对熵

Kullback-Leibler divergence_ is a useful distance measure for continuous distributions and is often useful when performing direct regression over the space of (discretely sampled) continuous output distributions.

和交叉熵一样都是熵的计算,其公式为:

image.png

信息量:它是用来衡量一个事件的不确定性的;一个事件发生的概率越大,不确定性越小,则它所携带的信息量就越小。
:它是用来衡量一个系统的混乱程度的,代表一个系统中信息量的总和;信息量总和越大,表明这个系统不确定性就越大。

    def __init__(self, size_average=None, reduce=None, reduction: str = 'mean', log_target: bool = False) -> None:
        super(KLDivLoss, self).__init__(size_average, reduce, reduction)
        self.log_target = log_target

3. CrossEntropyLoss —— 交叉熵

This criterion combines :class:~torch.nn.LogSoftmax and :class:~torch.nn.NLLLoss in one single class.

交叉熵:它主要刻画的是实际输出(概率)与期望输出(概率)的距离,也就是交叉熵的值越小,两个概率分布就越接近。
pytorch中的交叉熵不是公式(概率分布p为期望输出,q为实际输出):

image.png

而是以更加简洁的方式得到,主要是将softmax-log-NLLLoss合并到一块得到的结果,起源码中就曾写道。

image.png
   def __init__(self, weight: Optional[Tensor] = None, size_average=None, ignore_index: int = -100,
                 reduce=None, reduction: str = 'mean') -> None:
        super(CrossEntropyLoss, self).__init__(weight, size_average, reduce, reduction)
        self.ignore_index = ignore_index

通常总结来说就是:

为什么kl散度衡量了分布差异,kl散度的本质是交叉熵减信息熵,即,使用估计分布编码真实分布所需的bit数,与编码真实分布所需的最少bit数的差。当且仅当估计分布与真实分布相同时,kl散度为0。因此可以作为两个分布差异的衡量方法

上一篇 下一篇

猜你喜欢

热点阅读