Faster-rcnn源码解析5

2018-05-04  本文已影响238人  haoshengup

训练fast rcnn:Stage 1 Fast R-CNN using RPN proposals, init from ImageNet model

首先设置参数:

rpn_file:上一步中(generate proposals)产生的结果,然后创建进程。来看一下进程中的函数train_fast_rcnn。

先是设置一些基本的参数,然后通过get_roidb函数获得roidb数据和imdb类。关于get_roidb函数前面已经详细讲解过。这里再来看一下,因为参数cfg.TRAIN.PROPOSAL_METHOD ='rpn'发生了变化,所以产生的结果也不同。

注意这个语句:imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD),由于参数改变了,所以执行的结果变成了:rpn_roidb(self)函数。进入rpn_roidb(self)函数来看一下:

首先利用gt_roidb函数得到gt_roidb,我们前面仔细分析过这个函数,这里不再多说了,得到的结果是一个列表,列表中的每个元素是一个字典,每个元素对应于数据库中的一张图片。

然后,利用_load_rpn_roidb函数,得到rpn_roidb:

_load_rpn_roidb函数内容很简单,首先读取保存的rpn_file文件来获取每张图片的proposals,然后将其保存为一个列表:box_list。然后利用create_roidb_from_box_list函数来返回rpn_roidb。

下面,看一下create_roidb_from_box_list函数:

我们先来看一下这个函数的返回结果roidb,是一个列表,这个列表中的元素对应于读取的图片中的每一张图片,当然和box_list中的每一个元素也是对应的。再来看列表中的元素,我们发现列表中的每一个元素都是字典,而且字典的key和gt_roidb列表中的元素的key是一样的。下面,我们就来看一下,这些key的value值是如何得到的,和gt_roidb中的有何区别。

首先检查box_list中的图片数量和读取的图片数量是否一致。然后,对图片做遍历,开始生成roidb列表中的字典元素。

首先,读取这张图片的proposals:boxes = box_list[i],这是一个num_proposals行4列的矩阵

初始化overlaps矩阵:overlaps = np.zeros((num_boxes,self.num_classes),dtype=np.float32)

接下来读取gt_roidb中元素的一些信息,并得到gt_overlaps:

gt_overlaps = bbox_overlaps(boxes.astype(np.float), gt_boxes.astype(np.float))

gt_overlaps 是计算proposals和gt_box之间的IOU,因此是一个num_proposals * num_gt_box 的矩阵

下面根据gt_overlaps 来得到overlaps:

overlaps[I, gt_classes[argmaxes[I]]] = maxes[I]

overlaps矩阵的含义是:每个proposal对应的物体,在overlaps矩阵中,每行只有一个非0元素,即:这行代表的proposal对应的物体号所在的列。而且,取值为:这个proposal和对应的gt_box的overlap。

最后,把对应的元素放在字典中,并添加到roidb列表返回。这样,我们就得到了:rpn_roidb。

回到rpn_roidb函数中,还需要把得到的rpn_roidb做进一步的处理,用到的函数是:imdb.merge_roidbs函数。

来看一下imdb.merge_roidbs函数:

在这个函数中,做的事情很简单,就是把gt_roidb列表和rpn_roidb列表中的每个元素字典的value进行了合并(这个两个字典元素的key是相同的)。换句话说,就是把gt_roidb和rpn_roidb中的数据合并在了一起,当然,列表的长度没变。最后,把合并得到的结果roidb返回。

回到get_roidb函数中,还需要对得到的roidb进行一个get_training_roidb操作,这个很熟悉了,就是对原来的图片进行翻转,然后get_roidb的长度加倍。之后在进行一个prepare_roidb操作,给列表中的字典元素增加两个key:max_classes、max_overlaps。OK,这些修饰工作完成之后,就得到了最终的roidb。

好了,回到train_fast_rcnn函数中,我们现在得到了数据:roidb,下面利用这个roidb数据建立fast_rcnn网络并进行训练:model_paths = train_net(solver, roidb, output_dir,pretrained_model=init_model,max_iters=max_iters)。

这里的solver为:stage1_fast_rcnn_solver30k40k.pt, max_iters:40000。好了,有了参数,开始进入train_net函数,和train_rpn时候一样,先过滤数据:roidb = filter_roidb(roidb),然后初始化类:

sw = SolverWrapper(solver_prototxt, roidb, output_dir, pretrained_model=pretrained_model)。

进入SolverWrapper类,在初始化函数中:

由于__C.TRAIN.BBOX_REG =True,所以需要执行if函数中内容,其实只有一个语句:

self.bbox_means,self.bbox_stds = rdl_roidb.add_bbox_regression_targets(roidb)。其实就是给roidb列表中的每个字典元素增加了一个key:bbox_targets,然后返回标准化的均值和方差。

下面来一下add_bbox_regression_targets函数:

前面的只是一些数据的读取工作,添加key:bbox_targets,是由_compute_targets函数来完成的,进入_compute_targets函数:

_compute_targets函数的流程大概是这样的:在rois = roidb[im_i]['boxes'] 矩阵中,找出属于前景的proposals,得到它们的坐标:ex_rois = rois[ex_inds, :] ; 找出ex_rois 所对应的gt_box的坐标:gt_rois = rois[gt_inds[gt_assignment], :], 然后根据gt_box计算每个proposals所需要回归的target:

targets[ex_inds,1:] = bbox_transform(ex_rois, gt_rois)。关于targets矩阵,是这么定义的:targets = np.zeros((rois.shape[0],5),dtype=np.float32),行数和传入的proposals的行数是一致的,列数:5。其中,第1列的值为:每行的proposals所对应的物体号,targets[ex_inds,0] = labels[ex_inds]。这个函数,最后返回targets。

总结一下targets,是一个rois.shape[0]行5列的矩阵,第1列是proposal对应的物体号,后4列是proposal需要回归的目标。注意一下,只有符合一定条件的前景才有回归目标,不符合条件的后4列的元素为0,也就是不需要回归。

好了,下面返回add_bbox_regression_targets函数,现在已经给每个字典元素增加了一个key:bbox_targets。这个函数后面的内容是:对刚才添加的key:bbox_targets所对应的value值按照物体号进行标准化,并返回标准化的均值means(21行4列的矩阵)和标准差stds(21行4列的矩阵)。这里有一点注意一下,返回的means和stds是经过平铺之后返回的:return means.ravel(), stds.ravel()。

好了,回到SolverWrapper类的初始化函数中,还有一个步骤需要操作:self.solver.net.layers[0].set_roidb(roidb)。这个函数我们前面已经用过一次,主要实现两个功能:第一,把roidb传入到网络中(self._roidb = roidb),第二,随机打乱self._roidb列表中的元素。

然后再一次回到train_net函数,开始训练网络并返回网络保存的文件列表:model_paths = sw.train_model(max_iters)。        在train_model这个函数中,我们知道:训练10000次保存一次网络文件,因此model_paths是一个文件列表,列表中的每一个元素都是网络文件的绝对路径。

还有一个需要注意的是snapshot函数,这一次和前面使用的时候有点不同,这一次在执行的时候,由于if的取值为True,因此需要执行if中的语句:

其实,也就是对最终预测的bbox_pred进行反标准化,因为我们的回归目标targets = roidb[im_i]['bbox_targets']是经过标准化的。

最后,来看一下fast rcnn的网络结构,主要是:stage1_fast_rcnn_train.pt文件。

上一篇下一篇

猜你喜欢

热点阅读