pytorch学习笔记-dataloader忽略异常值
2022-03-24 本文已影响0人
升不上三段的大鱼
在使用自己的数据的时候,如果希望输入的数据满足一些条件,不满足条件的数据不会用于训练,一个方法是预处理,把不满足条件的数据去掉,另一种就是重写dataloader 的 collate_fn
函数。
class DataSet():
def __init__(self, data):
self.data= data
self.visited = np.zeros(len(data)) # 用来避免重复取值
def __getitem__(self,idx):
if self.visited[idx] == 1: # 避免重复取到不想要的
return None
data = self.data[idx]
self.visited[idx] = 1
if data is None: # 这里写去掉数据的条件
return None
return data
dataset = Dataset(data)
dataloader = DataLoader(dataset , batch_size=4,
shuffle=True, num_workers=1, collate_fn = my_collate )
def my_collate(batch):
len_batch = len(batch) # original batch length
batch = list(filter(lambda x: x is not None, batch)) # filter out all the Nones
if len_batch > len(batch): # source all the required samples from the original dataset at random
diff = len_batch - len(batch)
for i in range(diff):
item = dataset[np.random.randint(0, len(dataset))]
while item is None:
item = dataset[np.random.randint(0, len(dataset))]
batch.append(item)
return torch.utils.data.dataloader.default_collate(batch)
参考: https://stackoverflow.com/questions/57815001/pytorch-collate-fn-reject-sample-and-yield-another