pytorch学习(十七)—模型的保存与加载
2019-01-06 本文已影响0人
侠之大者_7d3f
前言
在深度学习中,模型的保存和加载很重要,当我们辛辛苦苦训练好的一个网络模型,自然需要将训练好的模型保存为文件。在测试使用时候,又需要将保存在磁盘的模型文件加载调用。
在pytorch中网络模型定义为torch.nn.Module
的子类的对象。因此模型的保存与加载涉及到2个重要概念——对象的序列化和反序列化。
目的
- 理解并掌握对象的序列化,反序列化
- 实现pytorch模型的保存与加载
开发/测试环境
- Ubuntu 18.04
- pycharm
- Anaconda3
- pytorch
- IntellJIDEA ,JDK10
对象的序列化与反序列化
序列化和反序列化听起来感觉高大上,其实是很常见的操作,下面举一个JAVA对象序列化和反序列化的例子,帮助理解。
序列化: 把对象转换为字节序列的过程称为对象的序列化。
序列化的目的:
在很多应用中,需要对某些对象进行序列化,让它们离开内存空间,入住物理硬盘,以便长期保存。比如最常见的是Web服务器中的Session对象,当有 10万用户并发访问,就有可能出现10万个Session对象,内存可能吃不消,于是Web容器就会把一些seesion先序列化到硬盘中,等要用了,再把保存在硬盘中的对象还原到内存中。
反序列化: 把字节序列恢复为对象的过程称为对象的反序列化。
当两个进程在进行远程通信时,彼此可以发送各种类型的数据。无论是何种类型的数据,都会以二进制序列的形式在网络上传送。发送方需要把这个Java对象转换为字节序列,才能在网络上传送;接收方则需要把字节序列再恢复为Java对象。 当两个进程在进行远程通信时,彼此可以发送各种类型的数据。无论是何种类型的数据,都会以二进制序列的形式在网络上传送。发送方需要把这个Java对象转换为字节序列,才能在网络上传送;接收方则需要把字节序列再恢复为Java对象。
package com.sty;
import java.io.Serializable;
/*
Java对象的序列化
实现Serializable接口
*/
public class Person implements Serializable {
private static final long serialVersionUID = -5809782578272943999L;
private int age;
private String name;
private String sex;
public int getAge() {
return age;
}
public String getName() {
return name;
}
public String getSex() {
return sex;
}
public void setAge(int age) {
this.age = age;
}
public void setSex(String sex) {
this.sex = sex;
}
public void setName(String name) {
this.name = name;
}
}
package com.sty;
import java.io.*;
//http://www.cnblogs.com/xdp-gacl/p/3777987.html
public class Main {
public static void main(String[] args) throws IOException, ClassNotFoundException {
serializePerson();
Person person = deserializePerson();
System.out.println(person);
}
/*
对象的序列化
*/
private static void serializePerson() throws IOException {
Person person = new Person();
person.setAge(25);
person.setName("LiMing");
person.setSex("male");
/*
ObjectOutputStream 对象输出流
*/
ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(new File("/home/weipenghui/Person.txt")));
objectOutputStream.writeObject(person);
System.out.println("对象序列化成功");
objectOutputStream.close();
}
/*
对象的反序列化
*/
private static Person deserializePerson() throws IOException, ClassNotFoundException {
ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream("/home/weipenghui/Person.txt"));
Person person = (Person) objectInputStream.readObject();
System.out.println("Person对象序列化成功");
return person;
}
}