Pytorch剪枝代码示例和注释
参考这个链接,加了一些自己的注释
导入模块
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
建立模型
LeNet 1998年提出
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
# 1 input image channel, 6 output channels, 3x3 square conv kernel
self.conv1 = nn.Conv2d(1, 6, 3)
self.conv2 = nn.Conv2d(6, 16, 3)
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5x5 image dimension
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, int(x.nelement() / x.shape[0]))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = LeNet().to(device=device)
检查模块
检查未修建的conv1
层
module = model.conv1
print(list(module.named_parameters()))
输出结果如下:
权重有6个矩阵,每个矩阵的size是3*3,偏差为6个value。
[('weight', Parameter containing:
tensor([[[[ 0.3161, -0.2212, 0.0417],
[ 0.2488, 0.2415, 0.2071],
[-0.2412, -0.2400, -0.2016]]],
[[[ 0.0419, 0.3322, -0.2106],
[ 0.1776, -0.1845, -0.3134],
[-0.0708, 0.1921, 0.3095]]],
[[[-0.2070, 0.0723, 0.2876],
[ 0.2209, 0.2077, 0.2369],
[ 0.2108, 0.0861, -0.2279]]],
[[[-0.2799, -0.1527, -0.0388],
[-0.2043, 0.1220, 0.1032],
[-0.0755, 0.1281, 0.1077]]],
[[[ 0.2035, 0.2245, -0.1129],
[ 0.3257, -0.0385, -0.0115],
[-0.3146, -0.2145, -0.1947]]],
[[[-0.1426, 0.2370, -0.1089],
[-0.2491, 0.1282, 0.1067],
[ 0.2159, -0.1725, 0.0723]]]], device='cuda:0', requires_grad=True)), ('bias', Parameter containing:
tensor([-0.1214, -0.0749, -0.2656, -0.1519, -0.1021, 0.1425], device='cuda:0',
requires_grad=True))]
检测是否有缓冲区
print(list(module.named_buffers()))
输出为空矩阵
parameter和buffer的区别:
模型中需要保存下来的参数包括两种:
一种是反向传播需要被optimizer更新的,称之为 parameter
一种是反向传播不需要被optimizer更新,称之为 buffer
修剪模块 part1
目标:我们将在conv1
层中名为weight
的参数中随机修剪 30%的连接。
- 从
torch.nn.utils.prune
选择修建技术 - 指定模块和该模块中需要修剪的参数名称
3.使用所选修剪技术所需的适当关键字参数,指定修剪参数。
name=weight
的理解:
之前print(list(module.named_parameters()))
的输出结果是以字典形式保存的,关键字有weight
amount=0.3
的理解:剪掉百分之30的连接。
prune.random_unstructured(module, name="weight", amount=0.3)
修剪函数执行时候的内部原理:
修剪是通过从参数中删除weight并将其替换为名为weight_orig的新参数(即,将"_orig"附加到初始参数name)来进行的。 weight_orig存储未修剪的张量版本。 bias未修剪,因此它将保持完整。
print(list(module.named_parameters()))
输出如下:
[('bias', Parameter containing:
tensor([-0.1214, -0.0749, -0.2656, -0.1519, -0.1021, 0.1425], device='cuda:0',
requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.3161, -0.2212, 0.0417],
[ 0.2488, 0.2415, 0.2071],
[-0.2412, -0.2400, -0.2016]]],
[[[ 0.0419, 0.3322, -0.2106],
[ 0.1776, -0.1845, -0.3134],
[-0.0708, 0.1921, 0.3095]]],
[[[-0.2070, 0.0723, 0.2876],
[ 0.2209, 0.2077, 0.2369],
[ 0.2108, 0.0861, -0.2279]]],
[[[-0.2799, -0.1527, -0.0388],
[-0.2043, 0.1220, 0.1032],
[-0.0755, 0.1281, 0.1077]]],
[[[ 0.2035, 0.2245, -0.1129],
[ 0.3257, -0.0385, -0.0115],
[-0.3146, -0.2145, -0.1947]]],
[[[-0.1426, 0.2370, -0.1089],
[-0.2491, 0.1282, 0.1067],
[ 0.2159, -0.1725, 0.0723]]]], device='cuda:0', requires_grad=True))]
通过以上选择的修剪技术生成的修剪掩码将保存为名为weight_mask的模块缓冲区
print(list(module.named_buffers()))
输出:
[('weight_mask', tensor([[[[0., 1., 0.],
[1., 0., 0.],
[1., 1., 1.]]],
[[[1., 0., 1.],
[1., 1., 0.],
[1., 0., 1.]]],
[[[1., 0., 0.],
[0., 1., 1.],
[1., 1., 1.]]],
[[[1., 0., 0.],
[1., 1., 1.],
[1., 1., 1.]]],
[[[1., 0., 1.],
[1., 1., 1.],
[0., 1., 1.]]],
[[[1., 1., 1.],
[1., 1., 0.],
[1., 1., 0.]]]], device='cuda:0'))]
这里需要注意:
mask里标记为0的位置对应的weight是被pruned掉的,在retrained的时候保持为0。
对比mask和weight可以发现,mask里标记为0的量对应每个3*3矩阵里weight magnitude较小的权重。
这时候打印weight
,会得到掩码和原始参数结合的版本(即pruned的权重变为0)。注意这里的weight
不是一个参数,只是一个属性。
print(module.weight)
tensor([[[[ 0.0000, -0.2212, 0.0000],
[ 0.2488, 0.0000, 0.0000],
[-0.2412, -0.2400, -0.2016]]],
[[[ 0.0419, 0.0000, -0.2106],
[ 0.1776, -0.1845, -0.0000],
[-0.0708, 0.0000, 0.3095]]],
[[[-0.2070, 0.0000, 0.0000],
[ 0.0000, 0.2077, 0.2369],
[ 0.2108, 0.0861, -0.2279]]],
[[[-0.2799, -0.0000, -0.0000],
[-0.2043, 0.1220, 0.1032],
[-0.0755, 0.1281, 0.1077]]],
[[[ 0.2035, 0.0000, -0.1129],
[ 0.3257, -0.0385, -0.0115],
[-0.0000, -0.2145, -0.1947]]],
[[[-0.1426, 0.2370, -0.1089],
[-0.2491, 0.1282, 0.0000],
[ 0.2159, -0.1725, 0.0000]]]], device='cuda:0',
grad_fn=<MulBackward0>)
剪枝需要在每次前向传播之前被应用。通过PyTorch 的forward_pre_hooks
可以应用剪枝。
当模型被剪枝时,它将为与该模型关联的每个参数获取forward_pre_hook进行修剪。(注意,在这里模型不是指整个网络模型,而是指被剪枝的子模型,比如在这里是指conv1
)
在这种情况下,由于到目前为止我们只修剪了名称为weight的原始参数,因此只会出现一个钩子。
print(module._forward_pre_hooks)
输出为:
OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7f1e6c425400>)])
修建模块 part2
为了完整起见,我们现在也可以修剪bias
,以查看module
的参数parameter,缓冲区buffer,挂钩hook和属性property如何变化。
在这里我们尝试另一种修剪方法,按 L1 范数修剪掉最小的3个偏差bias
prune.l1_unstructured(module, name="bias", amount=3)
预计目标:
现在,我们希望命名的参数同时包含之前的weight_orig和bias_orig。 缓冲区buffer将包括weight_mask和bias_mask。 两个张量(weight和bias)的修剪版本将作为模块属性存在,并且该模块现在将具有两个forward_pre_hooks。
实际输出:
参数:
[('weight_orig', Parameter containing:
tensor([[[[ 0.3161, -0.2212, 0.0417],
[ 0.2488, 0.2415, 0.2071],
[-0.2412, -0.2400, -0.2016]]],
[[[ 0.0419, 0.3322, -0.2106],
[ 0.1776, -0.1845, -0.3134],
[-0.0708, 0.1921, 0.3095]]],
[[[-0.2070, 0.0723, 0.2876],
[ 0.2209, 0.2077, 0.2369],
[ 0.2108, 0.0861, -0.2279]]],
[[[-0.2799, -0.1527, -0.0388],
[-0.2043, 0.1220, 0.1032],
[-0.0755, 0.1281, 0.1077]]],
[[[ 0.2035, 0.2245, -0.1129],
[ 0.3257, -0.0385, -0.0115],
[-0.3146, -0.2145, -0.1947]]],
[[[-0.1426, 0.2370, -0.1089],
[-0.2491, 0.1282, 0.1067],
[ 0.2159, -0.1725, 0.0723]]]], device='cuda:0', requires_grad=True)), ('bias_orig', Parameter containing:
tensor([-0.1214, -0.0749, -0.2656, -0.1519, -0.1021, 0.1425], device='cuda:0',
requires_grad=True))]
缓冲区:
[('weight_mask', tensor([[[[0., 1., 0.],
[1., 0., 0.],
[1., 1., 1.]]],
[[[1., 0., 1.],
[1., 1., 0.],
[1., 0., 1.]]],
[[[1., 0., 0.],
[0., 1., 1.],
[1., 1., 1.]]],
[[[1., 0., 0.],
[1., 1., 1.],
[1., 1., 1.]]],
[[[1., 0., 1.],
[1., 1., 1.],
[0., 1., 1.]]],
[[[1., 1., 1.],
[1., 1., 0.],
[1., 1., 0.]]]], device='cuda:0')), ('bias_mask', tensor([0., 0., 1., 1., 0., 1.], device='cuda:0'))]
属性:
tensor([-0.0000, -0.0000, -0.2656, -0.1519, -0.0000, 0.1425], device='cuda:0',
grad_fn=<MulBackward0>)
钩子:可以看到有两个钩子
OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7f1e6c425400>), (1, <torch.nn.utils.prune.L1Unstructured object at 0x7f1e6c425550>)])
迭代修剪
一个模块中的同一参数可以被多次修剪。
暂时用不到这一块,因此不深入下去了。
序列化修剪的模型
所有相关的张量,包括掩码缓冲区和用于计算修剪的张量的原始参数,都存储在模型的state_dict中,因此可以根据需要轻松地序列化和保存。
print(model.state_dict().keys())
输出
odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])
删除剪枝重新参数化
torch.nn.utils.prune
中的remove
修建模型中的多个参数
下面这段代码对不同的层采用了不同的sparsity percentage
new_model = LeNet()
for name, module in new_model.named_modules():
# prune 20% of connections in all 2D-conv layers
if isinstance(module, torch.nn.Conv2d):
prune.l1_unstructured(module, name='weight', amount=0.2)
# prune 40% of connections in all linear layers
elif isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module, name='weight', amount=0.4)
print(dict(new_model.named_buffers()).keys()) # to verify that all masks exist
全局剪枝(Global Pruning)
之前我们的剪枝方法为“局部剪枝”(local pruning)研究了通常被称为“局部”修剪的方法,即通过比较每个条目的统计信息(weight magnitude, activation, gradient, etc.)来逐一修剪模型中的张量的做法。 但是,一种常见且可能更强大的技术是通过删除(例如)删除整个模型中最低的20%的连接,而不是删除每一层中最低的 20%的连接来一次修剪模型。 这很可能导致每个层的修剪百分比不同。 让我们看看如何使用torch.nn.utils.prune
中的global_unstructured
进行操作。
model = LeNet()
parameters_to_prune = (
(model.conv1, 'weight'),
(model.conv2, 'weight'),
(model.fc1, 'weight'),
(model.fc2, 'weight'),
(model.fc3, 'weight'),
)
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.2,
)
下面我们检查在每个修剪参数中引起的稀疏性,该稀疏性将不等于每层中的 20%。 但是,全局稀疏度大约为 20%。
print(
"Sparsity in conv1.weight: {:.2f}%".format(
100\. * float(torch.sum(model.conv1.weight == 0))
/ float(model.conv1.weight.nelement())
)
)
print(
"Sparsity in conv2.weight: {:.2f}%".format(
100\. * float(torch.sum(model.conv2.weight == 0))
/ float(model.conv2.weight.nelement())
)
)
print(
"Sparsity in fc1.weight: {:.2f}%".format(
100\. * float(torch.sum(model.fc1.weight == 0))
/ float(model.fc1.weight.nelement())
)
)
print(
"Sparsity in fc2.weight: {:.2f}%".format(
100\. * float(torch.sum(model.fc2.weight == 0))
/ float(model.fc2.weight.nelement())
)
)
print(
"Sparsity in fc3.weight: {:.2f}%".format(
100\. * float(torch.sum(model.fc3.weight == 0))
/ float(model.fc3.weight.nelement())
)
)
print(
"Global sparsity: {:.2f}%".format(
100\. * float(
torch.sum(model.conv1.weight == 0)
+ torch.sum(model.conv2.weight == 0)
+ torch.sum(model.fc1.weight == 0)
+ torch.sum(model.fc2.weight == 0)
+ torch.sum(model.fc3.weight == 0)
)
/ float(
model.conv1.weight.nelement()
+ model.conv2.weight.nelement()
+ model.fc1.weight.nelement()
+ model.fc2.weight.nelement()
+ model.fc3.weight.nelement()
)
)
)
使用自定义修剪功能扩展torch.nn.utils.prune
这部分暂时用不到