pytorch学习笔记深度学习目标跟踪&&目标检测

pytorch学习(十四)—孪生网络训练数据集制作

2019-01-03  本文已影响0人  侠之大者_7d3f
import torch
from torch.utils.data import Dataset,DataLoader
import linecache
import random
from PIL import Image


class MyDataset(Dataset):
    def __init__(self, txt_file, transform=None):
        self.transform = transform
        self.txt = txt_file

    def __getitem__(self, index):
        # 随机选择一个人脸
        line = linecache.getline(self.txt, random.randint(1, self.__len__()))
        line.strip('\n')
        img0_list = line.split()

        # 随机取0,1   0------不是同一个人, 1----同一个人脸
        should_get_same_class = random.randint(0, 1)

        if should_get_same_class:
            while True:
                img1_list = linecache.getline(self.txt, random.randint(1, self.__len__())).strip('\n').split()
                # 同一个人的脸
                if img0_list[1] == img1_list[1]:
                    break
        else:
            img1_list = linecache.getline(self.txt, random.randint(1, self.__len__())).strip('\n').split()

        # 加载图像
        img0 = Image.open(img0_list[0])
        img1 = Image.open(img1_list[0])

        # 变换
        if self.transform:
            img0 = self.transform(img0)
            img1 = self.transform(img1)

        # label
        label = 1 if img1_list[1] == img0_list[1] else 0

        return {'image': [img0, img1], 'label': torch.tensor(label)}

    def __len__(self):
        # 返回总行数
        num = 0
        with open(self.txt, 'r') as f:
            num = len(f.readlines())
        return num


import torch
from torch.utils.data import Dataset,DataLoader
from torchvision.transforms import transforms
from torchvision.utils import make_grid
import linecache
import random
from PIL import Image
import dataset
import matplotlib.pyplot as plt
import numpy as np


# https://www.cnblogs.com/king-lps/p/8342452.html

def show_image(sample):
    image0 = sample['image'][0]
    image1 = sample['image'][1]

    image_transform = make_grid([image0, image1], pad_value=255)
    image_transform = np.transpose(image_transform.numpy(), (1, 2, 0))
    plt.imshow(image_transform)
    plt.axis('off')
    plt.title(sample['label'].numpy())


my_dataset = dataset.MyDataset('./data/att_faces/list.txt', transform=transforms.ToTensor())

plt.figure()
for i, sample in enumerate(my_dataset):
    print(sample)
    images = sample['image']
    label = sample['label']

    image0 = images[0]
    image1 = images[1]
    # 显示
    show_image(sample)
    plt.show()


image.png image.png image.png image.png image.png image.png
上一篇 下一篇

猜你喜欢

热点阅读