3D-R2N2代码分析

2017-07-20  本文已影响0人  ClarenceHoo

      文章地址:https://chrischoy.github.io/publication/r2n2/

      代码地址:https://github.com/chrischoy/3d-r2n2/

      整体代码基于theano框架,github上有详细的安装过程,所有代码均在virtualenv中实现,代码对应python版本应在3.5及其以上

1、该代码的主程序使用demo为直接调用已有的模型参数,进行demo的重建,

2、训练程序主要通过sh工程文件中的参数来运行

3、solver来运行程序

      代码主程序有demo.py和main.py两个部分,其中demo.py是为了直接运行已有权重,main.py则是加入了训练的部分。网络结构结构放在model中,res_gru_net为论文主要模型。其余主要库文件在lib文件夹下。本代码中的网络结构多数为自写,少量使用了theano部分代码,主要使用函数为theano.function和theano.scan。

       lib文件夹下,config.py存放全局所有调用的参变量。lib中的layer编写了每个层面的各个操作以及最优化所用的方法。

       demo流程:1、记录保存的文件名,2、load图片 3、下载模型4、netclass 通过load——model调用,load_model的返回值为两个model的object,netclass为models中的ResidualGRUNet,5、net调用ResidualGRUNet所继承的net中的load函数,通过函数enumerate(http://blog.csdn.net/churximi/article/details/51648388),读取路径文件的各个数据的索引 权重,按顺序依次存放在net.params中,6、这时调用theano.function,以net.x=demo_imgs,net.y=none为输入,以net.output net.loss为输出,其中函数调用为gru_net中的net.output和net.loss

       如何修改程序:

1、调用部分网络参数部分依然可以沿用load函数,返回的是list,而在额外初始化的部分可以继续append添加添加层的参数,

main流程:

1、附加参数初始化,

2、train_net()

3、train_net()中的main()

4、train_net()中的train_net(),读取网络类型参数值为cfg.CONST.NETWORK_CLASS,通过函数make_data_processes读取测试数据,调用solver中的train函数

上一篇下一篇

猜你喜欢

热点阅读