生动形象的DataLoader
整理一下 PyTorch 的 DataLoader 。
先来看看官方文档:
DataLoader官方文档PyTorch 出这两个类的目的是想将数据集代码和模型训练代码分离,以获得更好的可读性和模块化。
数据集训练分离举个例子:
有一个数据集(如下表),数据集中有四句话 和 每句话所对应的标签,现在要把此数据集输入模型中,做分类任务。
sentence |
---|
再次证明了“无敌是多么寂寞”——逆天的中国乒乓球队! label: news_sports |
学计算机的已经烂大街了吗 label: news_edu |
你见过最努力的人现在都混成什么样子了? label: news_edu |
为什么茶会成为世界级的流行饮料? label: new_culture |
把每个句子分词、去除停用词(百度停用词),并放到一个列表中。
sentence |
---|
['再次', '证明', '无敌', '多么', '寂寞', '—', '—', '逆天', '中国乒乓球队'] label: news_sports |
['学', '计算机', '烂', '大街'] label: news_edu |
['见', '最', '努力', '人', '都', '混成', '样子', '?'] label: news_edu |
['茶会', '世界级', '流行', '饮料', '?'] label: new_culture |
再把句子中的每个词换成id, 标签也换成相映id
sentence |
---|
[830, 2311, 2218, 1975, 2701, 1, 1, 7509, 24627] label: 0 |
[636, 6924, 2991, 3017] label: 1 |
[511, 74, 738, 25, 26, 20052, 2140, 18] label: 1 |
[1, 8831, 1952, 3308, 18] label: 2 |
接下来,看看 torch.utils.data.Dataset 。
Dataset文档所有集成 Dataset 类的子类都要复写 __getitem__() 方法和 __len_方法。
__getitem__():返回一个样本(sample)
__len__(): 返回样本的数量(the size of the dataset)
DataLoader 可以理解为从 Dataset 中取数据,然后对数据进行处理并形成多个 tensor ,最后把 tensor 送到模型中。具体文档如下图:
image参数解释(挑选几个我认为重要的):
- dataset:从哪加载数据
- batch_size:一个batch加载几个样本(samples)
- shuffle:是否打乱数据
- sampler:采样策略
- batch_sampler:批采样方法
- num_workers:加载数据的进程数
- collate_fn:合并样本,形成一个 batch 的 tensor
- drop_last:True -- 如果最后一个样本的 size 小于 batch_size, 则丢弃该样本。 False -- 不丢弃
DataLoader 处理数据的流程如下:
image使用采样策略到 DataSet 中取数据,返回多个样本(sample),然后经过 collate_fn 合并样本,形成一个 batch 的 tensor,最后将 tensor 送到模型中。
那采样策略是什么呢?
采样策略可以简单的理解为以什么样的方式取数据。
这里举两个简单的例子:
torch.utils.data.SequentialSampler -- 顺序采样:就是按照 Dataset 样本的顺序,依次返回一个样本。
torch.utils.data.RandomSampler -- 随机采样:就是打乱 Dataset 样本的顺序,随机返回一个样本。
当然采样的方式有多种,这里只举两种采样进行对比,目的是让大家明白什么是采样策略。
DataLoader 中根据采样策略(Sampler)到 Dataset 中取数据。在图中可以看到,从 Dataset 中可以返回多个样本(sample)。其中样本是什么?为何是多个?
样本(sample)-- 抽象的可以理解为 一个 x 和对应 y 的集合。
举个具体的“样本”例子就是:
([830, 2311, 2218, 1975, 2701, 1, 1, 7509, 24627], 0) #(从上面的表格拷贝的)
将一个句子(x)以及对应的 label=0 (y) 放在一个元组(也可以是列表、字典等)中,这就是一个样本
多个样本 -- 是依据采样策略。采样策略可以返回一个样本,也可以返回多个样本。
假设返回多个 sample ,就是返回2个样本。举个具体的“2个样本”例子就是:
[([830, 2311, 2218, 1975, 2701, 1, 1, 7509, 24627], 0),
([636, 6924, 2991, 3017], 1)]
现在就需要 collate_fn 将这两个样本进行处理,并形成一个 tensor(batch_data) 。
collate_fn 里面的具体操作就是将属于 x 的部分合并为一个 tensor_x,将属于 y 的部分合并为一个 tensor_y ,最后将 tensor_x 和 tensor_y 返回。在合并 x 的时候通常要做的操作是 填充截断(截断句子长的,填充句子短的)。
经过 collate_fn 举个具体的“ tensor_x 和 tensor_y ”的例子就是:
tensor_x = tensor([[830, 2311, 2218, 1975, 2701, 1, 1, 7509, 24627],
[636, 6924, 2991, 3017, 0, 0, 0, 0, 0]])
tensor_y = tensor([0,1])
我和别人不一样的点在于举了个例子吧!
欢迎批评指正。