PyTorch提取中间层特征

2019-10-04  本文已影响0人  顾北向南

https://mp.weixin.qq.com/s/U80uqeP-_nRJTjJZ3MfQ4g
本文仅作为学术分享,如果侵权,会删文处理

作者:涩醉
https://www.zhihu.com/question/68384370/answer/751212803

import torch
from torchvision.models import resnet18
import torch.nn as nn
from torchvision import transforms

import matplotlib.pyplot as plt


def viz(module, input):
    x = input[0][0]
    #最多显示4张图
    min_num = np.minimum(4, x.size()[0])
    for i in range(min_num):
        plt.subplot(1, 4, i+1)
        plt.imshow(x[i])
    plt.show()


import cv2
import numpy as np
def main():
    t = transforms.Compose([transforms.ToPILImage(),
                            transforms.Resize((224, 224)),
                            transforms.ToTensor(),
                            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                 std=[0.229, 0.224, 0.225])
                            ])

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    model = resnet18(pretrained=True).to(device)
    for name, m in model.named_modules():
        # if not isinstance(m, torch.nn.ModuleList) and \
        #         not isinstance(m, torch.nn.Sequential) and \
        #         type(m) in torch.nn.__dict__.values():
        # 这里只对卷积层的feature map进行显示
        if isinstance(m, torch.nn.Conv2d):
            m.register_forward_pre_hook(viz)
    img = cv2.imread('/Users/edgar/Desktop/cat.jpeg')
    img = t(img).unsqueeze(0).to(device)
    with torch.no_grad():
        model(img)

if __name__ == '__main__':
    main()

作者:袁坤
https://www.zhihu.com/question/68384370/answer/419741762

inter_feature = {}
 inter_gradient = {}
 def make_hook(name, flag):
     if flag == 'forward':
         def hook(m, input, output):
             inter_feature[name] = input
         return hook
     elif flag == 'backward':
         def hook(m, input, output):
             inter_gradient[name] = output
         return hook
     else:
         assert False
m.register_forward_hook(make_hook(name, 'forward'))
m.register_backward_hook(make_hook(name, 'backward'))
output = model(input)  # achieve intermediate feature
loss = criterion(output, target)
loss.backward()  # achieve backward intermediate gradients

作者:罗一成
https://www.zhihu.com/question/68384370/answer/263120790

上一篇 下一篇

猜你喜欢

热点阅读