RetinaNet源码分析(1):anchor

2019-08-29  本文已影响0人  wctgogo

代码是RetinaNet的pytorch版本,链接为
GitHub - yhenon/pytorch-retinanet: Pytorch implementation of RetinaNet object detection.

class Anchors()

__init__:

self.pyramid_levels=[3,4,5,6,7]  # feature map的标号,分辨率从大到小
self.strides=[8,16,32,64,128]  # 滑窗的步长
self.sizes=[32,64,128,256,512]  # anchor面积:32*32等
self.ratios=[0.5,1,2]  # anchor的长宽比
self.scales=[2^0 ,2^\frac{1}{3}  ,2^\frac{2}{3}  ]  # 面积增比

forward(input_image):

input_image是原图以(608,1024)为标准等比例缩放得到的,我的原始image均为1080*1920,resize之后变为608*1056,参考代码dataloader/Resizer()。

以下为anchor生成的两个函数:

1. anchors = generate_anchors(sizes[i], ratios, scales)
    输出anchors.shape=(9,4),对应面积为sizes[i]的9个不同ratios,scales的anchor坐标
    每行为一个anchor:(x1,y1,x2,y2),中心坐标(x_c,y_c)=(0,0),故x1,y1为负,x2,y2为正

2. shifted_anchors = shift(shapes[i], strides[i], anchors)
    shapes为对应feature map的尺寸
    输出shifted_anchors.shape=(shapes[0]*shapes[1]*9, 4)
    anchor是在input_image上以strides为步长滑动生成的坐标(x2, y1, x2, y2)

后续还包括一些增加fake dimension,转化tensor,加载cuda的操作

上一篇下一篇

猜你喜欢

热点阅读