mmdetection源码阅读笔记(0)--创建模型
之前做天池比赛用mmdetection取得了还不错的成绩,就想仔细读读mmdetection的源码,了解下具体实现。
这个系列,准备按照目标检测和实例分割的pipeline来写。
训练脚本
官方提供了分布式训练,并且推荐使用分布式训练,即使在单机器上dist_train.sh
。
#!/usr/bin/env bash
PYTHON=${PYTHON:-"python3"}
$PYTHON -m torch.distributed.launch --nproc_per_node=$2 $(dirname "$0")/train.py $1 --launcher pytorch ${@:3}
该脚本主要使用了torch.distributed.launch
辅助启动工具,这个工具可以辅助在每个节点上启动多个进程process,支持Python2 和 Python3.
更多关于分布式训练的细节可以参考pytorch 分布式训练 distributed parallel 笔记
创建模型
train.py
的main()
函数,先做了一些config文件,work_dir以及log的操作,之后调用了build_detector()
来创建模型。
build_detector()
build_detector()
定义在mmdet/models/builder.py
中。
下面是主要用到的几个函数。
mmdet/models/builder.py
from .registry import BACKBONES, NECKS, ROI_EXTRACTORS, HEADS, DETECTORS
def build_detector(cfg, train_cfg=None, test_cfg=None):
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
在build_detector()
中有一个DETECTORS
这是一个注册器,里面保存了所有支持的detector。具体的实现方式和Python装饰器有点像。
下面以cascade_rcnn
为例,看下是怎么进行注册过来的。
- 首先在
mmdet/models/__init__.py
里面from .detectors import *
- 在
mmdet/models/detectors/__init__.py
里面from .cascade_rcnn import CascadeRCNN
- 在
mmdet/models/detectors/cascade_rcnn.py
中
from ..registry import DETECTORS
@DETECTORS.register_module
class CascadeRCNN(BaseDetector, RPNTestMixin):
other codes
用@DETECTORS.register_module
这一行代码,将CascadeRCNN
注册到了DETECTORS
中。
这里简单的说下@
的用法,Python当解释器读到@
的这样的修饰符之后,会先解析@
后的内容,直接就把@
下一行的函数或者类作为@
后边的函数的参数,然后将返回值赋值给下一行修饰的函数对象。
例如:
def a():
print("func a")
def b():
print("func b")
@a
@b
def c():
print("func c")
python会按照自下而上的顺序把各自的函数结果作为下一个函数(上面的函数)的输入,也就是a(b(c()))
回到我们的DETECTORS
,也就是上面的操作将CascadeRCNN
传给了DETECTORS.register_module
mmdet/models/registry.py
class Registry(object):
def __init__(self, name):
self._name = name
self._module_dict = dict()
def _register_module(self, module_class):
"""Register a module.
Args:
module (:obj:`nn.Module`): Module to be registered.
"""
if not issubclass(module_class, nn.Module):
raise TypeError(
'module must be a child of nn.Module, but got {}'.format(
module_class))
module_name = module_class.__name__
if module_name in self._module_dict:
raise KeyError('{} is already registered in {}'.format(
module_name, self.name))
self._module_dict[module_name] = module_class
def register_module(self, cls):
self._register_module(cls)
return cls
BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
HEADS = Registry('head')
DETECTORS = Registry('detector')
注册的模型被保存到了,self._module_dict
中。
再回到builder.py
mmdet/models/builder.py
def build(cfg, registry, default_args=None):
if isinstance(cfg, list):
modules = [_build_module(cfg_, registry, default_args) for cfg_ in cfg]
return nn.Sequential(*modules)
else:
return _build_module(cfg, registry, default_args)
def _build_module(cfg, registry, default_args):
assert isinstance(cfg, dict) and 'type' in cfg
assert isinstance(default_args, dict) or default_args is None
args = cfg.copy()
obj_type = args.pop('type')
if mmcv.is_str(obj_type):
if obj_type not in registry.module_dict:
raise KeyError('{} is not in the {} registry'.format(
obj_type, registry.name))
obj_type = registry.module_dict[obj_type]
elif not isinstance(obj_type, type):
raise TypeError('type must be a str or valid type, but got {}'.format(
type(obj_type)))
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
return obj_type(**args)
build()
中主要通过_build_module()
从registry.module_dict
中实例化注册过的模型。
最后
这篇主要讲了mmdetection中的创建模型,下一篇准备以Cascade Rcnn
为例看下网络的具体搭建。