pytorch 拆分已定义的网络结构(slicing netwo

2019-03-24  本文已影响0人  Zeke_Wang

在pytorch中,常会load已有模型甚至pretrained的模型,用其中几层作为特征提取(feature extraction)。比如用pytorch内置的pretrained ResNet作为特征提取器,需要把fully connected layer去掉。可以用children()方法提出需要的层

import torch.nn as nn
from torchvision import models

model = models.resnet50(pretrained=True)
truncated_model = nn.Sequential(*list(model.children())[:8])
print(truncated_model)

truncted_model可作为feature extractor,需要注意输入输出大小即可。
PS: *list可以达到以下效果

l = ["./foo", "bar", "quux"]

funcXXX(*l)
# 等价于
funcXXX("./foo", "bar", "quux")

也即是,iterate 提取list中的内容,并以逗号分隔。满足nn.Sequential()的输入条件

上一篇 下一篇

猜你喜欢

热点阅读