Ada_grad自适应梯度下降

2023-03-13  本文已影响0人  Co酱_秋乏术

def sgd_adagrad(parameters, sqrs, lr):

eps = 1e-10

for param, sqr in zip(parameters, sqrs):

  sqr[:] = sqr + param.grad.data ** 2

  div = lr / torch.sqrt(sqr + eps) * param.grad.data

  param.data = param.data - div

上一篇下一篇

猜你喜欢

热点阅读