语义分割损失函数

2021-04-27  本文已影响0人  blair_liu

1.交叉熵损失

交叉熵损失在pytorch中有直接的实现
一般我们搜索交叉熵损失都会搜到二分类的交叉熵损失
L=-ylogy'-(1-y)log(1-y')=\begin{cases} -\log y', & y=1 \\ -\log(1-y'), & y=0 \end{cases}
y'是经过sigmoid激活函数的输出,所以在0-1之间,对应预测样本为类别1的概率,1-y'就是对应预测样本为类别0的概率。
普通的交叉熵对于正样本而言,输出概率越大损失越小。对于负样本而言,输出概率越小则损失越小。
但上面的公式并不对应pytorch中的交叉熵损失
因为上面公式认为y'就一个值,而实际网络输出有多少类对应于有多少个y',可以说是y_i^{'}
然后交叉熵损失就是
L=-logy_i^{'}
现在看看pytorch的交叉熵损失实现
\text{loss}(x, class) = -\log\left(\frac{\exp(x[class])}{\sum_j \exp(x[j])}\right) = -x[class] + \log\left(\sum_j \exp(x[j])\right)
只看\text{loss}(x, class) = -\log\left(\frac{\exp(x[class])}{\sum_j \exp(x[j])}\right)部分
和上面L=-logy_i^{'}也不一样
其实括号里面的一堆就是y_i^{'} ,只不过是对网络输出x做了一个softmax,让其值在0-1之间,对应预测样本为类别i 的概率y_i^{'}

loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)

存在的问题:
当负样本数量太大,占总的loss的大部分,而且多是容易分类的,因此使得模型的优化方向并不是我们所希望的那样。
改进:加权重
nn.CrossEntropyLoss()有一个参数weight
\text{loss}(x, class) = weight[class] \left(-x[class] + \log\left(\sum_j \exp(x[j])\right)\right)

2.Focal Loss

Focal loss主要是为了解决正负样本比例严重失衡的问题。该损失函数降低了大量简单负样本在训练中所占的权重,也可理解为一种困难样本挖掘
L=-(1-y_i^{'})^{\gamma}logy_i^{'}
原文写作:
\mathrm{FL}\left(p_{\mathrm{t}}\right)=-\left(1-p_{\mathrm{t}}\right)^{\gamma} \log \left(p_{\mathrm{t}}\right)

Focal Loss
focal loss的两个重要性质:
1、当一个样本被分错的时候,pt是很小的
(结合公式,比如当y=1时,p<0.5才是错分类,此时pt就比较小,反之当y=0时,p>0.5是错分了),因此调制系数就趋于1,也就是说相比原来的loss是没有什么大的改变的。当pt趋于1的时候(此时分类正确而且是易分类样本),调制系数趋于0,也就是对于总的loss的贡献很小。
2、当γ=0的时候,focal loss就是传统的交叉熵损失,当 γ 增加的时候,调制系数也会增加
当然,Focal Loss也可以加权重alpha,用来平衡正负样本本身的比例不均:文中alpha取0.25,即正样本要比负样本占比小,这是因为负例易分
Focal Loss加权重
class FocalLoss(nn.Module):
    def __init__(self, gamma=2, alpha=None, ignore_index=255, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.size_average = size_average
        self.CE_loss = nn.CrossEntropyLoss(ignore_index=ignore_index, weight=alpha)

    def forward(self, output, target):
        logpt = self.CE_loss(output, target)
        pt = torch.exp(-logpt)
        loss = ((1 - pt) ** self.gamma) * logpt
        if self.size_average:
            return loss.mean()
        return loss.sum()

3.DiceLoss

Dice系数,是一种集合相似度度量函数,通常用于计算两个样本的相似度(范围为[0, 1])

dice coefficient
在语义分割中,X是Ground Truth分割图像 ,Y是预测的分割图像
由此可以得到dice loss
dice loss的提出就是解决前景比例太小的问题
dice loss
Dice系数是分割效果的一个评判指标,其公式相当于预测结果区域和ground truth区域的交并比,所以它是把一个类别的所有像素作为一个整体去计算Loss的。因为Dice Loss直接把分割效果评估指标作为Loss去监督网络,而且计算交并比时还忽略了大量背景像素,解决了正负样本不均衡的问题,所以收敛速度很快。
一个简单的流程帮助理解:
image.png
image.png
image.png
image.png
Dice Loss和交叉熵函数的比较:
交叉熵损失函数中交叉熵值梯度计算形式类似于p-t,其中,p 是softmax输出,t为target。
而关于 dice-coefficient 的可微形式,loss 值为 \frac{2pt}{p^{2}+t^{2}}\frac{2pt}{p+t} ,其关于 p 的梯度形式是比较复杂的:\frac{2t^{2}}{p^{2}+t^{2}}\frac{2 t\left(t^{2}-p^{2}\right)}{\left(p^{2}+t^{2}\right)^{2}}
极端场景下,当 p 和 t 的值都非常小时,计算得到的梯度值可能会非常大. 通常情况下,可能导致训练更加不稳定。
直接采用 dice-coefficient 或者 IoU 作为损失函数的原因,是因为分割的真实目标就是最大化 dice-coefficient 和 IoU 度量。而交叉熵仅是一种代理形式,利用其在 BP 中易于最大化优化的特点。
另外,Dice-coefficient 对于类别不均衡问题,效果可能更优。然而,类别不均衡往往可以通过简单的对于每一个类别赋予不同的 loss 因子,以使得网络能够针对性的处理某个类别出现比较频繁的情况。因此,对于 Dice-coefficient 是否真的适用于类别不均衡场景,还有待探讨。
def make_one_hot(labels, classes):
    one_hot = torch.FloatTensor(labels.size()[0], classes, labels.size()[2], labels.size()[3]).zero_().to(labels.device)
    target = one_hot.scatter_(1, labels.data, 1)
    return target


class DiceLoss(nn.Module):
    def __init__(self, smooth=1., ignore_index=255):
        super(DiceLoss, self).__init__()
        self.ignore_index = ignore_index
        self.smooth = smooth

    def forward(self, output, target):
        if self.ignore_index not in range(target.min(), target.max()):
            if (target == self.ignore_index).sum() > 0:
                target[target == self.ignore_index] = target.min()
        target = make_one_hot(target.unsqueeze(dim=1), classes=output.size()[1])
        output = F.softmax(output, dim=1)
        output_flat = output.contiguous().view(-1)
        target_flat = target.contiguous().view(-1)
        intersection = (output_flat * target_flat).sum()
        loss = 1 - ((2. * intersection + self.smooth) /
                    (output_flat.sum() + target_flat.sum() + self.smooth))
        return loss

LovaszLoss

1

参考:
https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html#torch.nn.CrossEntropyLoss
https://www.cnblogs.com/king-lps/p/9497836.html
https://zhuanlan.zhihu.com/p/57008984
https://zhuanlan.zhihu.com/p/94326225
https://blog.csdn.net/xijuezhu8128/article/details/111164936

上一篇下一篇

猜你喜欢

热点阅读