求解根号2

2022-07-05  本文已影响0人  leon0514

梯度下降求解\sqrt{2}

f(t) = (t^2 - Y)^2 , f'(t) = 4t(t^2 - Y)

梯度下降法
从数学的角度看,梯度的方向时函数增长最快的方向,梯度的反方向就是函数减少最快的方向,通过梯度下降法可以找到凸函数的极小值点,对于存在多个极小值点的函数,梯度下降法可能会陷入局部最优,从而得不到全局最优解。
f(t) = (t^2 - 2)^2来求解\sqrt{2}举例来说,如果将t初始化为\frac{x}{2},会逐步逼近函数右边的最小值,如果将t初始化为-\frac{x}{2},则会逼近函数左边的最小值,得到错误的答案。
所以在初始化参数的时候,按照一定的经验可以帮助模型快速收敛,按照正确的方向收敛。通过梯度下降法求\sqrt{2}的过程中得知,初始化的参数和学习率的设定不当还会导致爆炸,权重值为NaN,程序无法继续迭代。

#include <iostream>
using namespace std;

double gd_sqrt(double x)
{
    int iter = 0;
    double lr = .01; // 学习率会影响梯度下降的快慢、能否跳跃出局部最优解
    double t = x / 2; // 初始化对函数运行的时间和正确性都有影响
    double loss = (t*t -x)*(t*t-x);
    while (loss > 1e-5)
    {
        ++iter;
        double delta_t = 2*(t*t - x)*2*t;
        t = t - lr * delta_t;
        loss = (t*t -x)*(t*t - x);
    }
    cout << "梯度下降法迭代了" << iter << "次" << endl;
    return t;  
}

int main()
{
    double x = 2;
    double p = gd_sqrt(x);
    cout << "gd_sqrt("<<x<<") = "<<p<<endl;
}
// 编译 g++ sqrt.cpp
// 输出
// (pytorch) [leon@ubuntu algocpp]$ ./a.out 
// 梯度下降法迭代了38次
// gd_sqrt(2) = 1.41318

牛顿法1求解\sqrt{2}

#include <iostream>
using namespace std;
double N_sqrt1(double x)
{
    int iter = 0;
    double t = x / 2;
    double ans = t*t - x;
    while (abs(ans) > 1e-5)
    {
    ++iter;
    t = t - ans / (2*t);
        ans = t*t - x;    
    }
    cout << "牛顿法1迭代了" << iter << "次" << endl;
    return t;
}

int main()
{
    double x = 2;
    double p =N_sqrt1(x);
    cout << "N_sqrt1("<<x<<") = "<<p<<endl;
}
// 输出
// 牛顿法1迭代了3次
// N_sqrt1(2) = 1.41422

牛顿法2求解\sqrt{2}

#include <iostream>
using namespace std;
double N_sqrt2(double x)
{
    int iter  = 0;
    double  t   = x / 2;
    double ans = (t*t - x)*(t*t - x);
    double dt  = 0;
    double ddt = 0;
    while (abs(ans) > 1e-5)
    {
        ++iter;
        dt = 4*t*(t*t - x);
        ddt = 12*t*t - 4*x;
        t = t - dt / ddt;
        ans = t*t - x;
    }
    cout << "牛顿法2迭代了" << iter << "次" << endl;
    return t;
}
int main()
{
    double x = 2;
    double p = N_sqrt2(x);
    cout << "N_sqrt2("<<x<<") = "<<p<<endl;
}
// 输出
// 牛顿法2迭代了5次
// N_sqrt2(2) = 1.41421
上一篇 下一篇

猜你喜欢

热点阅读