Pytorch: 用thop计算pytorch模型的FLOPs

2019-09-29  本文已影响0人  wzNote

安装thop

pip install thop

基础用法

from torchvision.models import resnet50
from thop import profile
model = resnet50()
input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input, ))
class YourModule(nn.Module):
    # your definition
def count_your_model(model, x, y):
    # your rule here

input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input, ), 
                        custom_ops={YourModule: count_your_model})
from thop import clever_format
flops, params = clever_format([flops, params], "%.3f")

参考:https://github.com/Lyken17/pytorch-OpCounter

上一篇下一篇

猜你喜欢

热点阅读