PyTorch

[PyTorch]专项 输出-模型存储与加载

2020-02-03  本文已影响0人  DDuncan

一、Python模块 & data

%matplotlib inline 
%config InlineBackend.figure_format = 'retina' 
import matplotlib.pyplot as plt 
import torch 
from torch import nn 
from torch import optim 
import torch.nn.functional as F 
from torchvision import datasets, transforms 

#自定义模块
import helper 
import fc_model 

#Define a transform to normalize the data 
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.5, ), (0.5, ))]) 
#Download and load the training data 
trainset = datasets.FashioniNIST('~/.pytorch/F_MNIST_data', download=True, 
                                train=True, transform=transform) 
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True) 

#Download and load the test data 
testset = datasets.FashioniNIST ('~/.pytorch/F_MNIST_data', download=True,
                                train=False, transform=transform) 
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True) 

二、建立模型 & 训练

#建立模型 自定义模块fc_model
model = fc_model.Network(784, 10, [512, 256, 128])
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

#训练模型
fc_model.tranin(model, trainloader, testloader, criterion, optimizer, epochs=2)

三、存储/加载模型

1. 存储模型(参数)

字典checkpoint:保存记录维度的信息

  1. 网络结构
  • input
  • output
  • hidden layers
  • .state_dict() 参数(weights, bias)
checkpoint = {'input_size': 784,
            'output_size': 10,
            'hidden_layers': [each.out_features for each in model.hidden_layers],
            'state_dict': model.state_dict()}
             
torch.save(checkpoint, 'checkpoint.pth')}
model.hidden_layers

注意:属性in_features, out_features

2. 加载模型(参数)

加载模型的参数必须与存储好的模型一致,否则加载错误

def load_checkpoint(filepath):
    checkpoint = torch.load(filepath) 
    model = fc_model.Network(checkpoint['input_size'],
                            checkpoint['output_size'],
                            checkpoint['hidden_layers']) #.out_features提取维度信息
    model.load_state_dict(checkpoint['state_dict'])
    return model
 
#加载模型
model = load_checkpoint('checkpoint.pth')
print (model) 
加载模型的参数
上一篇下一篇

猜你喜欢

热点阅读