pytorch中的顺序容器——torch.nn.Sequenti
2020-02-10 本文已影响0人
yuanCruise
1.torch.nn.Sequential概要
pytorch官网对torch.nn.Sequential的描述如下。
使用方式:
# 写法一
net = nn.Sequential(
nn.Linear(num_inputs, 1)
# 此处还可以传入其他层
)
# 写法二
net = nn.Sequential()
net.add_module('linear', nn.Linear(num_inputs, 1))
# net.add_module ......
# 写法三
from collections import OrderedDict
net = nn.Sequential(OrderedDict([
('linear', nn.Linear(num_inputs, 1))
# ......
]))
方式一:
这是一个有顺序的容器,将特定神经网络模块按照在传入构造器的顺序依次被添加到计算图中执行。
方式二:
也可以将以特定神经网络模块为元素的有序字典(OrderedDict)为参数传入。
方式三:
也可以利用add_module函数将特定的神经网络模块插入到计算图中。add_module函数是神经网络模块的基础类(torch.nn.Module)中的方法,如下描述所示用于将子模块添加到现有模块中。
2.Sequential源码分析
先看一下初始化函数init,在初始化函数中,首先是if条件判断,如果传入的参数为1个,并且类型为OrderedDict,通过字典索引的方式利用add_module函数将子模块添加到现有模块中,否则,通过for循环遍历参数,将所有的子模块添加到现有中。
由于每一个神经网络模块都继承于nn.Module,因此都会实现__call__
与forward
函数,所以forward函数中通过for循环依次调用添加到现有模块中的子模块,最后输出经过所有神经网络层的结果。
参考文献:
https://blog.csdn.net/dss_dssssd/article/details/82980222