Focal Loss损失函数(pytorch实现)
2023-06-29 本文已影响0人
小黄不头秃
![](https://img.haomeiwen.com/i12824314/ecc46b9478e29d06.png)
Focal Loss 是何凯明大神提出的一个新的损失函数,其基于交叉熵损失函数做了一些修改。Focal Loss源自ICCV2017的一篇论文:Best student paper——Focal Loss for Dense Object Detection。
论文下载链接为:Lin_Focal_Loss_for_ICCV_2017_paper.pdf。
Focal Loss的提出主要是解决机器视觉领域中的样本数量不均衡的问题,它还强调了样本的难易性。当数据集中的某一类的样本过少,其训练难度也相对较大,那么Focal Loss就是为了解决这个问题。
一、Focal Loss 损失函数
首先我们看一下,交叉熵损失函数的公式:
![](https://img.haomeiwen.com/i12824314/b10ca730d606c36d.png)
这里y为真实样本的概率分布,p为预测的概率分布。这里为了简化推导我们可以重新定义pt:
![](https://img.haomeiwen.com/i12824314/6f6ab06ba59daeb1.png)
所以,上述的交叉熵损失函数就变成了,如下形式:
![](https://img.haomeiwen.com/i12824314/56c30ebcd7a268c7.png)
并且有人提出带权重的交叉熵损失函数,其公式如下:
![](https://img.haomeiwen.com/i12824314/d63af01384ab9e92.png)
这个由人为设定的虽然能够解决一定的正负样本不均衡问题,但是其还是没有办法让神经网络去区分样本的难易程度。Focal Loss认为,数据的难易程度其实是由模型来进行判断的,也就是说我们可以将模型的输出作为数据的难易程度判断的标准。于是大佬们设计出了如下的Focal Loss:
![](https://img.haomeiwen.com/i12824314/fc30bfcdaaaa0d57.png)
使用来作为难易程度的代表,并且我们可以发现当
时,Focal Loss就等于原来的交叉熵。
![](https://img.haomeiwen.com/i12824314/d846df5fb52708d4.png)
二、pytorch代码实现
"""
以二分类任务为例
"""
from torch import nn
import torch
class FocalLoss(nn.Module):
def __init__(self, gama=1.5, alpha=0.25, weight=None, reduction="mean") -> None:
super().__init__()
self.loss_fcn = torch.nn.CrossEntropyLoss(weight=weight, reduction=reduction)
self.gama = gama
self.alpha = alpha
def forward(self, pre, target):
logp = self.loss_fcn(pre, target)
p = torch.exp(-logp)
loss = (1-p)**self.gama * self.alpha * logp
return loss.mean()