SSAST模型结构

2025-05-24  本文已影响0人  小草_a484

模型架构分析

在模型剪枝的第一步就是理解模型的结构

for i in list(audio_model.named_parameters()):
    if('bias' not in i[0]):
        print(i[0])
module.v.cls_token
module.v.pos_embed
module.v.dist_token
module.v.patch_embed.proj.weight
module.v.blocks.0.norm1.weight
module.v.blocks.0.attn.qkv.weight
module.v.blocks.0.attn.proj.weight
module.v.blocks.0.norm2.weight
module.v.blocks.0.mlp.fc1.weight
module.v.blocks.0.mlp.fc2.weight
module.v.blocks.1.norm1.weight
module.v.blocks.1.attn.qkv.weight
module.v.blocks.1.attn.proj.weight
module.v.blocks.1.norm2.weight
module.v.blocks.1.mlp.fc1.weight
module.v.blocks.1.mlp.fc2.weight
module.v.blocks.2.norm1.weight
module.v.blocks.2.attn.qkv.weight
module.v.blocks.2.attn.proj.weight
module.v.blocks.2.norm2.weight
module.v.blocks.2.mlp.fc1.weight
module.v.blocks.2.mlp.fc2.weight
module.v.blocks.3.norm1.weight
module.v.blocks.3.attn.qkv.weight
module.v.blocks.3.attn.proj.weight
module.v.blocks.3.norm2.weight
module.v.blocks.3.mlp.fc1.weight
module.v.blocks.3.mlp.fc2.weight
module.v.blocks.4.norm1.weight
module.v.blocks.4.attn.qkv.weight
module.v.blocks.4.attn.proj.weight
module.v.blocks.4.norm2.weight
module.v.blocks.4.mlp.fc1.weight
module.v.blocks.4.mlp.fc2.weight
module.v.blocks.5.norm1.weight
module.v.blocks.5.attn.qkv.weight
module.v.blocks.5.attn.proj.weight
module.v.blocks.5.norm2.weight
module.v.blocks.5.mlp.fc1.weight
module.v.blocks.5.mlp.fc2.weight
module.v.blocks.6.norm1.weight
module.v.blocks.6.attn.qkv.weight
module.v.blocks.6.attn.proj.weight
module.v.blocks.6.norm2.weight
module.v.blocks.6.mlp.fc1.weight
module.v.blocks.6.mlp.fc2.weight
module.v.blocks.7.norm1.weight
module.v.blocks.7.attn.qkv.weight
module.v.blocks.7.attn.proj.weight
module.v.blocks.7.norm2.weight
module.v.blocks.7.mlp.fc1.weight
module.v.blocks.7.mlp.fc2.weight
module.v.blocks.8.norm1.weight
module.v.blocks.8.attn.qkv.weight
module.v.blocks.8.attn.proj.weight
module.v.blocks.8.norm2.weight
module.v.blocks.8.mlp.fc1.weight
module.v.blocks.8.mlp.fc2.weight
module.v.blocks.9.norm1.weight
module.v.blocks.9.attn.qkv.weight
module.v.blocks.9.attn.proj.weight
module.v.blocks.9.norm2.weight
module.v.blocks.9.mlp.fc1.weight
module.v.blocks.9.mlp.fc2.weight
module.v.blocks.10.norm1.weight
module.v.blocks.10.attn.qkv.weight
module.v.blocks.10.attn.proj.weight
module.v.blocks.10.norm2.weight
module.v.blocks.10.mlp.fc1.weight
module.v.blocks.10.mlp.fc2.weight
module.v.blocks.11.norm1.weight
module.v.blocks.11.attn.qkv.weight
module.v.blocks.11.attn.proj.weight
module.v.blocks.11.norm2.weight
module.v.blocks.11.mlp.fc1.weight
module.v.blocks.11.mlp.fc2.weight
module.v.norm.weight
module.v.head.weight
module.v.head_dist.weight
module.mlp_head.0.weight
module.mlp_head.1.weight

可以看到模型的架构就是一个VIT和一个MLP

for name,layer in audio_model.named_children():
    print(f"name:{name},type:{type(layer)}")
name:module,type:<class 'models.ast_models.ASTModel'>

可以看到模型结果是一个包装好的模型采用的DataParallel

