pytorch学习笔记-dataloader忽略异常值

2022-03-24  本文已影响0人  升不上三段的大鱼

在使用自己的数据的时候,如果希望输入的数据满足一些条件,不满足条件的数据不会用于训练,一个方法是预处理,把不满足条件的数据去掉,另一种就是重写dataloader 的 collate_fn 函数。

class DataSet():
     def __init__(self, data):  
          self.data= data
          self.visited = np.zeros(len(data))   # 用来避免重复取值
     def __getitem__(self,idx):
          if self.visited[idx] == 1:   # 避免重复取到不想要的
              return None
          data = self.data[idx]
          self.visited[idx] = 1
          if data is None:  # 这里写去掉数据的条件
              return None 
          return data

dataset = Dataset(data)

dataloader = DataLoader(dataset , batch_size=4,
                        shuffle=True, num_workers=1, collate_fn = my_collate )

def my_collate(batch): 
    len_batch = len(batch)  # original batch length
    batch = list(filter(lambda x: x is not None, batch))  # filter out all the Nones
    if len_batch > len(batch):  # source all the required samples from the original dataset at random
        diff = len_batch - len(batch)
        for i in range(diff):
            item = dataset[np.random.randint(0, len(dataset))]
            while item is None:
                item = dataset[np.random.randint(0, len(dataset))]
            batch.append(item)
    return torch.utils.data.dataloader.default_collate(batch) 

参考: https://stackoverflow.com/questions/57815001/pytorch-collate-fn-reject-sample-and-yield-another

上一篇下一篇

猜你喜欢

热点阅读