Freeze BN in Pytorch

2021-03-08  本文已影响0人  Birdy潇
def set_bn_eval(m):
    classname = m.__class__.__name__
    if classname.find('BatchNorm') != -1:
        m.eval()

use model.apply() to freeze bn

def train(model,data_loader,criterion,epoch):
    model.train() # switch to train mode
    model.apply(set_bn_eval) # this will freeze the bn in training process
    ###
    # training code
    ###

wrap up, commonly used

def main():
    # ...
    for epoch in epochs:
        train(model,train_loader,criterion,epoch)
        test(model,eval_loader,epoch)
    # ...
上一篇 下一篇

猜你喜欢

热点阅读