pytorch

6. pytorch-保存与恢复

2018-07-01  本文已影响134人  FantDing

官方序列化教程

1. 只保存参数

推荐

1.1 示例

# file: main.py
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.h1 = torch.nn.Linear(1, 10)
        self.h2 = torch.nn.Linear(10, 1)

    def forward(self, x):
        x = F.relu(self.h1(x)) 
        x = self.h2(x)
        return x

def prepare_data():
    torch.manual_seed(1)  # 保证每次生成的随机数相同
    x = torch.linspace(-1, 1, 50)
    x = torch.unsqueeze(x, 1)
    y = x ** 2 + 0.2 * torch.rand(x.size())
    return x, y

if __name__ == "__main__":
    # 1. 数据准备
    x,y=prepare_data()
    plt.scatter(x.numpy(), y.numpy())
    plt.show()
    # 2. 网络搭建
    net = Net()
    # 3. 训练
    optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
    loss_F = torch.nn.MSELoss()
    for iter in range(100):
        pred = net(x)
        loss = loss_F(pred, y)
        optimizer.zero_grad()
        loss.backward()
        print(loss.detach().numpy())
        optimizer.step()
    # 只保存网络状态
    torch.save(net.state_dict(), "./net_param.pkl")
from main import Net, prepare_data
import torch
import matplotlib.pyplot as plt

if __name__ == "__main__":
    net = Net()
    x, y = prepare_data()
    plt.scatter(x.numpy(), y.numpy())
    plt.show()
    # load是加载成dict形式
    net.load_state_dict(torch.load("net_param.pkl"))
    loss_F = torch.nn.MSELoss()
    pred = net(x)
    loss = loss_F(pred, y) # loos值与训练最后一次迭代的loss值相同
    print(loss.detach().numpy())

1.2 好处

class New_Net(torch.nn.Module): # class名字也修改了
    def __init__(self):
        super().__init__()
        self.h1 = torch.nn.Linear(1, 10)
        self.h2 = torch.nn.Linear(10, 1)

    def forward(self, x):
        x = F.tanh(self.h1(x)) # 修改了激活函数
        x = self.h2(x)
        return x

2. 保存网络结构和参数

2.1 示例

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.h1 = torch.nn.Linear(1, 10)
        self.h2 = torch.nn.Linear(10, 1)

    def forward(self, x):
        x = F.relu(self.h1(x))
        x = self.h2(x)
        return x

def prepare_data():
    torch.manual_seed(1)  # 保证每次生成的随机数相同
    x = torch.linspace(-1, 1, 50)
    x = torch.unsqueeze(x, 1)
    y = x ** 2 + 0.2 * torch.rand(x.size())
    return x, y

if __name__ == "__main__":
    # 1. 数据准备
    x,y=prepare_data()
    plt.scatter(x.numpy(), y.numpy())
    plt.show()
    # 2. 网络搭建
    net = Net()
    # 3. 训练
    optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
    loss_F = torch.nn.MSELoss()
    for iter in range(100):
        pred = net(x)
        loss = loss_F(pred, y)
        optimizer.zero_grad()
        loss.backward()
        print(loss.detach().numpy())
        optimizer.step()
    # 只保存网络状态
    torch.save(net, "./net.pkl") #直接保存net,而不是net.state_dict()
from main import Net, prepare_data
import torch
import matplotlib.pyplot as plt

if __name__ == "__main__":
    x, y = prepare_data()
    plt.scatter(x.numpy(), y.numpy())
    plt.show()
    # load是加载成dict形式
    net = torch.load("net.pkl")
    loss_F = torch.nn.MSELoss()
    pred = net(x)
    loss = loss_F(pred, y)
    print(loss.detach().numpy())

2.2 弊端

上一篇下一篇

猜你喜欢

热点阅读