6Batch Norm (手动实现)

2018-10-22  本文已影响0人  Rooooooooong

1Batch Norm能做什么—

(1)解决covariate shift 问题

我们知道在中间层,每各神经元的输出又是下一次的输入。既然我们为了加速梯度下降算法,要对原始数据做归一化。那么有什么理由不对中间层的输出数据做归一化?这种对中间层的输出数据做归一化的过程就叫做Batch Norm。进一步来说,如果不做Batch Norm,那么输出的数据会服从某种分布,下一层的网络会学习到这种分布。以此类推,蝴蝶效应出现了。为了防止这种情况,我们做Batch Norm,这样每一层网络学到的都是既定的分布。
那么问题来了,这个既定的分布是多少呢?测试集和训练集的分布有区别吗?
就训练集而言,数据服从N(beta,gamma^2)
测试集则有:
N(running\_mean,running\_var)
running_mean是根据指数平滑法将mini-batch不同的batch的均值做的平均。

(2)起到略微的正则化效果

2Batch Norm 实现思路

def Batchnorm_simple_for_train(x,gamma,beta,bn_param):
    running_mean = bn_param['running_mean'] 
    running_var = bn_parma['running_var']
    results = 0.
    
    x_mean = x.mean(axis=0)
    x_var = x.var(axis=0)
    x_normalized = (x - x_mean)/np.sqrt(x_var + eps)
    results = gamma * x_normalized + beta
    
    running_mean = momentum * running_mean + (1 - momentum) * x_mean
    running_var = momentum * running_var + (1-momentum) * x_var
    
    bn_param['running_mean'] = running_mean
    bn_param['running_var'] = running_var
    
    return results , bn_param
    
def Batchnorm_simple_for_test(x,gamma,beta,bn_param):
    running_mean = bn_param['running_mean']
    running_var = bn_param['running_var']
    results = 0.
    x_normalized = (x - running_mean)/(np.sart(running_var))
    results = gamma * x_normalized + beta
    
    return results , bn_param
上一篇 下一篇

猜你喜欢

热点阅读