JAX快速入门
2022-09-17 本文已影响0人
辘轳鹿鹿
JAX由autograd和XLA(accelerate linear algebra)组成
- 做函数优化(感知机)
import numpy as np
def func(x,a,b):
y = x*a+b
return y
def loss_function(weights,x,y):
a,b = weights
y_hat = func(x,a,b)
return (y_hat-y)**2
jax的作用就是引入梯度
from jax import grad
def f(x):
return x**2
df = grad(f)
df(3.0) #返回6.0
a = np.random.random()
b = np.random.random()
weights = [a,b]
x = np.array([np.random.random() for _ in range(1000)])
y = np.array([3*xx+4 for xx in x])
grad_func = grad(loss_func)
grad_func(weights,x,y)
learning_rate = 0.001
for i in range(100):
loss = loss_func(weights,x,y)
da,db = grad_func(weights,x,y)
a = a - learning_rate*da
b = b - learning_rate*db