6. pytorch-保存与恢复
2018-07-01 本文已影响134人
FantDing
1. 只保存参数
推荐
1.1 示例
- 训练过程: main.py
# file: main.py
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.h1 = torch.nn.Linear(1, 10)
self.h2 = torch.nn.Linear(10, 1)
def forward(self, x):
x = F.relu(self.h1(x))
x = self.h2(x)
return x
def prepare_data():
torch.manual_seed(1) # 保证每次生成的随机数相同
x = torch.linspace(-1, 1, 50)
x = torch.unsqueeze(x, 1)
y = x ** 2 + 0.2 * torch.rand(x.size())
return x, y
if __name__ == "__main__":
# 1. 数据准备
x,y=prepare_data()
plt.scatter(x.numpy(), y.numpy())
plt.show()
# 2. 网络搭建
net = Net()
# 3. 训练
optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
loss_F = torch.nn.MSELoss()
for iter in range(100):
pred = net(x)
loss = loss_F(pred, y)
optimizer.zero_grad()
loss.backward()
print(loss.detach().numpy())
optimizer.step()
# 只保存网络状态
torch.save(net.state_dict(), "./net_param.pkl")
- 恢复过程: test.py
from main import Net, prepare_data
import torch
import matplotlib.pyplot as plt
if __name__ == "__main__":
net = Net()
x, y = prepare_data()
plt.scatter(x.numpy(), y.numpy())
plt.show()
# load是加载成dict形式
net.load_state_dict(torch.load("net_param.pkl"))
loss_F = torch.nn.MSELoss()
pred = net(x)
loss = loss_F(pred, y) # loos值与训练最后一次迭代的loss值相同
print(loss.detach().numpy())
1.2 好处
- 可以定义新的类。在
test.py
中可以定义新的class,forward
可以有不同的方式。只要有相同名字的参数,都可以load
成功
class New_Net(torch.nn.Module): # class名字也修改了
def __init__(self):
super().__init__()
self.h1 = torch.nn.Linear(1, 10)
self.h2 = torch.nn.Linear(10, 1)
def forward(self, x):
x = F.tanh(self.h1(x)) # 修改了激活函数
x = self.h2(x)
return x
- 加载更快
2. 保存网络结构和参数
2.1 示例
- main.py
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.h1 = torch.nn.Linear(1, 10)
self.h2 = torch.nn.Linear(10, 1)
def forward(self, x):
x = F.relu(self.h1(x))
x = self.h2(x)
return x
def prepare_data():
torch.manual_seed(1) # 保证每次生成的随机数相同
x = torch.linspace(-1, 1, 50)
x = torch.unsqueeze(x, 1)
y = x ** 2 + 0.2 * torch.rand(x.size())
return x, y
if __name__ == "__main__":
# 1. 数据准备
x,y=prepare_data()
plt.scatter(x.numpy(), y.numpy())
plt.show()
# 2. 网络搭建
net = Net()
# 3. 训练
optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
loss_F = torch.nn.MSELoss()
for iter in range(100):
pred = net(x)
loss = loss_F(pred, y)
optimizer.zero_grad()
loss.backward()
print(loss.detach().numpy())
optimizer.step()
# 只保存网络状态
torch.save(net, "./net.pkl") #直接保存net,而不是net.state_dict()
- test.py
from main import Net, prepare_data
import torch
import matplotlib.pyplot as plt
if __name__ == "__main__":
x, y = prepare_data()
plt.scatter(x.numpy(), y.numpy())
plt.show()
# load是加载成dict形式
net = torch.load("net.pkl")
loss_F = torch.nn.MSELoss()
pred = net(x)
loss = loss_F(pred, y)
print(loss.detach().numpy())
2.2 弊端
- 与特定class绑定了。即: 虽然
test.py
中的net
是通过load的来的, 但是还是需要import
训练时候的那个类Net
(否则会报错) - 不灵活。因为结构被定死了,不能定义新的层等等。