SSAST模型结构
2025-05-24 本文已影响0人
小草_a484
模型架构分析
在模型剪枝的第一步就是理解模型的结构
- 先采用了如下代码查看模型的结构
for i in list(audio_model.named_parameters()):
if('bias' not in i[0]):
print(i[0])
module.v.cls_token
module.v.pos_embed
module.v.dist_token
module.v.patch_embed.proj.weight
module.v.blocks.0.norm1.weight
module.v.blocks.0.attn.qkv.weight
module.v.blocks.0.attn.proj.weight
module.v.blocks.0.norm2.weight
module.v.blocks.0.mlp.fc1.weight
module.v.blocks.0.mlp.fc2.weight
module.v.blocks.1.norm1.weight
module.v.blocks.1.attn.qkv.weight
module.v.blocks.1.attn.proj.weight
module.v.blocks.1.norm2.weight
module.v.blocks.1.mlp.fc1.weight
module.v.blocks.1.mlp.fc2.weight
module.v.blocks.2.norm1.weight
module.v.blocks.2.attn.qkv.weight
module.v.blocks.2.attn.proj.weight
module.v.blocks.2.norm2.weight
module.v.blocks.2.mlp.fc1.weight
module.v.blocks.2.mlp.fc2.weight
module.v.blocks.3.norm1.weight
module.v.blocks.3.attn.qkv.weight
module.v.blocks.3.attn.proj.weight
module.v.blocks.3.norm2.weight
module.v.blocks.3.mlp.fc1.weight
module.v.blocks.3.mlp.fc2.weight
module.v.blocks.4.norm1.weight
module.v.blocks.4.attn.qkv.weight
module.v.blocks.4.attn.proj.weight
module.v.blocks.4.norm2.weight
module.v.blocks.4.mlp.fc1.weight
module.v.blocks.4.mlp.fc2.weight
module.v.blocks.5.norm1.weight
module.v.blocks.5.attn.qkv.weight
module.v.blocks.5.attn.proj.weight
module.v.blocks.5.norm2.weight
module.v.blocks.5.mlp.fc1.weight
module.v.blocks.5.mlp.fc2.weight
module.v.blocks.6.norm1.weight
module.v.blocks.6.attn.qkv.weight
module.v.blocks.6.attn.proj.weight
module.v.blocks.6.norm2.weight
module.v.blocks.6.mlp.fc1.weight
module.v.blocks.6.mlp.fc2.weight
module.v.blocks.7.norm1.weight
module.v.blocks.7.attn.qkv.weight
module.v.blocks.7.attn.proj.weight
module.v.blocks.7.norm2.weight
module.v.blocks.7.mlp.fc1.weight
module.v.blocks.7.mlp.fc2.weight
module.v.blocks.8.norm1.weight
module.v.blocks.8.attn.qkv.weight
module.v.blocks.8.attn.proj.weight
module.v.blocks.8.norm2.weight
module.v.blocks.8.mlp.fc1.weight
module.v.blocks.8.mlp.fc2.weight
module.v.blocks.9.norm1.weight
module.v.blocks.9.attn.qkv.weight
module.v.blocks.9.attn.proj.weight
module.v.blocks.9.norm2.weight
module.v.blocks.9.mlp.fc1.weight
module.v.blocks.9.mlp.fc2.weight
module.v.blocks.10.norm1.weight
module.v.blocks.10.attn.qkv.weight
module.v.blocks.10.attn.proj.weight
module.v.blocks.10.norm2.weight
module.v.blocks.10.mlp.fc1.weight
module.v.blocks.10.mlp.fc2.weight
module.v.blocks.11.norm1.weight
module.v.blocks.11.attn.qkv.weight
module.v.blocks.11.attn.proj.weight
module.v.blocks.11.norm2.weight
module.v.blocks.11.mlp.fc1.weight
module.v.blocks.11.mlp.fc2.weight
module.v.norm.weight
module.v.head.weight
module.v.head_dist.weight
module.mlp_head.0.weight
module.mlp_head.1.weight
可以看到模型的架构就是一个VIT和一个MLP
- 采用了接下来的代码查看模型的结构
for name,layer in audio_model.named_children():
print(f"name:{name},type:{type(layer)}")
name:module,type:<class 'models.ast_models.ASTModel'>
可以看到模型结果是一个包装好的模型采用的DataParallel
- 进一步取出模型的结构
for name,layer in audio_model.module.named_children():
print(f"name:{name},type:{type(layer)}")
name:v,type:<class 'timm.models.vision_transformer.DistilledVisionTransformer'>
name:mlp_head,type:<class 'torch.nn.modules.container.Sequential'>
和前面的原理相同,即一个vit和一个container
- 再进一步取出vit
for name,layer in audio_model.module.v.named_children():
print(f"name:{name},type:{type(layer)}")
name:patch_embed,type:<class 'models.ast_models.PatchEmbed'>
name:pos_drop,type:<class 'torch.nn.modules.dropout.Dropout'>
name:blocks,type:<class 'torch.nn.modules.container.ModuleList'>
name:norm,type:<class 'torch.nn.modules.normalization.LayerNorm'>
name:pre_logits,type:<class 'torch.nn.modules.linear.Identity'>
name:head,type:<class 'torch.nn.modules.linear.Linear'>
name:head_dist,type:<class 'torch.nn.modules.linear.Linear'>
结果除了第一个embed是嵌入块是提前修改了的,其他都是torch.nn的经典模块
name:proj,type:<class 'torch.nn.modules.conv.Conv2d'>
而embed块里面只有一个卷积层
- 继续深入便利ModuleList
for name,layer in audio_model.module.v.blocks.named_children():
print(f"name:{name},type:{type(layer)}")
name:0,type:<class 'timm.models.vision_transformer.Block'>
name:1,type:<class 'timm.models.vision_transformer.Block'>
name:2,type:<class 'timm.models.vision_transformer.Block'>
name:3,type:<class 'timm.models.vision_transformer.Block'>
name:4,type:<class 'timm.models.vision_transformer.Block'>
name:5,type:<class 'timm.models.vision_transformer.Block'>
name:6,type:<class 'timm.models.vision_transformer.Block'>
name:7,type:<class 'timm.models.vision_transformer.Block'>
name:8,type:<class 'timm.models.vision_transformer.Block'>
name:9,type:<class 'timm.models.vision_transformer.Block'>
name:10,type:<class 'timm.models.vision_transformer.Block'>
name:11,type:<class 'timm.models.vision_transformer.Block'>
- 继续深入查看每个子block内部结构
for name,layer in audio_model.module.v.blocks[0].named_children():
print(f"name:{name},type:{type(layer)}")
name:norm1,type:<class 'torch.nn.modules.normalization.LayerNorm'>
name:attn,type:<class 'timm.models.vision_transformer.Attention'>
name:drop_path,type:<class 'torch.nn.modules.linear.Identity'>
name:norm2,type:<class 'torch.nn.modules.normalization.LayerNorm'>
name:mlp,type:<class 'timm.models.vision_transformer.Mlp'>
- 继续深入查看attention
for name,layer in audio_model.module.v.blocks[0].attn.named_children():
print(f"name:{name},type:{type(layer)}")
name:qkv,type:<class 'torch.nn.modules.linear.Linear'>
name:attn_drop,type:<class 'torch.nn.modules.dropout.Dropout'>
name:proj,type:<class 'torch.nn.modules.linear.Linear'>
name:proj_drop,type:<class 'torch.nn.modules.dropout.Dropout'>
- 以及查看同层级的mlp
for name,layer in audio_model.module.v.blocks[0].mlp.named_children():
print(f"name:{name},type:{type(layer)}")
name:fc1,type:<class 'torch.nn.modules.linear.Linear'>
name:act,type:<class 'torch.nn.modules.activation.GELU'>
name:fc2,type:<class 'torch.nn.modules.linear.Linear'>
name:drop,type:<class 'torch.nn.modules.dropout.Dropout'>
prune剪枝
prune是pytorch库中剪枝方法,分为非结构剪枝和结构剪枝
- 非结构剪枝
- random_unstructured
prune.random_unstructured(model.conv1,name="weight",amount=0.5)
prune.remove(model.conv1,'weight')
print(model.conv1.weight)
Parameter containing:
tensor([[[[-0.0000, -0.0000, 0.0000],
[-0.1978, -0.2261, -0.0860],
[ 0.1037, 0.0000, 0.0000]]],
[[[-0.2053, 0.0410, 0.0000],
[-0.0000, -0.0000, -0.0000],
[-0.1972, 0.3187, -0.0000]]],
[[[-0.0000, -0.0000, -0.1223],
[ 0.1347, 0.3205, 0.0000],
[-0.0682, -0.0000, 0.2382]]]], device='cuda:0', requires_grad=True)
- l1_unstructured
prune.l1_unstructured(model.conv1,name="weight",amount=0.5)
prune.remove(model.conv1,'weight')
print(model.conv1.weight)
Parameter containing:
tensor([[[[-0.3320, 0.2106, 0.3279],
[-0.3195, 0.3222, 0.0000],
[-0.0000, -0.2038, -0.0000]]],
[[[-0.0000, 0.3175, 0.0000],
[-0.0000, -0.2928, -0.0000],
[ 0.0000, -0.0000, -0.0000]]],
[[[ 0.2634, -0.2414, 0.0000],
[ 0.0000, -0.3128, 0.0000],
[-0.2402, -0.2144, 0.0000]]]], device='cuda:0', requires_grad=True)
- 结构剪枝
- random_structured
prune.random_structured(model.conv1,name="weight",amount=0.33,dim=2)
prune.remove(model.conv1,'weight')
print(model.conv1.weight)
Parameter containing:
tensor([[[[ 0.0048, 0.1459, -0.0502],
[-0.0000, -0.0000, 0.0000],
[ 0.2687, -0.1137, 0.1034]]],
[[[ 0.1801, -0.2711, -0.0819],
[ 0.0000, 0.0000, 0.0000],
[-0.2404, -0.3188, 0.3194]]],
[[[ 0.0434, -0.0618, 0.0368],
[ 0.0000, 0.0000, 0.0000],
[ 0.1729, 0.2978, -0.3020]]]], device='cuda:0', requires_grad=True)
- ln_structured
prune.ln_structured(model.conv1,name="weight",amount=0.33,n=2,dim=2)
prune.remove(model.conv1,'weight')
print(model.conv1.weight)
Parameter containing:
tensor([[[[-0.1829, 0.2475, 0.2816],
[ 0.0694, -0.1366, 0.0740],
[-0.0000, -0.0000, 0.0000]]],
[[[-0.2966, -0.2881, 0.2974],
[ 0.3074, 0.2858, 0.1990],
[-0.0000, 0.0000, -0.0000]]],
[[[-0.0558, -0.3072, -0.0674],
[-0.0860, 0.2881, 0.1865],
[ 0.0000, 0.0000, 0.0000]]]], device='cuda:0', requires_grad=True)
模型性能指标统计工具
参数量计算
- 理论上计算方式:
- 卷积层:param=out_channels(in_channelskernel_size^2)+out_channels
- 归一化层:param=2*out_channels
- 全连接层:param=in_feature*out_feature+out_features
- 其它层:无参数
- 使用工具thop
统计Flops和参数量,一个mac等于两个Flops
from thop import profile
inp=torch.randn(48,512,128).to('cuda')
macs,params=profile(audio_model,inputs=(inp,))
print(f"MACs {macs}, Parameters: {params}")
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
MACs 2462153674752.0, Parameters: 85258757.0
| SSAST-Base-Patch-400 | 比例 | 参数量 | 准确率 | 推理速度 | mac |
|---|---|---|---|---|---|
| 0 | 85258757 | 87.61 | 1.104047 | 2462153674752 | |
| 10 | 85258757 | 2462153674752 | |||
| 30 | 85258757 | 87.60 | 1.096145 | 2462153674752 |
ASTModel(
(v): DistilledVisionTransformer(
(patch_embed): PatchEmbed(
(proj): Conv2d(1, 768, kernel_size=(16, 16), stride=(10, 10)) => (proj): Conv2d(1, 384, kernel_size=(16, 16), stride=(10, 10))
)
(pos_drop): Dropout(p=0.0, inplace=False)
(blocks): ModuleList(
(0-11): 12 x Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True) => (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True) => (qkv): Linear(in_features=384, out_features=1152, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True) => (proj): Linear(in_features=384, out_features=384, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True) => (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True) => (fc1): Linear(in_features=384, out_features=1536, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=3072, out_features=768, bias=True) => (fc2): Linear(in_features=1536, out_features=384, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
)
(norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True) => (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
(pre_logits): Identity()
(head): Linear(in_features=768, out_features=1000, bias=True)
(head_dist): Linear(in_features=768, out_features=1000, bias=True)
)
(mlp_head): Sequential(
(0): LayerNorm((768,), eps=1e-05, elementwise_affine=True) => (0): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
(1): Linear(in_features=768, out_features=5, bias=True) => (1): Linear(in_features=384, out_features=5, bias=True)
)
)
MACs: 58.0542 G => 16.2705 G
Params: 87.2606 M => 23.1657 M