Pytorch学习笔记(9) 通过DataSet、Dataset

2020-06-17  本文已影响0人  银色尘埃010

如何将我们准备好的数据放入模型中呢? Pytorch 给出的答案都在torch.utils.data 包中。

一、先看看所有的类

这个模块中方法并不多,所以让我们先全部列出来看看,看看名字猜猜功能。

二、Dataset和DatasetLoader

一般情况下,使用Dataset和DatasetLoader两个类已经可以完成大部分的数据导入。首先来看Dataset类。
在此对象中,必须重写以下两个方法。

def __getitem__(self, index)
      return  index对应的一条数据,可以是一张图,可以是一句话,总之 记住,一条数据。
     
def  __len__():
    return  带训练数据的总长度, 如果是dataframe, 直接len(df)即可,或者在init的时候传入了长度,直接返回

接下来看DataLoader 类

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)

关键的几个参数:

看看实例:
想到sklearn中提供了一些小数据集,使用鸢尾花(iris)的数据集:

def loaddata():
    iris_data = datasets.load_iris()
    return iris_data["data"], iris_data["target"]

class IrisDataset(Dataset):
    def __init__(self,irisdata,target):
        #   传入参数
        #   ndarray 类型的,可以是任何类型的
        self.irisdata = irisdata
        self.target = target
        self.lens = len(irisdata)

    def __getitem__(self, index):
        # index是方法自带的参数,获取相应的第index条数据
        return self.irisdata[index,:],self.target[index]

    def __len__(self):
        return self.lens

数据集就构架完成了,大家也可以通过DataFrame来处理数据。
然后结合DataLoader使用:

data,target = loaddata()
dataset_iris = IrisDataset(data,target)
train_loader = torch.utils.data.DataLoader(dataset_iris, batch_size=10,   shuffle=True, num_workers=4)

for i, (input, target) in enumerate(tqdm.tqdm(train_loader)):
        print(input.size())
        # 在这之后就可以进行训练了
输出

三、random_split 介绍

pytorch 中 random_split可以将实现sklearn 的 train_test_split类似的功能,大家可能注意到了,在上面的例子中只有训练数据,一般还需要有test set和valid set。
那么我们用random_split来划分数据集吧:

    data,target = loaddata()
    dataset_iris = IrisDataset(data,target)

    all_length = len(dataset_iris)
    train_size = int(0.80 * all_length)
    test_size = all_length - train_size

    train_dataset,test_dataset = torch.utils.data.random_split(dataset_iris,[train_size,test_size])

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=10, shuffle=True, num_workers=4)

    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=10, shuffle=False, num_workers=4)

到这里就已经分好了,不过还是建议先通过其他方法提前分好。为了使每次结果都相同,可以设置好seed。

上一篇 下一篇

猜你喜欢

热点阅读