pytorch 中间层输出

2019-10-23  本文已影响0人  Janeshurmin

想获取网络的中间输出,但是尝试后,发现

所以最终决定还是直接使用list保存,代码中的修改部分如下:

    def forward(self, x):
        all_output = [] #新增
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        all_output.append(x) #新增
        x = F.relu(self.fc2(x))
        all_output.append(x) #新增
        x = self.fc3(x)
        all_output.append(x) #新增
        return x, all_output
# 训练fc
def FCN_train(lr, epochs, train_loader):
    model = FCN(28 * 28, 256, 512, 10)
    loss_func = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.5)

    step_list = []
    loss_list = []
    for epoch in range(epochs):
        for step, (x, y) in enumerate(train_loader):
            output = model(x)[0] #修改,原本是model(x)

#第一层输出 all_output[0]
#第二层输出 all_output[1]
......

参考资料

http://www.yanglajiao.com/article/LEILEI18A/80389229

上一篇 下一篇

猜你喜欢

热点阅读