Pytorch剪枝代码示例和注释

2020-07-08  本文已影响0人  幽并游侠儿_1425

参考这个链接,加了一些自己的注释

导入模块

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

建立模型

LeNet 1998年提出

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

检查模块

检查未修建的conv1

module = model.conv1
print(list(module.named_parameters()))

输出结果如下:
权重有6个矩阵,每个矩阵的size是3*3,偏差为6个value。

[('weight', Parameter containing:
tensor([[[[ 0.3161, -0.2212,  0.0417],
          [ 0.2488,  0.2415,  0.2071],
          [-0.2412, -0.2400, -0.2016]]],

        [[[ 0.0419,  0.3322, -0.2106],
          [ 0.1776, -0.1845, -0.3134],
          [-0.0708,  0.1921,  0.3095]]],

        [[[-0.2070,  0.0723,  0.2876],
          [ 0.2209,  0.2077,  0.2369],
          [ 0.2108,  0.0861, -0.2279]]],

        [[[-0.2799, -0.1527, -0.0388],
          [-0.2043,  0.1220,  0.1032],
          [-0.0755,  0.1281,  0.1077]]],

        [[[ 0.2035,  0.2245, -0.1129],
          [ 0.3257, -0.0385, -0.0115],
          [-0.3146, -0.2145, -0.1947]]],

        [[[-0.1426,  0.2370, -0.1089],
          [-0.2491,  0.1282,  0.1067],
          [ 0.2159, -0.1725,  0.0723]]]], device='cuda:0', requires_grad=True)), ('bias', Parameter containing:
tensor([-0.1214, -0.0749, -0.2656, -0.1519, -0.1021,  0.1425], device='cuda:0',
       requires_grad=True))]

检测是否有缓冲区

print(list(module.named_buffers()))

输出为空矩阵
parameter和buffer的区别:

模型中需要保存下来的参数包括两种:
一种是反向传播需要被optimizer更新的,称之为 parameter
一种是反向传播不需要被optimizer更新,称之为 buffer

修剪模块 part1

目标:我们将在conv1层中名为weight的参数中随机修剪 30%的连接。

  1. torch.nn.utils.prune选择修建技术
  2. 指定模块和该模块中需要修剪的参数名称
    3.使用所选修剪技术所需的适当关键字参数,指定修剪参数。
    name=weight的理解:
    之前print(list(module.named_parameters()))的输出结果是以字典形式保存的,关键字有weight
    amount=0.3的理解:剪掉百分之30的连接。
prune.random_unstructured(module, name="weight", amount=0.3)

修剪函数执行时候的内部原理:

修剪是通过从参数中删除weight并将其替换为名为weight_orig的新参数(即,将"_orig"附加到初始参数name)来进行的。 weight_orig存储未修剪的张量版本。 bias未修剪,因此它将保持完整。

print(list(module.named_parameters()))

输出如下:

[('bias', Parameter containing:
tensor([-0.1214, -0.0749, -0.2656, -0.1519, -0.1021,  0.1425], device='cuda:0',
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.3161, -0.2212,  0.0417],
          [ 0.2488,  0.2415,  0.2071],
          [-0.2412, -0.2400, -0.2016]]],

        [[[ 0.0419,  0.3322, -0.2106],
          [ 0.1776, -0.1845, -0.3134],
          [-0.0708,  0.1921,  0.3095]]],

        [[[-0.2070,  0.0723,  0.2876],
          [ 0.2209,  0.2077,  0.2369],
          [ 0.2108,  0.0861, -0.2279]]],

        [[[-0.2799, -0.1527, -0.0388],
          [-0.2043,  0.1220,  0.1032],
          [-0.0755,  0.1281,  0.1077]]],

        [[[ 0.2035,  0.2245, -0.1129],
          [ 0.3257, -0.0385, -0.0115],
          [-0.3146, -0.2145, -0.1947]]],

        [[[-0.1426,  0.2370, -0.1089],
          [-0.2491,  0.1282,  0.1067],
          [ 0.2159, -0.1725,  0.0723]]]], device='cuda:0', requires_grad=True))]

通过以上选择的修剪技术生成的修剪掩码将保存为名为weight_mask的模块缓冲区

