Generalized Focal Loss: Learning
个人感觉从理论上在实际场景下应该很有效,在自己的数据集上使用也涨了2个点。现实中标注的数据大部分不确定性都很很强、场景也更复杂。引入Generalized Focal Loss
可以有效的提升鲁棒性。
一、主要贡献
作者认为现有的密集检测器存在问题(以FCOS
为例):
![](https://img.haomeiwen.com/i10449208/a4dbafb8e57b851c.png)
1、
classification score
和 iou/centerness score
的训练过程和推理过程不一致
- 在训练过程中两个分支各自独立进行训练,但是在推理阶段,却相乘作为
nms score
,这中间必然会存在一个gap
。 - 训练阶段使用的监督信号不同,
classification score
使用正负样本进行训练;iou/centerness score
仅仅使用正样本进行训练。就很有可能引发这么一个情况:一个classification score
相对低的真正的负样本,由于预测了一个不可信的极高的centerness score
,而导致它可能排到一个真正的正样本的前面。
image.png
2、bounding box regression
的表示不够灵活
- 目前的广泛使用的方法是将
target box coordinates
视作一个Dirac delta
分布。 - 在
Gaussian YOLOV3
中将其视作一个Gaussian
分布。
但是在实际场景下,真实的分布有很强的不确定性。如下图所示,在不确定的边界处,分布比较平坦。
image.png
因此作者将classification score
和iou score
合并在一个分支,将target box coordinates
使用任意分布进行建模来表示回归框,提出了Generalized Focal Loss
,由Quality Focal Loss
和Distribution Focal Loss
组成。
-
QFL
:保留Focal Loss
平衡正负、难易样本的特性,又需要让其支持连续数值的监督。 -
DFL
:使用任意分布进行建模来表示回归框。
二、具体方法
![](https://img.haomeiwen.com/i10449208/b8d7dc2ae8cdcf05.png)
可以看出GFL
和之前方法的不同之处。
1、Focal loss
2、QFL
完整的保留了focal loss
的结构,为了支持连续值监督,将变成
,全局最小解即是
时。实际场景下测试
时,效果最好。
以下参考mmdetection
中的实现,先将所有样本都视作negative samples
计算loss
,再计算positive samples
的loss
。
- 多类别分类使用
sigmoid
-
y = iou score * one hot label
,作者称之为soft one hot label
- 损失最后除以正样本数量进行平均
@weighted_loss
def quality_focal_loss(pred, target, beta=2.0):
r"""Quality Focal Loss (QFL) is from `Generalized Focal Loss: Learning
Qualified and Distributed Bounding Boxes for Dense Object Detection
<https://arxiv.org/abs/2006.04388>`_.
Args:
pred (torch.Tensor): Predicted joint representation of classification
and quality (IoU) estimation with shape (N, C), C is the number of
classes.
target (tuple([torch.Tensor])): Target category label with shape (N,)
and target quality label with shape (N,).
beta (float): The beta parameter for calculating the modulating factor.
Defaults to 2.0.
Returns:
torch.Tensor: Loss tensor with shape (N,).
"""
assert len(target) == 2, """target for QFL must be a tuple of two elements,
including category label and quality label, respectively"""
# label denotes the category id, score denotes the quality score
label, score = target
# negatives are supervised by 0 quality score
pred_sigmoid = pred.sigmoid()
scale_factor = pred_sigmoid
zerolabel = scale_factor.new_zeros(pred.shape)
loss = F.binary_cross_entropy_with_logits(
pred, zerolabel, reduction='none') * scale_factor.pow(beta)
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes
bg_class_ind = pred.size(1)
pos = ((label >= 0) & (label < bg_class_ind)).nonzero().squeeze(1)
pos_label = label[pos].long()
# positives are supervised by bbox quality (IoU) score
scale_factor = score[pos] - pred_sigmoid[pos, pos_label]
loss[pos, pos_label] = F.binary_cross_entropy_with_logits(
pred[pos, pos_label], score[pos],
reduction='none') * scale_factor.abs().pow(beta)
loss = loss.sum(dim=1, keepdim=False)
return loss
3、DFL
考虑到真实的分布通常不会距离标注的位置太远,所以作者又额外加了个DFL
,希望网络能够快速地聚焦到标注位置附近的数值,使得他们概率尽可能大。以下为推导过程:
概率密度函数:
因此,可以理解对应的期望
根据ATSS
中的结论,易求得其值域,在COCO
数据集上是16,,因此可得下式。
使用离散值进行近似:
所以,在经过softmax
函数后,可以保证,可以进行端到端的训练。为了使网络快速地聚焦到目标位置的邻近区域的分布中,最终
DFL
为:
因为DFL
仅仅使用positive samples
进行训练,因此不存在不平衡的问题,只需要用简单的cross entroy
。
以下参考mmdetection
中的实现:
@weighted_loss
def distribution_focal_loss(pred, label):
r"""Distribution Focal Loss (DFL) is from `Generalized Focal Loss: Learning
Qualified and Distributed Bounding Boxes for Dense Object Detection
<https://arxiv.org/abs/2006.04388>`_.
Args:
pred (torch.Tensor): Predicted general distribution of bounding boxes
(before softmax) with shape (N, n+1), n is the max value of the
integral set `{0, ..., n}` in paper.
label (torch.Tensor): Target distance label for bounding boxes with
shape (N,).
Returns:
torch.Tensor: Loss tensor with shape (N,).
"""
dis_left = label.long()
dis_right = dis_left + 1
weight_left = dis_right.float() - label
weight_right = label - dis_left.float()
loss = F.cross_entropy(pred, dis_left, reduction='none') * weight_left \
+ F.cross_entropy(pred, dis_right, reduction='none') * weight_right
return loss
由于已经被离散化,但是实际
应该是一个连续值,因此作者使用了左右取整得到
和
,然后用距离进行加权,计算
DFL
。
4、GIOU
box
的损失函数使用了GIOU
,这里不再赘述。
最后可得GFL
损失函数:由QFL
、DFL
、GIOU
组成。
三、相关的实验结果
1、和不同的质量表示的方法进行对比、cls loss
替换成QFL
的涨点、系数的选择
![](https://img.haomeiwen.com/i10449208/25d82a53fd0ad804.png)
2、针对框回归不同建模方法的可视化差异
![](https://img.haomeiwen.com/i10449208/60c89df50fa93ac9.png)
3、DFL
的有效性, n的选择、n间隔的选择
![](https://img.haomeiwen.com/i10449208/6dc7de6639677939.png)
4、QFL
和DFL
的增益是正交的
![](https://img.haomeiwen.com/i10449208/7216208b4c9b7287.png)
5、和别的方法进行对比
![](https://img.haomeiwen.com/i10449208/8f740485d52692e8.png)
四、补充实验
1、不同分布的建模方法
![](https://img.haomeiwen.com/i10449208/12519077bde7691e.png)
2、相比Dirac delta
分布建模,General
更稳定更鲁棒
加入0.1的扰动,对比两种方法的误差
![](https://img.haomeiwen.com/i10449208/121304e835569840.png)
3、使用iou label
比使用centerness label
更加稳定
![](https://img.haomeiwen.com/i10449208/519ad5372305914d.png)
4、可视化图片
![](https://img.haomeiwen.com/i10449208/7f2141f137aa0a9b.png)
![](https://img.haomeiwen.com/i10449208/8284e673623f1467.png)
五、对应yolo系列中的改动
yolo系列中obj loss是正负样本参与训练,cls和box只有正样本参与训练,在推理时会取obj score和cls score的乘积判断是否是正例。因此需要将obj 和 cls进行合并,并且乘以iou,计算QFL损失;box改成DFL loss。