CenterNet(一)论文解读

2019-11-06  本文已影响0人  blackmanba_084b

2019年最火的目标检测模型就是CenterNet,其实它是基于CenterNet的基础上进行改进。在看CenterNet之前自己已经将CornerNet代码也梳理了一遍,对于立即CenterNet也是有很大的帮助的。下面我们通过论文并且结合代码来详细理解一下这个模型。(这里需要强调一点关于CenterNet有两篇论文,但是还是推荐这篇CenterNet :Objects as Points文章)

论文地址:CenterNet :Objects as Points
代码地址地址:Github CenterNet


一、 论文解读

它的基本思想就是找到目标的中心点,再基于目标中心点的基础上进行boudingbox回归.

首先我们展示一下CenterNet效果。 CenterNet效果图
该模型可以进行bbox,3D框预测,姿态,等多种任务。这里我们主要介绍的是回归框的预测。

下面我们也来展示一下其模型结构


image.png
1. 步骤

预测主要分为两个步骤

步骤 图示
1. 预测目标中心点heatmap Gussian heatmap.png
2. 基于中心点预测框的高以及宽 预测框高以及宽
2. Loss函数

Loss = Loss_{k} + Loss_{wh} + Loss_{off}

a. Loss_{k}目标中心点损失
heat_map loss.png

在研究loss之前, 我们首先来看一下heat_map是如何表示



因此我们的heatmap矩阵为


这里的W,H分别表示原图的宽与高,这里我们R设置为输出对应原图的步长。因为在特征提取模型中我们会缩放4倍,即最终生成的feature为原图的1/4。第三个通道C表示的是目标的种类数目。这里我们首先说明一下每个这里每个字符所表示的意义。\color{blue}{\hat{Y}_{xyc}}表示feature map在某个通道上上每个点预测的值[0~1]。 \color{blue}{Y_{xyz}} 为feature map某一通道上坐标点为((x, y))的真实值,要么为1(表示当前(x,y)坐标点检测到c这个类别的物体), 要么为0(表示当前这个(x,y)坐标点不存在类别为c的物体)。\color{blue}{\widetilde{p}=[\frac{p}{R}]}这里的p为我们要预测目标中心点坐标(p = (\frac{x_1 + x_2}{2}, \frac{y_1 + y_2}{2})), \color{blue}{R}表示下采样因子,这里取决于获取feature map下采样比例, 这里默认为4。接着我们会以这个点为中心点并已radius画出对应点的高斯函数\color{blue}{exp(-\frac{(x-\widetilde{p}_x)^2+(y-\widetilde{p}_y^2)^2}{2\sigma^2_p})},这里radius也是选用了ConerNet的做法进行计算的,后面会详细介绍。如果两个高斯函数出现重合的情况我们对于每个高斯函数我们选取对应重合部分最大值作为高斯的高斯函数值。最后画出的高斯函数如上图Gussian heatmap.png所示。

