算法小白菜AITensorFlow

深入理解FTRL

2019-02-15  本文已影响78人  HorningFeng

FTRL算法是吸取了FOBOS算法和RDA算法的两者优点形成的Online Learning算法。读懂这篇文章,你需要理解LR、SGD、L1正则。

FOBOS算法

前向后向切分(FOBOS,Forward Backward Splitting)是 John Duchi 和 Yoran Singer 提出的。在该算法中,权重的更新分成两个步骤,其中t是迭代次数,\eta^t是当前迭代的学习率,G^t是loss func的梯度,\Psi(W)是正则项,如下:

W^{t+0.5}=W^t-\eta^tG^t
W^{t+1}=argmin_W\{ \frac{1}{2} \Vert W-W^{t+0.5} \Vert^2_2 + \eta^{t+0.5}\Psi(W) \}

权重更新的另外一种形式:
对上式argmin部分求导,令导数等于0可得:
W^{t+1}=W^t-\eta^tG^t-\eta^{t+0.5}\partial\Psi(W^{t+1})
这就是权重更新的另外一种形式,可以看到W^{t+1}的更新不仅与W^{t}有关,还与自己本身有关,有人猜测这就是前向后向的来源。

L1-FOBOS,正则项为L1范数,其中\lambda>0
W^{t+0.5}=W^t-\eta^tG^t
W^{t+1}=argmin_W\{ \frac{1}{2} \Vert W-W^{t+0.5} \Vert^2_2 + \eta^{t+0.5}\lambda\Vert W \Vert_1 \}

合并为一步:
\eta^{t+0.5}=\eta^t,将二次项乘开,消去常数项得
W^{t+1}=argmin_W\{ G^t W + \frac{1}{2\eta^t}\Vert W-W^t\Vert^2_2 + \lambda \Vert W \Vert_1\}

闭式解:
w^{t+1}_i= \begin{cases} 0, & \mathcal{if\ \vert w^t_i-\eta^t g^t_i\vert\leq\eta^{t+0.5}\lambda}\\ (w^t_i-\eta^t g^t_i)-\eta^{t+0.5}\lambda\cdot sgn(w^t_i-\eta^t g^t_i), & \mathcal{otherwise} \end{cases}
推导过程略,思路同下方FTRL闭式解的推导过程。


RDA算法

RDA(Regularized Dual Averaging Algorithm)叫做正则对偶平均算法,特征权重的更新策略如下,只有一步,其中
累积梯度G^{(1:t)}=\sum_{s=1}^t G^s
累积梯度平均值g^{(1:t)}=\frac1t\sum_{s=1}^t G^s=\frac{G^{(1:t)}}{t}
\Psi(W)是正则项,h(W)是一个严格的凸函数,\beta^{(t)}是一个关于t的非负递增序列:

W^{t+1}=argmin_W\{ g^{(1:t)}W + \Psi(W) + \frac{\beta^{(t)}}{t}h(W) \}

L1-RDA:
\Psi(W)=\lambda\Vert W \Vert_1,令h(W)=\frac{1}{2}\Vert W \Vert^2_2,令\beta^{(t)}=\gamma\sqrt{t},其中\lambda>0\gamma>0,并且各项同时乘以t,得:
W^{t+1}=argmin_W\{ g^{(1:t)}W + \lambda \Vert W \Vert_1 + \frac{\gamma}{2\sqrt{t}}\Vert W \Vert^2_2\}


y=\vert x \vert
x=0, \partial y \in (-1,1)


闭式解:
w^{t+1}_i= \begin{cases} 0, & \mathcal{if\ \vert g^{(1:t)}_i\vert<\lambda}\\ -\frac{\sqrt t}{\gamma}(g^{(1:t)}-\lambda sgn(g^{(1:t)})), & \mathcal{otherwise} \end{cases}
推导过程略,思路同下方FTRL闭式解的推导过程。


FTRL算法

FTRL 算法综合考虑了 FOBOS 和 RDA 对于梯度和正则项的优势和不足,其中累积梯度G^{(1:t)}=\sum_{r=1}^t G^r\sigma^s=\frac{1}{\eta^s}-\frac{1}{\eta^{s-1}}\sigma^{(1:t)}=\frac{1}{\eta_t}=\sum_{s=1}^t \sigma^s\lambda_1>0\lambda_2>0,特征权重的更新公式是:
W^{t+1}=argmin_w\{ G^{(1:t)}W + \lambda_1 \Vert W \Vert_1 + \frac{\lambda_2}{2}\Vert W \Vert^2_2 + \frac{1}{2}\sum_{s=1}^t\sigma^s\Vert W-W^s\Vert^2_2 \}
维度i的学习率设置为\eta^t_i=\frac{\alpha}{\beta+\sqrt{\sum_{s=1}^t (g^{(s)})^2}},随着迭代次数增加而减小