print(list(module.named_buffers()))

输出:

[('weight_mask', tensor([[[[0., 1., 0.],
          [1., 0., 0.],
          [1., 1., 1.]]],

        [[[1., 0., 1.],
          [1., 1., 0.],
          [1., 0., 1.]]],

        [[[1., 0., 0.],
          [0., 1., 1.],
          [1., 1., 1.]]],

        [[[1., 0., 0.],
          [1., 1., 1.],
          [1., 1., 1.]]],

        [[[1., 0., 1.],
          [1., 1., 1.],
          [0., 1., 1.]]],

        [[[1., 1., 1.],
          [1., 1., 0.],
          [1., 1., 0.]]]], device='cuda:0'))]

这里需要注意:
mask里标记为0的位置对应的weight是被pruned掉的,在retrained的时候保持为0。
对比mask和weight可以发现,mask里标记为0的量对应每个3*3矩阵里weight magnitude较小的权重。

这时候打印weight,会得到掩码和原始参数结合的版本(即pruned的权重变为0)。注意这里的weight不是一个参数,只是一个属性。

print(module.weight)
tensor([[[[ 0.0000, -0.2212,  0.0000],
          [ 0.2488,  0.0000,  0.0000],
          [-0.2412, -0.2400, -0.2016]]],

        [[[ 0.0419,  0.0000, -0.2106],
          [ 0.1776, -0.1845, -0.0000],
          [-0.0708,  0.0000,  0.3095]]],

        [[[-0.2070,  0.0000,  0.0000],
          [ 0.0000,  0.2077,  0.2369],
          [ 0.2108,  0.0861, -0.2279]]],

        [[[-0.2799, -0.0000, -0.0000],
          [-0.2043,  0.1220,  0.1032],
          [-0.0755,  0.1281,  0.1077]]],

        [[[ 0.2035,  0.0000, -0.1129],
          [ 0.3257, -0.0385, -0.0115],
          [-0.0000, -0.2145, -0.1947]]],

        [[[-0.1426,  0.2370, -0.1089],
          [-0.2491,  0.1282,  0.0000],
          [ 0.2159, -0.1725,  0.0000]]]], device='cuda:0',
       grad_fn=<MulBackward0>)

剪枝需要在每次前向传播之前被应用。通过PyTorch 的forward_pre_hooks可以应用剪枝。
当模型被剪枝时,它将为与该模型关联的每个参数获取forward_pre_hook进行修剪。(注意,在这里模型不是指整个网络模型,而是指被剪枝的子模型,比如在这里是指conv1

在这种情况下,由于到目前为止我们只修剪了名称为weight的原始参数,因此只会出现一个钩子。

print(module._forward_pre_hooks)

输出为:

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7f1e6c425400>)])

修建模块 part2

为了完整起见,我们现在也可以修剪bias,以查看module的参数parameter,缓冲区buffer,挂钩hook和属性property如何变化。
在这里我们尝试另一种修剪方法,按 L1 范数修剪掉最小的3个偏差bias

prune.l1_unstructured(module, name="bias", amount=3)

预计目标:
现在,我们希望命名的参数同时包含之前的weight_orig和bias_orig。 缓冲区buffer将包括weight_mask和bias_mask。 两个张量(weight和bias)的修剪版本将作为模块属性存在,并且该模块现在将具有两个forward_pre_hooks。
实际输出:
参数:

[('weight_orig', Parameter containing:
tensor([[[[ 0.3161, -0.2212,  0.0417],
          [ 0.2488,  0.2415,  0.2071],
          [-0.2412, -0.2400, -0.2016]]],

        [[[ 0.0419,  0.3322, -0.2106],
          [ 0.1776, -0.1845, -0.3134],
          [-0.0708,  0.1921,  0.3095]]],

        [[[-0.2070,  0.0723,  0.2876],
          [ 0.2209,  0.2077,  0.2369],
          [ 0.2108,  0.0861, -0.2279]]],

        [[[-0.2799, -0.1527, -0.0388],
          [-0.2043,  0.1220,  0.1032],
          [-0.0755,  0.1281,  0.1077]]],

        [[[ 0.2035,  0.2245, -0.1129],
          [ 0.3257, -0.0385, -0.0115],
          [-0.3146, -0.2145, -0.1947]]],

        [[[-0.1426,  0.2370, -0.1089],
          [-0.2491,  0.1282,  0.1067],
          [ 0.2159, -0.1725,  0.0723]]]], device='cuda:0', requires_grad=True)), ('bias_orig', Parameter containing:
tensor([-0.1214, -0.0749, -0.2656, -0.1519, -0.1021,  0.1425], device='cuda:0',
       requires_grad=True))]

