PyTorch自定义数据集示例(2019-12-19)

2019-12-19  本文已影响0人  yzuzhangxr

文章结构

自定义Dataset的基本结构

from torch.utils.data.dataset import Dataset

class MyCustomDataset(Dataset):
    def __init__(self, ...):
        # 填充
        
    def __getitem__(self, index):
        # 填充
        return (img, label)

    def __len__(self):
        return count # 你有多少张图片
__init__() #函数是初始逻辑发生的地方,比如读取csv、分配转换等
__getitem__()#函数返回数据和标签。这个函数是从dataloader中被调用的,如下所示:
img, label = MyCustomDataset.__getitem__(99)  # 有99个数据

__len__()#返回你的样本数量

使用Torchvisiom进行类型转换

from torch.utils.data.dataset import Dataset
from torchvision import transforms

class MyCustomDataset(Dataset):
    def __init__(self, ..., transforms=None):
        # 填充
        #...
        self.transforms = transforms
        
    def __getitem__(self, index):
        # 填充
        #...
        data = # 从文件或者图像中读取的数据
        if self.transforms is not None:
            data = self.transforms(data)
        # 如果转换变量不是空
        # 按照传入的转换格式来转换数据
        return (img, label)

    def __len__(self):
        return count
        
if __name__ == '__main__':
    # 自定义transforms
    transformations = transforms.Compose([transforms.CenterCrop(100), transforms.ToTensor()])
    # 调用数据集
    custom_dataset = MyCustomDataset(..., transformations)

使用Torchvision的另一种方法

from torch.utils.data.dataset import Dataset
from torchvision import transforms

class MyCustomDataset(Dataset):
    def __init__(self, ...):
        # 填充
        #...
        # 单独定义转换
        self.center_crop = transforms.CenterCrop(100)
        self.to_tensor = transforms.ToTensor()
        
        # 也可以组合定义
        self.transformations = transforms.Compose([
                                transforms.CenterCrop(100),
                                transforms.ToTensor()])
        
    def __getitem__(self, index):
        # 填充
        #...
        data = # 从文件或者图像中读取的数据
        
        #对应了在__init__()中定义的三个transforms
        data = self.center_crop(data)  
        data = self.to_tensor(data)  
        data = self.trasnformations(data) 
        
        return (img, label)

    def __len__(self):
        return count 
        
if __name__ == '__main__':
    # 调用dataset
    custom_dataset = MyCustomDataset(...)

Incorporating Pandas

File Name Label Extra Operation
tr_0.png 5 TRUE
tr_1.png 0 FALSE
tr_2.png 4 FALSE
class CustomDatasetFromImages(Dataset):
    def __init__(self, csv_path):
        '''
        Args:
            csv_path (string): csv文件路径
            img_path (string): 图片文件路径
            transform: pytorch变换用于变换和张量转换
        '''
        # Transforms
        self.to_tensor = transforms.ToTensor()
        # 读取csv文件
        self.data_info = pd.read_csv(csv_path, header=None)
        # 第一列包含图像路径
        self.image_arr = np.asarray(self.data_info.iloc[:, 0])
        # 第二列是标签
        self.label_arr = np.asarray(self.data_info.iloc[:, 1])
        # 第三列是操作指示符
        self.operation_arr = np.asarray(self.data_info.iloc[:, 2])
        # 计算整个数据集的长度
        self.data_len = len(self.data_info.index)

    def __getitem__(self, index):
        # 从pandas df获取图片文件名
        single_image_name = self.image_arr[index]
        # 打开图片
        img_as_img = Image.open(single_image_name)

        # 检查是否有操作
        some_operation = self.operation_arr[index]
        # 如果有操作的话
        if some_operation:
            # 对图像做一些操作
            # ...
            # ...
            pass
        # 把图像变换成张量
        img_as_tensor = self.to_tensor(img_as_img)

        # 根据裁剪的panda列获取图像的标签
        single_image_label = self.label_arr[index]

        return (img_as_tensor, single_image_label)

    def __len__(self):
        return self.data_len

if __name__ == "__main__":
    # 调用 dataset
    custom_mnist_from_images = CustomDatasetFromImages('../data/mnist_labels.csv')

Incorporating Pandas with More Logic

Lbel pixel_1 pixel_2 ...
1 50 99 ...
0 21 223 ...
9 44 112 ...
... ... ... ...
class CustomDatasetFromCSV(Dataset):
    def __init__(self, csv_path, height, width, transforms=None):
        '''
        Args:
            csv_path (string): csv文件路径
            height (int): 图片高度
            width (int): 图片宽度
            transform: pytorch transforms for transforms and tensor conversion
        '''
        self.data = pd.read_csv(csv_path)
        self.labels = np.asarray(self.data.iloc[:, 0])
        self.height = height
        self.width = width
        self.transforms = transform

    def __getitem__(self, index):
        single_image_label = self.labels[index]
        # Read each 784 pixels and reshape the 1D array ([784]) to 2D array ([28,28]) 
        img_as_np = np.asarray(self.data.iloc[index][1:]).reshape(28,28).astype('uint8')
    # 将图像从numpy数组转换为PIL图像,模式“L”为灰度
        img_as_img = Image.fromarray(img_as_np)
        img_as_img = img_as_img.convert('L')
        # 把图像变换成tensor
        if self.transforms is not None:
            img_as_tensor = self.transforms(img_as_img)
        # 返回图片和标签
        return (img_as_tensor, single_image_label)

    def __len__(self):
        return len(self.data.index)
        

if __name__ == "__main__":
    transformations = transforms.Compose([transforms.ToTensor()])
    custom_mnist_from_csv = CustomDatasetFromCSV('../data/mnist_in_csv.csv', 28, 28, transformations)

使用Data Loader

if __name__ == "__main__":
    # 定义 transforms
    transformations = transforms.Compose([transforms.ToTensor()])
    # 定义dataset
    custom_mnist_from_csv = CustomDatasetFromCSV('../data/mnist_in_csv.csv',28, 28,transformations)
    # 定义data loader
    mn_dataset_loader = torch.utils.data.DataLoader(dataset=custom_mnist_from_csv,
                                                    batch_size=10,
                                                    shuffle=False)
    
    for images, labels in mn_dataset_loader:
        # 将数据送入模型
上一篇 下一篇

猜你喜欢

热点阅读