大数据,机器学习,人工智能pytorch自然语言处理—学习

6 Module -庖丁解牛之pytorch

2018-10-24  本文已影响1人  readilen

Module存储了模块类的函数

pytorch中模块非常容易使用,只需要派生自Module,重载两个函数就行了,那么Module都做了什么

class Module(object):
  def __init__(self):
        self._backend = thnn_backend
        self._parameters = OrderedDict()
        self._buffers = OrderedDict()
        self._backward_hooks = OrderedDict()
        self._forward_hooks = OrderedDict()
        self._forward_pre_hooks = OrderedDict()
        self._modules = OrderedDict()
        self.training = True

构造函数生成一堆有序字典,用来存储各种参数,暂且不表,先说第一个结构self._backend是一个全局THNNFunctionBackend()类,存储一个一系列函数指针, 这个类派生类是FunctionBackend

class FunctionBackend(object):
    def __init__(self):
        self.function_classes = {}
    def register_function(self, name, function_class):
        self.function_classes[name] = function_class

其中这个类的function_classes字典的键是名称,值是函数,使用register_function添加注册,注册完毕后约有118个函数,本文的pytorch版本是0.4.1

RNN                                      <function RNN at 0x7f4330534378>
RNNTanhCell                              <function RNNTanhCell at 0x7f4330530d90>
RNNReLUCell                              <function RNNReLUCell at 0x7f43305309d8>
LSTMCell                                 <function LSTMCell at 0x7f4330530e18>
GRUCell                                  <function GRUCell at 0x7f4330530ea0>
Dropout                                  <class 'torch.nn._functions.dropout.Dropout'>
Dropout2d                                <class 'torch.nn._functions.dropout.FeatureDropout'>
Dropout3d                                <class 'torch.nn._functions.dropout.FeatureDropout'>
MarginCriterion                          <class 'torch.nn._functions.thnn.auto.MarginCriterion'>
MarginCriterionBackward                  <class 'torch.nn._functions.thnn.auto.MarginCriterionBackward'>
GatedLinear                              <class 'torch.nn._functions.thnn.auto.GatedLinear'>
GatedLinearBackward                      <class 'torch.nn._functions.thnn.auto.GatedLinearBackward'>
SpatialFullConvolutionMap                <class 'torch.nn._functions.thnn.auto.SpatialFullConvolutionMap'>
SpatialFullConvolutionMapBackward        <class 'torch.nn._functions.thnn.auto.SpatialFullConvolutionMapBackward'>
VolumetricFractionalMaxPooling           <class 'torch.nn._functions.thnn.auto.VolumetricFractionalMaxPooling'>
VolumetricFractionalMaxPoolingBackward   <class 'torch.nn._functions.thnn.auto.VolumetricFractionalMaxPoolingBackward'>
VolumetricFullDilatedConvolution         <class 'torch.nn._functions.thnn.auto.VolumetricFullDilatedConvolution'>
VolumetricFullDilatedConvolutionBackward <class 'torch.nn._functions.thnn.auto.VolumetricFullDilatedConvolutionBackward'>
Col2Im                                   <class 'torch.nn._functions.thnn.auto.Col2Im'>
Col2ImBackward                           <class 'torch.nn._functions.thnn.auto.Col2ImBackward'>
DilatedConv2d                            <class 'torch.nn._functions.thnn.auto.DilatedConv2d'>
DilatedConv2dBackward                    <class 'torch.nn._functions.thnn.auto.DilatedConv2dBackward'>
SpatialConvolutionLocal                  <class 'torch.nn._functions.thnn.auto.SpatialConvolutionLocal'>
SpatialConvolutionLocalBackward          <class 'torch.nn._functions.thnn.auto.SpatialConvolutionLocalBackward'>
FeatureLPPooling                         <class 'torch.nn._functions.thnn.auto.FeatureLPPooling'>
FeatureLPPoolingBackward                 <class 'torch.nn._functions.thnn.auto.FeatureLPPoolingBackward'>
VolumetricGridSamplerBilinear            <class 'torch.nn._functions.thnn.auto.VolumetricGridSamplerBilinear'>
VolumetricGridSamplerBilinearBackward    <class 'torch.nn._functions.thnn.auto.VolumetricGridSamplerBilinearBackward'>
TemporalUpSamplingNearest                <class 'torch.nn._functions.thnn.auto.TemporalUpSamplingNearest'>
TemporalUpSamplingNearestBackward        <class 'torch.nn._functions.thnn.auto.TemporalUpSamplingNearestBackward'>
SpatialUpSamplingNearest                 <class 'torch.nn._functions.thnn.auto.SpatialUpSamplingNearest'>
SpatialUpSamplingNearestBackward         <class 'torch.nn._functions.thnn.auto.SpatialUpSamplingNearestBackward'>
ReflectionPad1d                          <class 'torch.nn._functions.thnn.auto.ReflectionPad1d'>
ReflectionPad1dBackward                  <class 'torch.nn._functions.thnn.auto.ReflectionPad1dBackward'>
SpatialConvolutionMap                    <class 'torch.nn._functions.thnn.auto.SpatialConvolutionMap'>
SpatialConvolutionMapBackward            <class 'torch.nn._functions.thnn.auto.SpatialConvolutionMapBackward'>
NLLLoss                                  <class 'torch.nn._functions.thnn.auto.NLLLoss'>
NLLLossBackward                          <class 'torch.nn._functions.thnn.auto.NLLLossBackward'>
Softplus                                 <class 'torch.nn._functions.thnn.auto.Softplus'>
SoftplusBackward                         <class 'torch.nn._functions.thnn.auto.SoftplusBackward'>
LogSigmoid                               <class 'torch.nn._functions.thnn.auto.LogSigmoid'>
LogSigmoidBackward                       <class 'torch.nn._functions.thnn.auto.LogSigmoidBackward'>
SpatialUpSamplingBilinear                <class 'torch.nn._functions.thnn.auto.SpatialUpSamplingBilinear'>
SpatialUpSamplingBilinearBackward        <class 'torch.nn._functions.thnn.auto.SpatialUpSamplingBilinearBackward'>
ReplicationPad3d                         <class 'torch.nn._functions.thnn.auto.ReplicationPad3d'>
ReplicationPad3dBackward                 <class 'torch.nn._functions.thnn.auto.ReplicationPad3dBackward'>
MultiMarginLoss                          <class 'torch.nn._functions.thnn.auto.MultiMarginLoss'>
MultiMarginLossBackward                  <class 'torch.nn._functions.thnn.auto.MultiMarginLossBackward'>
ReplicationPad1d                         <class 'torch.nn._functions.thnn.auto.ReplicationPad1d'>
ReplicationPad1dBackward                 <class 'torch.nn._functions.thnn.auto.ReplicationPad1dBackward'>
MultiLabelMarginLoss                     <class 'torch.nn._functions.thnn.auto.MultiLabelMarginLoss'>
MultiLabelMarginLossBackward             <class 'torch.nn._functions.thnn.auto.MultiLabelMarginLossBackward'>
SpatialFullDilatedConvolution            <class 'torch.nn._functions.thnn.auto.SpatialFullDilatedConvolution'>
SpatialFullDilatedConvolutionBackward    <class 'torch.nn._functions.thnn.auto.SpatialFullDilatedConvolutionBackward'>
SoftMarginLoss                           <class 'torch.nn._functions.thnn.auto.SoftMarginLoss'>
SoftMarginLossBackward                   <class 'torch.nn._functions.thnn.auto.SoftMarginLossBackward'>
NLLLoss2d                                <class 'torch.nn._functions.thnn.auto.NLLLoss2d'>
NLLLoss2dBackward                        <class 'torch.nn._functions.thnn.auto.NLLLoss2dBackward'>
MSELoss                                  <class 'torch.nn._functions.thnn.auto.MSELoss'>
MSELossBackward                          <class 'torch.nn._functions.thnn.auto.MSELossBackward'>
Sigmoid                                  <class 'torch.nn._functions.thnn.auto.Sigmoid'>
SigmoidBackward                          <class 'torch.nn._functions.thnn.auto.SigmoidBackward'>
VolumetricUpSamplingTrilinear            <class 'torch.nn._functions.thnn.auto.VolumetricUpSamplingTrilinear'>
VolumetricUpSamplingTrilinearBackward    <class 'torch.nn._functions.thnn.auto.VolumetricUpSamplingTrilinearBackward'>
BCELoss                                  <class 'torch.nn._functions.thnn.auto.BCELoss'>
BCELossBackward                          <class 'torch.nn._functions.thnn.auto.BCELossBackward'>
Square                                   <class 'torch.nn._functions.thnn.auto.Square'>
SquareBackward                           <class 'torch.nn._functions.thnn.auto.SquareBackward'>
ReplicationPad2d                         <class 'torch.nn._functions.thnn.auto.ReplicationPad2d'>
ReplicationPad2dBackward                 <class 'torch.nn._functions.thnn.auto.ReplicationPad2dBackward'>
L1Loss                                   <class 'torch.nn._functions.thnn.auto.L1Loss'>
L1LossBackward                           <class 'torch.nn._functions.thnn.auto.L1LossBackward'>
SpatialGridSamplerBilinear               <class 'torch.nn._functions.thnn.auto.SpatialGridSamplerBilinear'>
SpatialGridSamplerBilinearBackward       <class 'torch.nn._functions.thnn.auto.SpatialGridSamplerBilinearBackward'>
Sqrt                                     <class 'torch.nn._functions.thnn.auto.Sqrt'>
SqrtBackward                             <class 'torch.nn._functions.thnn.auto.SqrtBackward'>
TemporalRowConvolution                   <class 'torch.nn._functions.thnn.auto.TemporalRowConvolution'>
TemporalRowConvolutionBackward           <class 'torch.nn._functions.thnn.auto.TemporalRowConvolutionBackward'>
SpatialFractionalMaxPooling              <class 'torch.nn._functions.thnn.auto.SpatialFractionalMaxPooling'>
SpatialFractionalMaxPoolingBackward      <class 'torch.nn._functions.thnn.auto.SpatialFractionalMaxPoolingBackward'>
TemporalUpSamplingLinear                 <class 'torch.nn._functions.thnn.auto.TemporalUpSamplingLinear'>
TemporalUpSamplingLinearBackward         <class 'torch.nn._functions.thnn.auto.TemporalUpSamplingLinearBackward'>
VolumetricDilatedMaxPooling              <class 'torch.nn._functions.thnn.auto.VolumetricDilatedMaxPooling'>
VolumetricDilatedMaxPoolingBackward      <class 'torch.nn._functions.thnn.auto.VolumetricDilatedMaxPoolingBackward'>
Threshold                                <class 'torch.nn._functions.thnn.auto.Threshold'>
ThresholdBackward                        <class 'torch.nn._functions.thnn.auto.ThresholdBackward'>
Abs                                      <class 'torch.nn._functions.thnn.auto.Abs'>
AbsBackward                              <class 'torch.nn._functions.thnn.auto.AbsBackward'>
Softshrink                               <class 'torch.nn._functions.thnn.auto.Softshrink'>
SoftshrinkBackward                       <class 'torch.nn._functions.thnn.auto.SoftshrinkBackward'>
LeakyReLU                                <class 'torch.nn._functions.thnn.auto.LeakyReLU'>
LeakyReLUBackward                        <class 'torch.nn._functions.thnn.auto.LeakyReLUBackward'>
VolumetricUpSamplingNearest              <class 'torch.nn._functions.thnn.auto.VolumetricUpSamplingNearest'>
VolumetricUpSamplingNearestBackward      <class 'torch.nn._functions.thnn.auto.VolumetricUpSamplingNearestBackward'>
VolumetricDilatedConvolution             <class 'torch.nn._functions.thnn.auto.VolumetricDilatedConvolution'>
VolumetricDilatedConvolutionBackward     <class 'torch.nn._functions.thnn.auto.VolumetricDilatedConvolutionBackward'>
Tanh                                     <class 'torch.nn._functions.thnn.auto.Tanh'>
TanhBackward                             <class 'torch.nn._functions.thnn.auto.TanhBackward'>
TemporalSubSampling                      <class 'torch.nn._functions.thnn.auto.TemporalSubSampling'>
TemporalSubSamplingBackward              <class 'torch.nn._functions.thnn.auto.TemporalSubSamplingBackward'>
ELU                                      <class 'torch.nn._functions.thnn.auto.ELU'>
ELUBackward                              <class 'torch.nn._functions.thnn.auto.ELUBackward'>
Hardtanh                                 <class 'torch.nn._functions.thnn.auto.Hardtanh'>
HardtanhBackward                         <class 'torch.nn._functions.thnn.auto.HardtanhBackward'>
L1Cost                                   <class 'torch.nn._functions.thnn.auto.L1Cost'>
L1CostBackward                           <class 'torch.nn._functions.thnn.auto.L1CostBackward'>
SpatialSubSampling                       <class 'torch.nn._functions.thnn.auto.SpatialSubSampling'>
SpatialSubSamplingBackward               <class 'torch.nn._functions.thnn.auto.SpatialSubSamplingBackward'>
Im2Col                                   <class 'torch.nn._functions.thnn.auto.Im2Col'>
Im2ColBackward                           <class 'torch.nn._functions.thnn.auto.Im2ColBackward'>
KLDivLoss                                <class 'torch.nn._functions.thnn.auto.KLDivLoss'>
KLDivLossBackward                        <class 'torch.nn._functions.thnn.auto.KLDivLossBackward'>
SmoothL1Loss                             <class 'torch.nn._functions.thnn.auto.SmoothL1Loss'>
SmoothL1LossBackward                     <class 'torch.nn._functions.thnn.auto.SmoothL1LossBackward'>
ReflectionPad2d                          <class 'torch.nn._functions.thnn.auto.ReflectionPad2d'>
ReflectionPad2dBackward                  <class 'torch.nn._functions.thnn.auto.ReflectionPad2dBackward'>
CrossMapLRN2d                            <class 'torch.nn._functions.thnn.normalization.CrossMapLRN2d'>
EmbeddingBag                             <class 'torch.nn._functions.thnn.sparse.EmbeddingBag'>

