Pytorch Workflow
2019-08-08 本文已影响0人
不到15不改名
Abstract
Personal understanding of the working paradigm of training an artificial neural network (ANN) based on Pytorch.
Paradigm
一、数据(torch.utils.data.DataLoader)
-->
二、模型(torch.nn)
-->
三、策略(损失函数, criterion = torch.nn.BlaBlaLoss)+ 算法(优化算法, optimizer = torch.optim.SGD|Adam|Adadelta...)
-->
四、迭代训练
(
FOR
1. optimizer.zero_grad()
2. outputs_train = net(inputs_train)
3. loss_train = criterion(outputs_train, labels_train)
4. loss_train.backward()
5. optimizer.step()
END FOR
)
-->
五、调参/测试(验证集调参/测试集进行最后的打分)
(
with torch.no_grad():
outputs_test = net(inputs_test)
loss_test = criterion(outputs_test, labels_test)
... other test criterion ...
)
-->
六、加速(optional)
(
# Train on GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)
inputs, labels = inputs.to(device), labels .to(device)
# Data parallelism
net = nn.DataParallel(net)
)
to be continued...
References
Pytorch tutorial: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
李航老师: 《统计机器学习》