meta learning

Meta Learning——MAML

2019-09-29  本文已影响0人  单调不减

Meta Learning就是元学习,所谓元学习,就是学习如何去学习。这个概念放在机器学习当中就是我们希望找到一个算法可以帮助我们找到一个好的学习算法去解决问题

传统机器学习算法是从一个函数集合中找到一个好的函数来拟合数据,而元学习算法是从一个学习算法集合中找到一个好的学习算法来解决问题。

听上去好像很厉害的样子?元学习居然可以自己找到最适合当前问题的算法?

也别抱太高期望,因为元学习做的事并不是自己构建一个算法,而只是从我们给出的算法集合中挑一个出来。听起来似乎也很厉害了?我们先来看一看我们提供的算法集合会是什么样子:

拿神经网络模型来说,需要人为指定的包括上图中的红框部分——网络架构(决定了函数集合)、权重初始值、更新参数的算法(选用何种梯度下降法以及梯度下降法中学习率变化的算法)。

那么我们现在希望机器自己做到红框中的某些部分,从而进一步减少人为参与。注意,上图中整个灰色框框内的内容都是学习算法F的一部分,若红色框框内的部分都由人来指定,那么这就是一个由人设计好的算法F,喂给F一些data,就可以得到一个不错的模型f,这就是机器学习做的事情;若我们把某个红框内的内容让机器自己去学,那么此时的F就可以视为一个元学习函数了。

我们先看一下元学习算法和机器学习算法在训练上的差别。机器学习算法有训练集和测试集,两者都是同一类型的数据;元学习算法则有一系列的训练Task和测试Task,每个Task又包含了自己的训练集和测试集

因此我们要评判一个元学习算法F的好坏,要看它在训练集中的各个Task上的平均表现,具体来说,当我们评判F在Task 1上的表现时,先让F在Task 1的训练集上进行训练,然后输出一个模型f注意,这里输出的f是训练好的模型,也就是说,F一共完成了两件事,第一件事,找到某个红框里的合适的值(比如参数初始值),第二件事,以学到的参数初始值初始化模型然后进行训练得到最终模型)。接下来检验f在Task 1的测试集上的表现,计算出损失l^1。类似的,我们可以计算出F在各个Task上的损失l^n,然后计算它们的平均值。

定义好F的损失函数,接下来我们只需要找到令损失函数L(F)最小的F即可。

接下来我们可以看一下MAML算法是如何做的:

MAML把参数初始化的任务交给机器去学习,因此这里把损失函数写成L(\phi)的形式,有趣的是,L(\phi)=\sum_{n=1}^{N} l^{n}\left(\hat{\theta}^{n}\right)中的\hat{\theta}^{n}是由\phi经过梯度下降得到的。下图显示了\hat{\theta}^{n}\phi的关系:

而优化L(\phi)的方法依然是梯度下降:

然而还有一个问题亟待解决——\hat{\theta}^{n}是由\phi经过梯度下降得到的,那么迭代几次合适呢?要迭代至收敛吗?

MAML给出的做法是,只做一次!

为什么只进行一次迭代呢?一次迭代得到的\hat{\theta}^n表现不好真的可以说明初始值\phi不好而不是训练不足吗?

这样做的原因有以下几个方面:

首先是快,节约计算成本。

然后是我们这样调整\phi最后说不定真能得到一个迭代一两次就能得到好结果的\phi呢。

然后是当我们真的用这个初始值来训练最终的模型f的时候,还是会迭代多次的,所以不用担心。

最后就是如果我们的Task中的数据很少,那么迭代多次很容易过拟合,所以迭代一次是不错的选择。

总之,只进行一次迭代的梯度下降过程如下,经过推导和近似,可以得到:

也就是说,损失函数对\phi的偏导可以近似看作对\hat{\theta}的偏导:

下图是整个MAML的操作流程示意图,可以看到,整个算法对\phi的更新过程很简单,首先,先由\phi_0初始化的网络进行一步梯度下降得到更新后的参数\hat{\theta}^m,然后求损失函数关于\hat{\theta}^m的梯度方向,并让\phi_0沿此方向移动一个步长得到\phi_1,依次进行。

Meta Learning是机器学习中比较新的一个领域,它的愿景是激动人心的,让机器学会如何学习这件事本身也很酷,但是目前为止我们能做到的还很有限,我们仍然需要把大框架搭建好,然后试着把其中的一部分工作交给机器,而且这部分工作如何做依然要人类指定,从这个角度来说,Learn to Learn的目标还任重道远。

上一篇下一篇

猜你喜欢

热点阅读