Pytorch计算Normalize的数值

2021-06-15  本文已影响0人  zeolite

以MNIST为例, 计算Normalize的均值和方差

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

batch_size = 1

transform = transforms.Compose([
    transforms.ToTensor(),
])

train_data = datasets.MNIST(
    root='.',
    train=True,
    download=True,
    transform=transform
)

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

train_len = len(train_data)

means = torch.zeros((train_len,), dtype=torch.float32)
stds = torch.zeros((train_len,), dtype=torch.float32)
for idx, (image, label) in enumerate(train_loader):
    image = torch.squeeze(image)

    means[idx] = image.mean()
    stds[idx] = image.std()

means = means.mean()
stds = stds.mean()
print(means, stds)
上一篇 下一篇

猜你喜欢

热点阅读