关于 torchvision.transforms.Normal

2021-01-20  本文已影响0人  星光下的胖子

先贴一段使用代码:

from torchvision import models, transforms

# 迁移学习,预训练模型
net = models.resnet18(pretrained=True)

# 标准化
normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
)

# 数据转换
image_transform = transforms.Compose([
    # 将输入图片resize成统一尺寸
    transforms.Resize([224, 224]),
    # 将PIL Image或numpy.ndarray转换为tensor,并除255归一化到[0,1]之间
    transforms.ToTensor(),
    # 标准化处理-->转换为标准正太分布,使模型更容易收敛
    normalize
])

transforms.Normalize(mean, std) 的计算公式:
input[channel] = (input[channel] - mean[channel]) / std[channel]

Normalize() 函数的作用是将数据转换为标准正太分布,使模型更容易收敛。

PyTorch 中我们经常看到 mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ,是从 ImageNet 数据集的数百万张图片中随机抽样计算得到的。

上一篇 下一篇

猜你喜欢

热点阅读