一不留神把pytorch支持的所有预定义模块都给展示出来了。本文稍后开始讲解这些预定义模块的实现。

其他有序字典

        self._parameters = OrderedDict() # 模块网络参数
        self._buffers = OrderedDict()       # 驻留内存(不释放,不交换)
        self._backward_hooks = OrderedDict() # 反向钩子函数字典,
        self._forward_hooks = OrderedDict() # 正向钩子函数字典
        self._forward_pre_hooks = OrderedDict() # 正向调用前钩子函数字典
        self._modules = OrderedDict() # 模块列表
        self.training = True # 训练还是验证

模块函数

模块的函数根据名称可以知道其作用,此处仅仅列举,不在详述

名称 作用
forward 前向计算虚函数
register_buffer 注册驻留内存
register_parameter 注册参数
add_module 添加模块
_apply 针对所有参数的操作
apply 针对所有子模块的操作
cuda 搬家到GPU上
cpu 搬家到CPU上
type 所有参数换类型喽
float 统统换成浮点
double 统统换成双精度浮点
half 统统换成字(俩字节)
to 给用户一个换类型和CGPU的接口,其实还是调用_
register_backward_hook 注册反向钩子
register_forward_pre_hook 注册前向调用前钩子
register_forward_hook 注册前向钩子
_slow_forward 没有加速的前向函数
call 给个参数就执行的前向调用
setstate 快速设置所有字典状态
getattr 获取属性
setattr 设置属性
delattr 删除属性
state_dict 当前状态字典的输出
_load_from_state_dict 从状态字典中装载的执行函数
load_state_dict 装载状态的用户接口
children 子模块
modules 所有模块
train 训练
eval 评估
zero_grad 参数梯度清零
share_memory 使用共享内存
repr 迭代器
dir 列举
上一篇 下一篇

猜你喜欢

热点阅读