JAX快速入门

2022-09-17  本文已影响0人  辘轳鹿鹿

JAX由autograd和XLA(accelerate linear algebra)组成

import numpy as np

def func(x,a,b):
    y = x*a+b
    return y

def loss_function(weights,x,y):
    a,b = weights
    y_hat = func(x,a,b)
    return (y_hat-y)**2

jax的作用就是引入梯度

from jax import grad
def f(x):
    return x**2
df = grad(f)
df(3.0)  #返回6.0
a = np.random.random()
b = np.random.random()
weights = [a,b]
x = np.array([np.random.random() for _ in range(1000)])
y = np.array([3*xx+4 for xx in x])


grad_func = grad(loss_func)
grad_func(weights,x,y)



learning_rate = 0.001
for i in range(100):
    loss = loss_func(weights,x,y)
    da,db = grad_func(weights,x,y)
    a = a - learning_rate*da
    b = b - learning_rate*db
上一篇 下一篇

猜你喜欢

热点阅读