Pytorch Tips

2019-10-18  本文已影响0人  SnorlaxSE
# 保存和加载整个模型
torch.save(model_object, 'model.pkl')
model = torch.load('model.pkl')
# 仅保存和加载模型参数(推荐使用)
torch.save(model_object.state_dict(), 'params.pkl')
model_object.load_state_dict(torch.load('params.pkl'))
try:
    train_net(net=net, epochs=args.epochs, batch_size=args.batchsize,
              lr=args.lr, gpu=args.gpu, img_scale=args.scale)
except KeyboardInterrupt:  # 用户中断执行(通常是输入^C)
    import time
    save_time = time.strftime("%Y-%m-%d-%H-%M", time.localtime())
    torch.save(net.state_dict(), '{}_INTERRUPTED.pth'.format(save_time))
    print('Saved interrupt')
    try:
        sys.exit(0)
    except SystemExit:
        os._exit(0)

将该代码添加至save_model合适的位置,可实现“Early Stopping”

上一篇 下一篇

猜你喜欢

热点阅读