构建一个全连接神经网络

2022-12-18  本文已影响0人  黑熊小李
import torch.nn as nn
from torchinfo import summary

# 构建神经网络类
class Lr_Net(nn.Module):
    def __init__(self, dim_in:int, dim_out:int, n_hidden:list):
        super(Lr_Net, self).__init__()
        for ite in n_hidden:
            assert(int == type(ite))

        self.Lr = nn.Sequential()
        for i in range(len(n_hidden)):
            if 0 != i: self.Lr.add_module('Hdd_lr_'+str(i), nn.Linear(n_hidden[i-1], n_hidden[i]))
            else:  self.Lr.add_module('In_lr',nn.Linear(dim_in, n_hidden[i]))
            self.Lr.add_module('Act_'+str(i), nn.ReLU(True))  

        self.Lr.add_module('Out_lr',nn.Linear(n_hidden[-1], dim_out))

        def init_weights(m):
            if nn.Linear == type(m):
                nn.init.normal_(m.weight.data, mean=0, std=0.1)

        self.Lr.apply(init_weights)

    def forward(self, x):
        return self.Lr(x)

# 示例
model = Lr_Net(2, 8, n_hidden=[64,32,16,16,16])
summary(model)
上一篇 下一篇

猜你喜欢

热点阅读