@IT·互联网深度学习

深度学习第三篇---数据加载

2023-11-17  本文已影响0人  LooperJing

Pytorch的数据加载主要依赖torch.utils.data.Dataset和torch.utils.data.DataLoader两个模块,可以完成如下格式的傻瓜式加载。

train_dataset = CustomDataset(train_data_path) 
train_loader = torch.utils.data.DataLoader(train_dataset)

1 Dataset

阅读源码后,我们可以指导,继承该方法实现3个方法:
init():主要是数据格式的转换,还有一部分处理
getItem():主要是从数据集里面获取数据项的item和label。
lens():返回数据的个数

2 DataLoader

提供对Dataset的操作,操作如下:‘

torch.utils.data.DataLoader(dataset,batch_size,shuffle,drop_last,num_workers)

● dataset: 加载torch.utils.data.Dataset对象数据
● batch_size: 每个batch的大小
● shuffle:是否对数据进行打乱
● drop_last:是否对无法整除的最后一个datasize进行丢弃
● num_workers:表示加载的时候子进程数

3 案例

import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split


# 自定义数据集类
class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]  # 根据索引获取样本
        label = self.labels[idx]  # 根据索引获取标签
        return sample, label


# 创建数据集实例
data = [1, 2, 3, 4, 5]
labels = [0, 1, 0, 1, 0]
X_train, X_test, y_train, y_test = train_test_split(data, labels, test_size=0.2, random_state=42)
train_dataset = CustomDataset(X_train, y_train)
test_dataset = CustomDataset(X_test, y_test)

# 创建数据加载器
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=False)

# 迭代加载训练集数据
print("训练集")
for batch in train_dataloader:
    samples, labels = batch
    print(samples, labels)

# 迭代加载测试集数据
print("测试集")
for batch in test_dataloader:
    samples, labels = batch
    print(samples, labels)

输出:

训练集
tensor([4, 1]) tensor([1, 0])
tensor([3, 5]) tensor([0, 0])
测试集
tensor([2]) tensor([1])
上一篇 下一篇

猜你喜欢

热点阅读