生成batch生成器的简单方法

2020-02-12  本文已影响0人  英文名字叫dawntown

关键在于yield的用法,廖雪峰老师的这篇文章解释得非常清楚详细。以下是生成batch训练训练集的简单方法:

方法一:

train_data = torch.tensor(...)

def data_iter(batch_size, train_data, train_labels):
    num_examples = len(train_data)
    indices = list(range(num_examples))
    random.shuffle(indices)  # random read 10 samples
    for i in range(0, num_examples, batch_size):
        j = torch.LongTensor(indices[i: min(i + batch_size, num_examples)]) # the last time may be not enough for a whole batch
        yield  train_data.index_select(0, j), train_labels.index_select(0, j)

方法二:

# combine featues and labels of dataset
dataset = Data.TensorDataset(features, labels)

# put dataset into DataLoader
data_iter = Data.DataLoader(
    dataset=dataset,            # torch TensorDataset format
    batch_size=batch_size,      # mini batch size
    shuffle=True,               # whether shuffle the data or not
    num_workers=2,              # read data in multithreading
)

使用方法分别为:

# 方法一
for X, y in data_iter(batch_size, train_data, train_labels):
    pass
# 方法二
for X, y in data_iter:
上一篇 下一篇

猜你喜欢

热点阅读