Pytorch Workflow

2019-08-08  本文已影响0人  不到15不改名

Abstract

Personal understanding of the working paradigm of training an artificial neural network (ANN) based on Pytorch.


Paradigm

一、数据(torch.utils.data.DataLoader)
--> 
二、模型(torch.nn)
 --> 
三、策略(损失函数, criterion = torch.nn.BlaBlaLoss)+ 算法(优化算法, optimizer = torch.optim.SGD|Adam|Adadelta...)
 --> 
四、迭代训练
(
    FOR 
        1. optimizer.zero_grad() 
        2. outputs_train = net(inputs_train) 
        3. loss_train = criterion(outputs_train, labels_train) 
        4. loss_train.backward() 
        5. optimizer.step() 
    END FOR
)
--> 
五、调参/测试(验证集调参/测试集进行最后的打分)
(
    with torch.no_grad():
        outputs_test = net(inputs_test)
        loss_test = criterion(outputs_test, labels_test)
        ... other test criterion ...
)
--> 
六、加速(optional)
(
    # Train on GPU
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    net.to(device) 
    inputs, labels = inputs.to(device), labels .to(device)
    # Data parallelism
    net = nn.DataParallel(net)
)

to be continued...


References

Pytorch tutorial: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
李航老师: 《统计机器学习》

上一篇下一篇

猜你喜欢

热点阅读