Generalized Focal Loss: Learning
个人感觉从理论上在实际场景下应该很有效,在自己的数据集上使用也涨了2个点。现实中标注的数据大部分不确定性都很很强、场景也更复杂。引入Generalized Focal Loss
可以有效的提升鲁棒性。
一、主要贡献
作者认为现有的密集检测器存在问题(以FCOS
为例):
1、
classification score
和 iou/centerness score
的训练过程和推理过程不一致
- 在训练过程中两个分支各自独立进行训练,但是在推理阶段,却相乘作为
nms score
,这中间必然会存在一个gap
。 - 训练阶段使用的监督信号不同,
classification score
使用正负样本进行训练;iou/centerness score
仅仅使用正样本进行训练。就很有可能引发这么一个情况:一个classification score
相对低的真正的负样本,由于预测了一个不可信的极高的centerness score
,而导致它可能排到一个真正的正样本的前面。
image.png
2、bounding box regression
的表示不够灵活
- 目前的广泛使用的方法是将
target box coordinates
视作一个Dirac delta
分布。 - 在
Gaussian YOLOV3
中将其视作一个Gaussian
分布。
但是在实际场景下,真实的分布有很强的不确定性。如下图所示,在不确定的边界处,分布比较平坦。
image.png
因此作者将classification score
和iou score
合并在一个分支,将target box coordinates
使用任意分布进行建模来表示回归框,提出了Generalized Focal Loss
,由Quality Focal Loss
和Distribution Focal Loss
组成。
-
QFL
:保留Focal Loss
平衡正负、难易样本的特性,又需要让其支持连续数值的监督。 -
DFL
:使用任意分布进行建模来表示回归框。
二、具体方法
image.png可以看出GFL
和之前方法的不同之处。
1、Focal loss
2、QFL
完整的保留了focal loss
的结构,为了支持连续值监督,将变成,全局最小解即是时。实际场景下测试时,效果最好。
以下参考mmdetection
中的实现,先将所有样本都视作negative samples
计算loss
,再计算positive samples
的loss
。
- 多类别分类使用
sigmoid
-
y = iou score * one hot label
,作者称之为soft one hot label
- 损失最后除以正样本数量进行平均
@weighted_loss
def quality_focal_loss(pred, target, beta=2.0):
r"""Quality Focal Loss (QFL) is from `Generalized Focal Loss: Learning
Qualified and Distributed Bounding Boxes for Dense Object Detection
<https://arxiv.org/abs/2006.04388>`_.
Args:
pred (torch.Tensor): Predicted joint representation of classification
and quality (IoU) estimation with shape (N, C), C is the number of
classes.
target (tuple([torch.Tensor])): Target category label with shape (N,)
and target quality label with shape (N,).
beta (float): The beta parameter for calculating the modulating factor.
Defaults to 2.0.
Returns:
torch.Tensor: Loss tensor with shape (N,).
"""
assert len(target) == 2, """target for QFL must be a tuple of two elements,
including category label and quality label, respectively"""
# label denotes the category id, score denotes the quality score
label, score = target
# negatives are supervised by 0 quality score
pred_sigmoid = pred.sigmoid()
scale_factor = pred_sigmoid
zerolabel = scale_factor.new_zeros(pred.shape)
loss = F.binary_cross_entropy_with_logits(
pred, zerolabel, reduction='none') * scale_factor.pow(beta)
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes
bg_class_ind = pred.size(1)
pos = ((label >= 0) & (label < bg_class_ind)).nonzero().squeeze(1)
pos_label = label[pos].long()
# positives are supervised by bbox quality (IoU) score
scale_factor = score[pos] - pred_sigmoid[pos, pos_label]
loss[pos, pos_label] = F.binary_cross_entropy_with_logits(
pred[pos, pos_label], score[pos],
reduction='none') * scale_factor.abs().pow(beta)
loss = loss.sum(dim=1, keepdim=False)
return loss
3、DFL
考虑到真实的分布通常不会距离标注的位置太远,所以作者又额外加了个DFL
,希望网络能够快速地聚焦到标注位置附近的数值,使得他们概率尽可能大。以下为推导过程:
概率密度函数:
因此,可以理解对应的期望
根据ATSS
中的结论,易求得其值域,在COCO
数据集上是16,,因此可得下式。
使用离散值进行近似:
所以,在经过softmax
函数后,可以保证,可以进行端到端的训练。为了使网络快速地聚焦到目标位置的邻近区域的分布中,最终DFL
为:
因为DFL
仅仅使用positive samples
进行训练,因此不存在不平衡的问题,只需要用简单的cross entroy
。
以下参考mmdetection
中的实现:
@weighted_loss
def distribution_focal_loss(pred, label):
r"""Distribution Focal Loss (DFL) is from `Generalized Focal Loss: Learning
Qualified and Distributed Bounding Boxes for Dense Object Detection
<https://arxiv.org/abs/2006.04388>`_.
Args:
pred (torch.Tensor): Predicted general distribution of bounding boxes
(before softmax) with shape (N, n+1), n is the max value of the
integral set `{0, ..., n}` in paper.
label (torch.Tensor): Target distance label for bounding boxes with
shape (N,).
Returns:
torch.Tensor: Loss tensor with shape (N,).
"""
dis_left = label.long()
dis_right = dis_left + 1
weight_left = dis_right.float() - label
weight_right = label - dis_left.float()
loss = F.cross_entropy(pred, dis_left, reduction='none') * weight_left \
+ F.cross_entropy(pred, dis_right, reduction='none') * weight_right
return loss
由于已经被离散化,但是实际应该是一个连续值,因此作者使用了左右取整得到和,然后用距离进行加权,计算DFL
。
4、GIOU
box
的损失函数使用了GIOU
,这里不再赘述。
最后可得GFL
损失函数:由QFL
、DFL
、GIOU
组成。
三、相关的实验结果
1、和不同的质量表示的方法进行对比、cls loss
替换成QFL
的涨点、系数的选择
2、针对框回归不同建模方法的可视化差异
3、DFL
的有效性, n的选择、n间隔的选择
4、QFL
和DFL
的增益是正交的
5、和别的方法进行对比
四、补充实验
1、不同分布的建模方法
2、相比Dirac delta
分布建模,General
更稳定更鲁棒
加入0.1的扰动,对比两种方法的误差
3、使用iou label
比使用centerness label
更加稳定
4、可视化图片
image.png
五、对应yolo系列中的改动
yolo系列中obj loss是正负样本参与训练,cls和box只有正样本参与训练,在推理时会取obj score和cls score的乘积判断是否是正例。因此需要将obj 和 cls进行合并,并且乘以iou,计算QFL损失;box改成DFL loss。