基于Pytorch计算图片数据集各通道均值与方差
2021-08-16 本文已影响0人
深思海数_willschang
该方法来自于《Deep Learning with PyTorch Step by Step》一书第六章。
image.png
直接上代码
实际应用中可以根据自己项目需要再进行优化,作为常用函数。
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from copy import deepcopy
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split
from torchvision.transforms import Compose, ToTensor, Normalize, ToPILImage, Resize
from torchvision.datasets import ImageFolder
# 对数据进行预处理并通过ImageFolder进行读取处理
temp_transform = Compose([Resize(28), ToTensor()])
# 图片地址根据实际情况填写
temp_dataset = ImageFolder(root='./data/rps', transform=temp_transform)
# 构建数据加载器loader
temp_loader = DataLoader(temp_dataset, batch_size=16)
# 计算图片各通道的均值与方差
class GetChannelsNormalize():
def __init__():
pass
@staticmethod
def loader_apply(loader, func, reduce='sum'):
results = [func(x, y) for i, (x, y) in enumerate(loader)]
results = torch.stack(results, axis=0)
if reduce == 'sum':
results = results.sum(axis=0)
elif reduce == 'mean':
results = results.float().mean(axis=0)
return results
@staticmethod
def statistics_per_channel(images, labels):
# NCHW
n_samples, n_channels, n_height, n_weight = images.size()
# Flatten HW into a single dimension
flatten_per_channel = images.reshape(n_samples, n_channels, -1)
# Computes statistics of each image per channel
# Average pixel value per channel
# (n_samples, n_channels)
means = flatten_per_channel.mean(axis=2)
# Standard deviation of pixel values per channel
# (n_samples, n_channels)
stds = flatten_per_channel.std(axis=2)
# Adds up statistics of all images in a mini-batch
# (1, n_channels)
sum_means = means.sum(axis=0)
sum_stds = stds.sum(axis=0)
# Makes a tensor of shape (1, n_channels)
# with the number of samples in the mini-batch
n_samples = torch.tensor([n_samples]*n_channels).float()
# Stack the three tensors on top of one another
# (3, n_channels)
return torch.stack([n_samples, sum_means, sum_stds], axis=0)
@staticmethod
def make_normalizer(loader):
total_samples, total_means, total_stds = loader_apply(loader, statistics_per_channel)
norm_mean = total_means / total_samples
norm_std = total_stds / total_samples
return Normalize(mean=norm_mean, std=norm_std)
norm_data = GetChannelsNormalize.make_normalizer(temp_loader)
print(norm_data)
# Normalize(mean=tensor([0.8502, 0.8215, 0.8116]), std=tensor([0.2089, 0.2512, 0.2659]))