使用\sigma替换学习率可将L1-FOBOS、L1-RDA、FTRL写成类似的形式,如下:
W^{t+1}_{(L1-FOBOS)}=argmin_W\{ G^t W + \lambda \Vert W \Vert_1 + \frac{1}{2}\sigma^{(1:t)}\Vert W-W^t\Vert^2_2 \}
W^{t+1}_{(L1-RDA)}=argmin_W\{ G^{(1:t)}W + t\lambda \Vert W \Vert_1 + \frac{1}{2}\sigma^{(1:t)}\Vert W-0 \Vert^2_2\}
W^{t+1}_{(FTRL)}=argmin_W\{ G^{(1:t)}W + \lambda_1 \Vert W \Vert_1 + \frac{\lambda_2}{2}\Vert W \Vert^2_2 + \frac{1}{2}\sum_{s=1}^t\sigma^s\Vert W-W^s\Vert^2_2 \}
各项解释todo

闭式解及其推导过程
将二次项乘开,消去常数项,得:
W^{t+1}=argmin_W\{ (G^{(1:t)}-\sum_{s=1}^t\sigma^sW^s)W + \lambda_1 \Vert W \Vert_1 + \frac{1}{2}(\lambda_2+\sum_{s=1}^t\sigma^s)\Vert W \Vert^2_2\}
Z^t=G^{(1:t)}-\sum_{s=1}^t\sigma^sW^s,则Z^t=Z^{t-1}+G^t-\sigma^t W^t,得:
W^{t+1}=argmin_W\{Z^tW+\lambda_1\Vert W \Vert_1 + \frac{1}{2}(\lambda_2+\sum_{s=1}^t\sigma^s)\Vert W \Vert^2_2\}
对于单个维度i来说:
w^{t+1}_i=argmin_w\{z^t_iw_i+\lambda_1\vert w_i \vert + \frac{1}{2}(\lambda_2+\sum_{s=1}^t\sigma^s)w^2_i\}
对上式,假设w^*_i是最优解,令上式导数等于0可得:
z^t_i+\lambda_1sgn(w^*_i)+(\lambda_2+\sum_{s=1}^t\sigma^s)w^*_i=0
我们分三种情况进行讨论

  1. \vert z^t_i\vert\leq\lambda_1时:
    1. w^*_i=0时,满足sgn(0) \in (-1,1),成立
    2. w^*_i>0时,z^t_i+\lambda_1sgn(w^*_i)=z^t_i+\lambda_1\geq0(\lambda_2+\sum_{s=1}^t\sigma^s)w^*_i>0上式不成立
    3. w^*_i<0时,z^t_i+\lambda_1sgn(w^*_i)=z^t_i-\lambda_1\leq0(\lambda_2+\sum_{s=1}^t\sigma^s)w^*_i<0上式不成立
  2. z^t_i>\lambda_1时:
    1. w^*_i=0时,不满足sgn(0) \in (-1,1),不成立
    2. w^*_i>0时,z^t_i+\lambda_1sgn(w^*_i)=z^t_i+\lambda_1>0(\lambda_2+\sum_{s=1}^t\sigma^s)w^*_i>0,上式不成立
    3. w^*_i<0时,z^t_i+\lambda_1sgn(w^*_i)=z^t_i-\lambda_1>0(\lambda_2+\sum_{s=1}^t\sigma^s)w^*_i<0w^*_t有解,w^*_t=-(\frac{\beta+\sqrt{\sum_{s=1}^t (g^{(s)})^2}}{\alpha}+\lambda_2)^{-1}(z^t_i-\lambda_1)
  3. z^t_i<-\lambda_1时:
    1. w^*_i=0时,不满足sgn(0) \in (-1,1),不成立
    2. w^*_i>0时,z^t_i+\lambda_1sgn(w^*_i)=z^t_i+\lambda_1<0(\lambda_2+\sum_{s=1}^t\sigma^s)w^*_i>0w^*_t有解,w^*_t=-(\frac{\beta+\sqrt{\sum_{s=1}^t (g^{(s)})^2}}{\alpha}+\lambda_2)^{-1}(z^t_i+\lambda_1)
    3. w^*_i<0时,z^t_i+\lambda_1sgn(w^*_i)=z^t_i-\lambda_1<0(\lambda_2+\sum_{s=1}^t\sigma^s)w^*_i<0,上式不成立

综上,可得分段函数形式的闭式解:
w^{t+1}_i= \begin{cases} 0, & \mathcal{if\ \vert z^t_i\vert<\lambda_1}\\ -(\frac{\beta+\sqrt{\sum_{s=1}^t (g^{(s)})^2}}{\alpha}+\lambda_2)^{-1}(z^t_i-sgn(z^t_i)\lambda_1), & \mathcal{otherwise} \end{cases}

论文内的伪代码

FTRL工程实现上的trick

如果不理解,回去仔细研究LR的公式。

[1] McMahan, H. Brendan, et al. "Ad click prediction: a view from the trenches." Proceedings of the 19th ACM SIGKDD international conference on Knowledge discovery and data mining. ACM, 2013.
[2] 张戎 FOLLOW THE REGULARIZED LEADER (FTRL) 算法总结 https://zhuanlan.zhihu.com/p/32903540

上一篇 下一篇

猜你喜欢

热点阅读