pytorch系列|3步入门篇

2018-12-27  本文已影响0人  reallocing
1. 定义网络

继承nn.Module类,实现init和forward方法.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
  super(Net,self).__init__()
  self.fc1 = nn.Linear(240,120) #输入长度240,输出长度120
  self.fc2 = nn.Linear(120,60)
  self.fc3 = nn.Linear(60,10)

def forward(self,x):
  ##定义网络中的层该怎么连接
  x = F.relu(self.fc1(x))
  x = F.relu(self.fc2(x))
  x = self.fc3(x)
  return x

net = Net()

As a rule of thumb, you can put inside the forward method all the layers that do not have any weights to be updated. On the other hand, you should put all the layers that have weights to be updated inside the __init__.

2. 加载数据

pytorch使用了DatasetDataLoader来接入数据.我们可自定义Dataset来接入自己的数据集.

import torch
import pandas as pd
from torch.utils.data import Dataset,DataLoader

class ExampleDataset(Dataset):
  def __init__(self,csv_file):
    self.data_frame = pd.read_csv(csv_file)
 def __len__(self):
  return len(self.data_frame)
def __getitem__(self,idx):
  return self.data_frame[idx]

example_dataset = ExampleDataset('my_datasets.csv')

example_data_loader = DataLoader(example_dataset,batch_size = 10,shuffle=True,num_workers=2) # num_workers: used to load the data in parallel


## Loop over data
## enumerate() return index and value.
for batch_index, batch in enumerate(example_data_loader):
  print(batch_index,batch)

3. 训练网络
import torch.optim as optim
import torch.nn as nn

## instantiate network
net = Net()
## optimizer
optimizer = optim.SGD(net.parameters(),lr=1e-3)

## define loss function
criterion = nn.MSELoss() ## nn.CrossEntropyLoss()

for epoch in range(10): ## epoches = 10
  for i,batch in enumerate(example_data_loader):
    # get the inputs
    data,targets = batch

    # zero the gradient buffes
    optimizer.zero_grad()
    ## passes the data through the network
    output = net.forward(data)
    ## calculates the loss
    loss = criterion(output,target)
    ## propagates the loss back.
    loss.backward()
    ## update all weights of the network
    optimizer.step()


更多学习资料
参考
上一篇 下一篇

猜你喜欢

热点阅读