for name,layer in audio_model.module.named_children():
    print(f"name:{name},type:{type(layer)}")
name:v,type:<class 'timm.models.vision_transformer.DistilledVisionTransformer'>
name:mlp_head,type:<class 'torch.nn.modules.container.Sequential'>

和前面的原理相同,即一个vit和一个container

for name,layer in audio_model.module.v.named_children():
    print(f"name:{name},type:{type(layer)}")
name:patch_embed,type:<class 'models.ast_models.PatchEmbed'>
name:pos_drop,type:<class 'torch.nn.modules.dropout.Dropout'>
name:blocks,type:<class 'torch.nn.modules.container.ModuleList'>
name:norm,type:<class 'torch.nn.modules.normalization.LayerNorm'>
name:pre_logits,type:<class 'torch.nn.modules.linear.Identity'>
name:head,type:<class 'torch.nn.modules.linear.Linear'>
name:head_dist,type:<class 'torch.nn.modules.linear.Linear'>

结果除了第一个embed是嵌入块是提前修改了的,其他都是torch.nn的经典模块

name:proj,type:<class 'torch.nn.modules.conv.Conv2d'>

而embed块里面只有一个卷积层

for name,layer in audio_model.module.v.blocks.named_children():
    print(f"name:{name},type:{type(layer)}")
name:0,type:<class 'timm.models.vision_transformer.Block'>
name:1,type:<class 'timm.models.vision_transformer.Block'>
name:2,type:<class 'timm.models.vision_transformer.Block'>
name:3,type:<class 'timm.models.vision_transformer.Block'>
name:4,type:<class 'timm.models.vision_transformer.Block'>
name:5,type:<class 'timm.models.vision_transformer.Block'>
name:6,type:<class 'timm.models.vision_transformer.Block'>
name:7,type:<class 'timm.models.vision_transformer.Block'>
name:8,type:<class 'timm.models.vision_transformer.Block'>
name:9,type:<class 'timm.models.vision_transformer.Block'>
name:10,type:<class 'timm.models.vision_transformer.Block'>
name:11,type:<class 'timm.models.vision_transformer.Block'>
for name,layer in audio_model.module.v.blocks[0].named_children():
    print(f"name:{name},type:{type(layer)}")
name:norm1,type:<class 'torch.nn.modules.normalization.LayerNorm'>
name:attn,type:<class 'timm.models.vision_transformer.Attention'>
name:drop_path,type:<class 'torch.nn.modules.linear.Identity'>
name:norm2,type:<class 'torch.nn.modules.normalization.LayerNorm'>
name:mlp,type:<class 'timm.models.vision_transformer.Mlp'>
for name,layer in audio_model.module.v.blocks[0].attn.named_children():
    print(f"name:{name},type:{type(layer)}")
name:qkv,type:<class 'torch.nn.modules.linear.Linear'>
name:attn_drop,type:<class 'torch.nn.modules.dropout.Dropout'>
name:proj,type:<class 'torch.nn.modules.linear.Linear'>
name:proj_drop,type:<class 'torch.nn.modules.dropout.Dropout'>
for name,layer in audio_model.module.v.blocks[0].mlp.named_children():
    print(f"name:{name},type:{type(layer)}")
name:fc1,type:<class 'torch.nn.modules.linear.Linear'>
name:act,type:<class 'torch.nn.modules.activation.GELU'>
name:fc2,type:<class 'torch.nn.modules.linear.Linear'>
name:drop,type:<class 'torch.nn.modules.dropout.Dropout'>

prune剪枝

prune是pytorch库中剪枝方法,分为非结构剪枝和结构剪枝

  1. 非结构剪枝
prune.random_unstructured(model.conv1,name="weight",amount=0.5)
prune.remove(model.conv1,'weight')
print(model.conv1.weight)
Parameter containing:
tensor([[[[-0.0000, -0.0000,  0.0000],
          [-0.1978, -0.2261, -0.0860],
          [ 0.1037,  0.0000,  0.0000]]],


        [[[-0.2053,  0.0410,  0.0000],
          [-0.0000, -0.0000, -0.0000],
          [-0.1972,  0.3187, -0.0000]]],


        [[[-0.0000, -0.0000, -0.1223],
          [ 0.1347,  0.3205,  0.0000],
          [-0.0682, -0.0000,  0.2382]]]], device='cuda:0', requires_grad=True)
