梯度下降

2017-05-07  本文已影响0人  greymonster

最近在做关于数据的一些东西,就是研究游戏在线人数的变换曲线,看了梯度下降。
  梯度下降的介绍百度一下说的都很好,我也是百度学习的 可以@refer 这篇 http://blog.csdn.net/woxincd/article/details/7040944
  我的理解就是先随意找一个点,然后求出这个点在各个方向的切线,顺着这个方向认为是下降最快的方向,然后根据步长(自己设定的一个值,这里是0.1)调节走的速度。最终找到极值。关键在于迭代,关于迭代,牛顿方法的迭代会更迅速。

下面是我用php写的一个小demo。

预期多元一次方程是 ax1 + bx2 = y , 给定了N组 x 和对应y的值 ,求 a b 分别是多少。
先转换成求极值 => 误差最小。

$dataset = [[1,4],[2,5],[5,1],[4,2]]; //初始的三组x(每组x包括x1 x2)
$dataret = [19,26, 19, 20]; // 对应三组x 的y 值
$expect  = [10, 10]; //随意找到的开始点 这里是指 预测 a = 10 b = 10
$step  = 0.001; //步长
$times = 1000000;
/*
 *梯度下降 * 
 * @auther menmei
 * @date 2017/03/24
 *
 */
/*
 * 梯度下降求多元一次方程的多元参数
 *
 *
 * @param 原始数据 Array
 * @param 原始数据结果 Array
 * @param 初始参数 Arrayθ
 * @param 步长 double
 * @param 循环次数 int
 *
 * @return 参数数组 Arrayθ
 *
 */
function gradientDescent($dataset, $dataret, $expect, $step, $times){
    //check given params
    $setTotal    = count($dataset);
    $paramsTotal = count($dataset[0]);
    if($setTotal < 2 || (count($expect) != $paramsTotal) || count($dataret) != $setTotal )  return False;
    //$deviation = array_fill(0, $paramsTotal, 0);
    for($i = 0; $i < $times; $i ++){
        $h = 0;
        $index = $i % $setTotal;
        for($j = 0; $j < $paramsTotal; $j ++){
            $h += ($expect[$j] * $dataset[$index][$j]);
        }

        $error = $h - $dataret[$index];
        for($k = 0; $k < $paramsTotal; $k ++){
            //这里是关键 这里 $error * $dataset[$index][$k] 是J(θ) 按梯度方向减少的量
            //是对J(θ) 求偏导得到的 => 按梯度每个方向的斜率 *  步长
            $expect[$k] -= $step * $error * $dataset[$index][$k];
        }

        //calculate new deviation
        $deviation = 0 ;
        for($l = 0; $l < $setTotal; $l ++){
            $h = 0;
            for($m = 0; $m < $paramsTotal; $m ++){
                $h += ($expect[$m] * $dataset[$l][$m]);
            }
            $deviation += ($h - $dataret[$l]) * ($h - $dataret[$l]);
        }
        if($deviation < 0.001) break;
    }
    echo "误差是{$deviation}";
    return $expect;
}

/***************** EXAMPLE ******************/
//sample 1
//这里步长设置成 0.1 就会越过最低点 然后继续向上。所以步长选择很 !重 !要 !
$dataset = [[1,4],[2,5],[5,1],[4,2]];
$dataret = [19,26, 19, 20];
$expect  = [10, 10];
$step  = 0.001;
$times = 1000000;

//sample 2
$dataset = [[1, 1, 2], [1, 2, 3], [1, 2, 5], [1, 8, 3], [1, 4, 7]];
$dataret = [13, 19, 27, 31, 39];
$expect = [0, 0, 0];

//sample 3
$dataset = [[1, 2], [2, 3], [2, 5], [8, 3], [4, 7]];
$dataret = [1, 4, 0, 34, 6];
$expect = [0,0];
$ret = gradientDescent($dataset, $dataret, $expect, $step, $times);
var_dump($ret);
上一篇下一篇

猜你喜欢

热点阅读