在说明heat_map_loss 之前让我们先了解一下什么是Focal Loss吧。Focal主要是为了解决不平衡的问题,其实很多情况下会采用OHEM(online hard example mining)方法来解决这个问题,不了解的可以参考我写的这篇文章服饰关键点论文及代码解读(keras)。OHEM 是仅将损失较大的部分反向传播,直接忽略简单样本的损失,这种直接忽略肯定也会带来一定的影响,所以 Focal Loss 将简单样本的损失降低,而不是直接忽略,可以得到更好的结果。我们知道对于二分类的交叉熵损失函数为
\color{blue}{L\_{CE}(p,y) =-ylogp - (1-y)log(1-p) = \left\{ \begin{aligned} -log(p) && 如果 y =1 \\ -log(1-p) && 如果 y = 0\\ \end{aligned} \right. }
这里\color{blue}{p = \frac{1}{1+e^{-\theta x}}}

Focal Loss相当于在此基础上加上权重系数。
\color{blue}{FL(p_t) = \left\{ \begin{aligned} -(1-p_t)^\gamma log(p_t) && 如果 y =1 \\ -(p_t)^\gamma log(1-p_t) && 如果y= 0 \\ \end{aligned} \right. }



这里的heat_map_loss可以认为是中心点预测的损失函数, 下面我们来具体说一下。
\color{blue}{L\_k = \frac{-1}{N}\sum_{xyc}\left\{ \begin{aligned} x & =(1-\hat{Y_{xyc}}) ^\alpha log(\hat{Y}_{xyc}) && 如果 Y_{xyc} = 1 \\ y & = (\hat{Y}_{xyc})^\alpha(1 - Y_{xyc})^\beta log(1-\hat{Y}_{xyc}) && otherwise\\ \end{aligned} \right. }
对于这个公式大家看起来是不是很熟悉, 其实就是Focal Loss的改进版本。

公式中\alpha以及\beta就是这个Loss函数的超参数,论文中作者分别设置为2和4。N表示每张图像I关键点数量,用于归一化。
这里作者设计的focal非常巧妙,我们来好好分析这个focal loss。这里我觉得 扔掉anchor!真正的CenterNet——Objects as Points论文解读这篇作者说的非常好,我用我的理解解释一下。

\color{red}{所以通过上面的描述可以得出两个结论:}

  1. (1-\hat{Y_{xyc}}) ^\alpha以及(\hat{Y}_{xyc})^\alpha用来限制easy example导致的gradient被easy example dominant的问题。
  2. (1 - Y_{xyc}) ^\beta则用来处理正负样本的不平衡问题(因为每一个物体只有一个实际中心点,其余的都是负样本,但是负样本相较于一个中心点显得有很多$

btw, 这里顺便提一下(1 - Y_{xyc}) ^\beta这里其实就是faster rcnn(Faster-RCNN的原理及演变)中的做正负比例平衡, 即训练过程中使positive和negative的box比例为1:3来减少negative box的比例。

b. Loss_{wh}目标大小损失

上面的loss可以帮助我们的目标中心点的定位,那么剩下的就应该确定我们的目标的大小了。\color{blue}{L_{size}= \sum^N_{k=1}|\hat{S}_{pk} - s_k|}
假设我们对于目标k种类为c_k来说它的bounding box 为(x_1^{(k)},y_1^{(k)},x_2^{(k)},y_2^{(k)}),并且通过这些点坐标我们可以得出其中心点坐标为p_k = (\frac{x_1^{(k)} + x_2^{(k)}}{2}, \frac{y_1^{(k)} + y_2^{(k)}}{2}), \hat{S}_{pk}为我们预测的值, 为了减少计算麻烦,这里我们设置\hat{S}\in R^{\frac{W}{R}\times \frac{F}{R}\times 2}对于每一个目标k size 进行回归s_k = (x_2^{(k)} - x_1^{(k)}, y_2^{(k)} -y_q^{(k)}), 这里我们使用L1损失函数。

c. Loss_{off}目标中心点偏置损失

最后我们来说一下Loss_{off}
\color{blue}{L_{off} = \frac{1}{N}\sum_p |\hat{O}_p - (\frac{p}{R}-\widetilde{p})|}
这个式子中\hat{O}_p是我们预测的偏置(也就是说我们预测出会偏移多少误差), 而(\frac{p}{R}-p)则是在训练过程中提前计算的数值
因为我们在一开始有提到过我们在进行R=4(即缩放4倍)的下采再将feature map映射到原图取整的时候会有精度上面的损失,所以这里添加一个local offset 去补充。同样这个偏置项也是用L1损失函数。举一个简单的栗子,比如说我们最终在大小为[128, 128]的feature map上面预测出的中心点为[12.34231, 77.32123], 我们会按照[12, 77]映射到原来的图片大小为[512, 512]的原图上这样会有小数点之间损失,我们需要把这些损失也需要考虑进去。

最终我们将这些loss进行汇总\color{blue}{L_{det}=L_k + \lambda_{size}L_{size} + \lambda_{off}L_{off}}
这里我们设置\lambda_{size}=0.1,\lambda_{off}=1

3. 预测阶段

上面已经详细介绍了模型以及Loss的组成, 下面我们来好好介绍下我们如何预测各种不同的结果。利用我们得出的目标中心点以及目标框大小等信息我们可以预测出对应目标的bounding box, 3D detection, 以及我们所需要的keypoints以及Human pose。

a. bounding box

这里需要提及一点的是我们对于每一个类会提取对应的feature map上的预测中心点,但是我们如何正确提取这里面的中心点的?首先我们类似做一次3\times3的Max Polling选取某个点的值比周围8个点值都大或等于的点,收集好这些点之后我们再进行一次值排序选取前100个采样点(即n=100)。(\color{red}{这里的效果可以类似于Faster RCNN中的NMS的效果, 即根据IOU进行选取候选框})
首先我们设\hat{P_c}为我们检测到某个类别的中心点,所以\hat{P} = \{(\hat{x}_i, \hat{y}_i)\}^n_{i=1}这里的n我们默认设置为100。
这里bouding box的坐标为:
(\hat{x_i} + \delta\hat{x}_i - \hat{w}_i/2, \hat{y}_i + \delta\hat{y}_i- \hat{h}_i/2,\hat{x}_i + \delta\hat{x}_i + \hat{w}_i/2, \hat{y}_i+\delta\hat{y}_i+\hat{h}_i/2)
我们设\hat{Y}_{x_i,y_i,c}作为检测目标的置信度。这里(\delta\hat{x}_i, \delta\hat{y}_i)为当前点的偏置点, (\delta\hat{x}_i, \delta\hat{y}_i)=\hat{O}\hat{x}_i\hat{y}_i为当前点对原始点的偏置点, (\hat{w}_i, \hat{h}_i)= \hat{S}\hat{x}_i\hat{y}_i代表预测出当前点对应目标的长于宽。

b. 3D detection(后续会详细说明)

3D检测其实就是来预测每张图三个维度的bounding box。这里每个中心点需要添加三个附加信息分别是depth, 3D dimension, 以及orientation

因为depth很难回归,一我们使用这个公式去计算depth,即d = 1 /\delta{(\hat{d})-1}, 这里\delta表示的是sigmoid函数。我们设置depth作为关键点评估器的另一中输出通道\hat{D}\in[0,1]^{\frac{W}{R}\times\frac{H}{R}}这里分别两个卷积层,并且卷积层之间使用ReLU函数一次。我们直接使用一个独立的head.\Gamma\in[0,1]^{\frac{W}{R}\times\frac{H}{R}\times3}方向默认是单标量的值,然而其也很难回归, 用两个bins来呈现方向,且in-bin回归。特别地,方向用8个标量值来编码的形式,每个bin有4个值。对于一个bin,两个值用作softmax分类,其余两个值回归到在每个bin中的角度。

c. Human pose estimation(后续会详细说明)

未完待续。。。
更多代码的理解可以关注我后面写的CenterNet(二) 源码解读

关于如何制作自己的训练集以及训练可以参考这篇博客(绝对详细)CenterNet训练自己的数据(pytorch0.4.1)

参考

  1. 扔掉anchor!真正的CenterNet——Objects as Points论文解读
  2. Focal Loss
  3. 超越yolov3,Centernet 原理详解(object as points)
  4. (绝对详细)CenterNet训练自己的数据(pytorch0.4.1)
上一篇下一篇

猜你喜欢

热点阅读