prune.l1_unstructured(model.conv1,name="weight",amount=0.5)
prune.remove(model.conv1,'weight')
print(model.conv1.weight)
Parameter containing:
tensor([[[[-0.3320,  0.2106,  0.3279],
          [-0.3195,  0.3222,  0.0000],
          [-0.0000, -0.2038, -0.0000]]],


        [[[-0.0000,  0.3175,  0.0000],
          [-0.0000, -0.2928, -0.0000],
          [ 0.0000, -0.0000, -0.0000]]],


        [[[ 0.2634, -0.2414,  0.0000],
          [ 0.0000, -0.3128,  0.0000],
          [-0.2402, -0.2144,  0.0000]]]], device='cuda:0', requires_grad=True)
  1. 结构剪枝
prune.random_structured(model.conv1,name="weight",amount=0.33,dim=2)
prune.remove(model.conv1,'weight')
print(model.conv1.weight)
Parameter containing:
tensor([[[[ 0.0048,  0.1459, -0.0502],
          [-0.0000, -0.0000,  0.0000],
          [ 0.2687, -0.1137,  0.1034]]],


        [[[ 0.1801, -0.2711, -0.0819],
          [ 0.0000,  0.0000,  0.0000],
          [-0.2404, -0.3188,  0.3194]]],


        [[[ 0.0434, -0.0618,  0.0368],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.1729,  0.2978, -0.3020]]]], device='cuda:0', requires_grad=True)
prune.ln_structured(model.conv1,name="weight",amount=0.33,n=2,dim=2)
prune.remove(model.conv1,'weight')
print(model.conv1.weight)
Parameter containing:
tensor([[[[-0.1829,  0.2475,  0.2816],
          [ 0.0694, -0.1366,  0.0740],
          [-0.0000, -0.0000,  0.0000]]],


        [[[-0.2966, -0.2881,  0.2974],
          [ 0.3074,  0.2858,  0.1990],
          [-0.0000,  0.0000, -0.0000]]],


        [[[-0.0558, -0.3072, -0.0674],
          [-0.0860,  0.2881,  0.1865],
          [ 0.0000,  0.0000,  0.0000]]]], device='cuda:0', requires_grad=True)

模型性能指标统计工具

参数量计算

  1. 卷积层:param=out_channels(in_channelskernel_size^2)+out_channels
  2. 归一化层:param=2*out_channels
  3. 全连接层:param=in_feature*out_feature+out_features
  4. 其它层:无参数
from thop import profile

inp=torch.randn(48,512,128).to('cuda')
macs,params=profile(audio_model,inputs=(inp,))
print(f"MACs {macs}, Parameters: {params}")
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
MACs 2462153674752.0, Parameters: 85258757.0
SSAST-Base-Patch-400 比例 参数量 准确率 推理速度 mac
0 85258757 87.61 1.104047 2462153674752
10 85258757 2462153674752
30 85258757 87.60 1.096145 2462153674752
ASTModel(
  (v): DistilledVisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(1, 768, kernel_size=(16, 16), stride=(10, 10)) => (proj): Conv2d(1, 384, kernel_size=(16, 16), stride=(10, 10))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0-11): 12 x Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True) => (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True) => (qkv): Linear(in_features=384, out_features=1152, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True) => (proj): Linear(in_features=384, out_features=384, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True) => (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True) => (fc1): Linear(in_features=384, out_features=1536, bias=True)
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=3072, out_features=768, bias=True) => (fc2): Linear(in_features=1536, out_features=384, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True) => (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
    (pre_logits): Identity()
    (head): Linear(in_features=768, out_features=1000, bias=True)
    (head_dist): Linear(in_features=768, out_features=1000, bias=True)
  )
  (mlp_head): Sequential(
    (0): LayerNorm((768,), eps=1e-05, elementwise_affine=True) => (0): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=768, out_features=5, bias=True) => (1): Linear(in_features=384, out_features=5, bias=True)
  )
)

MACs: 58.0542 G => 16.2705 G
Params: 87.2606 M => 23.1657 M
上一篇 下一篇

猜你喜欢

热点阅读