缓冲区:

[('weight_mask', tensor([[[[0., 1., 0.],
          [1., 0., 0.],
          [1., 1., 1.]]],

        [[[1., 0., 1.],
          [1., 1., 0.],
          [1., 0., 1.]]],

        [[[1., 0., 0.],
          [0., 1., 1.],
          [1., 1., 1.]]],

        [[[1., 0., 0.],
          [1., 1., 1.],
          [1., 1., 1.]]],

        [[[1., 0., 1.],
          [1., 1., 1.],
          [0., 1., 1.]]],

        [[[1., 1., 1.],
          [1., 1., 0.],
          [1., 1., 0.]]]], device='cuda:0')), ('bias_mask', tensor([0., 0., 1., 1., 0., 1.], device='cuda:0'))]

属性:

tensor([-0.0000, -0.0000, -0.2656, -0.1519, -0.0000,  0.1425], device='cuda:0',
       grad_fn=<MulBackward0>)

钩子:可以看到有两个钩子

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7f1e6c425400>), (1, <torch.nn.utils.prune.L1Unstructured object at 0x7f1e6c425550>)])

迭代修剪

一个模块中的同一参数可以被多次修剪。
暂时用不到这一块,因此不深入下去了。

序列化修剪的模型

所有相关的张量,包括掩码缓冲区和用于计算修剪的张量的原始参数,都存储在模型的state_dict中,因此可以根据需要轻松地序列化和保存。

print(model.state_dict().keys())

输出

odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])

删除剪枝重新参数化

torch.nn.utils.prune中的remove

修建模型中的多个参数

下面这段代码对不同的层采用了不同的sparsity percentage

new_model = LeNet()
for name, module in new_model.named_modules():
    # prune 20% of connections in all 2D-conv layers
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    # prune 40% of connections in all linear layers
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)
print(dict(new_model.named_buffers()).keys())  # to verify that all masks exist

全局剪枝(Global Pruning)

之前我们的剪枝方法为“局部剪枝”(local pruning)研究了通常被称为“局部”修剪的方法,即通过比较每个条目的统计信息(weight magnitude, activation, gradient, etc.)来逐一修剪模型中的张量的做法。 但是,一种常见且可能更强大的技术是通过删除(例如)删除整个模型中最低的20%的连接,而不是删除每一层中最低的 20%的连接来一次修剪模型。 这很可能导致每个层的修剪百分比不同。 让我们看看如何使用torch.nn.utils.prune中的global_unstructured进行操作。

model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

下面我们检查在每个修剪参数中引起的稀疏性,该稀疏性将不等于每层中的 20%。 但是,全局稀疏度大约为 20%。

print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100\. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Sparsity in conv2.weight: {:.2f}%".format(
        100\. * float(torch.sum(model.conv2.weight == 0))
        / float(model.conv2.weight.nelement())
    )
)
print(
    "Sparsity in fc1.weight: {:.2f}%".format(
        100\. * float(torch.sum(model.fc1.weight == 0))
        / float(model.fc1.weight.nelement())
    )
)
print(
    "Sparsity in fc2.weight: {:.2f}%".format(
        100\. * float(torch.sum(model.fc2.weight == 0))
        / float(model.fc2.weight.nelement())
    )
)
print(
    "Sparsity in fc3.weight: {:.2f}%".format(
        100\. * float(torch.sum(model.fc3.weight == 0))
        / float(model.fc3.weight.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100\. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)

使用自定义修剪功能扩展torch.nn.utils.prune

这部分暂时用不到

上一篇 下一篇

猜你喜欢

热点阅读