简书元学习,谐韵五十行
翻译自Meta-Learning in 50 Lines of JAX
作者:Eric Jang
译者:尹肖贻
本文全部资料,位于Github地址:https://github.com/ericjang/maml-jax
人和动物常表现出适应环境的行为。从时间长短而论,大致分为两种。短时间的适应,如使用淋浴阀,我得摆弄一会儿,才能调出舒服的温度;如乍读新领域的文章,我慢慢学着从中汲取信息。长时间的适应,如学习某种乐器的演奏,需要耗费一生,刻意练习。
这样的学习行为,不仅见于动物,还见于各类生物【译者按:甚至非生物】。培养多细胞的项目中,常可观察到高度灵活的适应行为,甚至在代际间积累了表观遗传学意义上的“记忆”。在较长的时间尺度上,进化本身可以被认为是种群级别的“学习”,即优势基因的代际传承;在较短的时间尺度上,粒子的能级跃迁也可被认为是激励下的“学习”,即在为了适应环境而做出反应。生物学家常故意模糊“行为”(对环境的反应)、“学习”(从外界获取信息以提高适应度)和“优化”(提高适应度)的界限。
机器学习(ML)的核心问题,在于教会计算机获得自主使用数据的能力,去完成人类难以明确说明的任务。然而大多数机器学习专家所谓的“学习”,只是生物适应环境的行为里的很小的子集。深度学习模型很强大,但常须车载斗量的数据、擢发难数的梯度反传迭代。虽然学习过程旷日持久,但是模型的行为能力呆板。在开发周期中,若想改变系统的输出(如更改一个错误),非得使用昂贵的重启手续不可。【划重点】能不能设计一个训练更快、训练数据更少的系统咧?
“元学习”正是解决这个问题的热点课题。该课题的目的在于,不仅仅要模型“预测得好”,还要“学习得好”。虽然元学习在最近几年吸引了大量研究者,相关的问题和算法早已有之。(对于相关概念感兴趣的读者,请参看Hugo Larochelle的PPT和Lilian Weng的博文,二者对此有精彩的梳理)
本文不是介绍元学习系统方方面面的综述,而是一份开启你元学习研究的实用教程。特别是我要教你利用谷歌卓越的JAX库,用50行python代码,搭建元学习算法系统MAML。
读者可以自行下载Jupyter notebook版本的内容自洽的教程,复现本教程的结果。
从学习算子(learning operator)的角度,理解元学习
“元学习”这个词被研究人员严重滥用,以至于我在跟其他同行提到“元学习”的时候,很难在交流中保持统一。这一滥用现象的源头,在于一些术语定义含混,如“优化”、“学习”、“适应”、“记忆”,更不要说这些术语在应用场景下的泛滥使用。
这一节我试着用数学的角度定义“学习”和“元学习”,并解释一下为什么最近一大票不同的算法都打上“元学习”的牌子。要是你想要直接上手MAML+JAX代码的学习,请直接跳过这一节。
我们定义学习算子为针对某种场景的函数的算法,用以提升的表现效果。普通的学习算子,一般应用在深度学习和增强学习中,定义为针对某个损失函数的梯度下降算法。在典型的深度学习场景中,学习步骤往往持续几千乃至几百万次梯度更新;但在更一般的场景中,“学习”既可以发生在较短的时间尺度上(如求条件概率),也可以发生在更长的时间尺度上(如超参数的搜索)。除了显式的优化过程,学习还可以隐式地出现在动态系统中(如自回归神经网络RNN学习当下的条件概率)或概率推断中。
元学习算子(meta-learning operator)定义为两个学习算子嵌套的组合算子:“内环”和“外环”。进一步来说,是模型本身,为内环学习的算子。或者这么说,学习的学习规则,学习具体任务的规则。我们定义“任务”为这样一族逻辑自洽的问题,它们的可以充分更新。在元学习训练中,在的参数中选择对于一众任务最合适的一组;在元学习测试中,我们评估和的泛化能力是否支持不同的任务。
对于和的选择要依具体问题而定。在架构搜索(architecture search)的语境下(也称“学习如何学习”),的网络从零开始训练的过程相对缓慢,这时可以做神经控制器(neural controler),或者叫随机搜索算法,或者高斯过程搜索(Gaussian Process Bandit)。
可被称为元学习算子的机器学习问题有很多。在模仿式(元)学习(meta imitation learning)(或称为条件概率的目标强化学习(goal-conditioned reinforcement learning))中,的选择依据强化学习的代理人(agent)的操作反馈,比如针对在某项任务的数学表示(task embedding)的条件下、或在某种人为设定的场景(human demonstration)条件下的概率反馈。在元强化学习(meta reinforcement learning MRL)中,实现的手段是“快速强化学习(fast reinforcement learning)”算法,在其中代理人通过数次试错来优化自身策略。这里值得重申,强化学习的场景下,“学习(learning)”与“条件概率优化(conditioning)”没有区别,因为二者都要依赖测试时的输入(或称“环境提供的新信息”)。
MAML是一类通过随机梯度下降实现的元学习算法。形式化为:。随机梯度下降更新对于来说是可导的,这样就不必在优化的梯度反传时,使用额外的参数表示。
探索JAX:梯度
我们以JAX的即来即用的numpy库和梯度算符grad开始这个教程吧。
import jax.numpy as np
from jax import grad
梯度算符grad将一个python函数转化为另一个可求梯度的函数。这里,我们演示如何计算和的一阶导、二阶导、三阶导。
f = lambda x : np.exp(x)
g = lambda x : np.square(x)
print(grad(f)(1.)) # = e^{1}
print(grad(grad(f))(1.))
print(grad(grad(grad(f)))(1.))
print(grad(g)(2.)) # 2x=4
print(grad(grad(g))(2.)) # x=2
print(grad(grad(grad(g)))(2.)) #x=0
探索JAX:自向量化函数vmap
现在我们考虑一个简单的回归案例,我们去拟合函数,目的是熟悉怎样定义和训练网络。JAX设置了一些轻量级的函数,方便搭建简单的网络
from jax imort vmap # for auto-vectorizing functions
from functools import partial # for use with vmap
from jax import jit # for compiling fuctions for speedup
from jax.experimental import stax # neural network library
from jax.experimental.stax import Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax # neural network layers
import matplotlib.pyplot as plt # visualization
我们将定义一个两层隐含层的神经网络,定义 in_shape为(-1,1),意思是可变的batchsize,而特征的维度是1(因为这是一个一维的回归问题)。JAX的工具箱提供的API全部是泛函形式的(这与TensorFlow不同,后者保持了图结构),所以我们返回了一个初始化参数的函数和一个前传网络。这些函数都是可调用的numpy序列的元组的列表(lists of tuples of numpy arrays)——一种存储网络参数的简单易用的数据结构。
# 使用stax来初始化或评估网络参数
net_init, net_apply = stax.serial(
Dense(40), Relu(),
Dense(40), Relu(),
Dense(1)
)
in_shape = (-1, 1, )
out_shape, net_params = net_init(in_shape)
然后,我们定义模型在一个batch的输入数据的平均平方误差(Mean-square Error MSE)损失。
def loss(params, inputs, targets):
# 计算一个batch的平均损失
predictions = net_apply(params, inputs)
return np.mean((targets - predictions)**2)
我们评估未初始化的网络在输入下的结果:
# 将K=100个输入成批推断
xrange_inputs = np.linspace(-5, 5, 100).reshape(100, 1) #(k,1)
targets = np.sin(xrange_inputs)
predictions = vmap(partial(net_apply, net_params))(xrange_inputs)
losses = vmap(partial(loss, net_params))(xrange_inputs, targets) # per-input loss
plt.plot(xrange_inputs, predictions, label='prediction')
plt.plot(xrange_inputs, losses, label='loss')
plt.plot(xrange_inputs, targets, label='targets')
plt.legend()
正如预期的那样,在随机初始化下,模型的预测(蓝线)完全偏离了目标函数(绿线)。
我们用梯度下降算法来更新参数。JAX的随机函数的产生器和numpy的随机函数产生器不同,所以用numpy的产生器(onp)来随机化网络参数。我们要引入tree_multimap函数来管理参数的梯度(对于TensorFlow的用户,这个函数类似于nest.map_stucture的张量的函数)
import numpy as onp
from jax.experimental import optimizers
from jax.tree_util import tree_multimap
# 对numpy的array的集合进行Element-wise级别的操作
我们初始化参数和优化器,将曲线拟合操作循环100次。值得注意,@jit这个修饰器,能将一整个训练函数(和优化器、内存和代码优化一起)都用上XLA编译成机器码。TensorFlow也使用XLA来加速统计类定义的网络。XLA使得计算非常快,和硬件的兼容性很强,因为它不需要返回一个Python解释器(或者在没有XLA时TensorFlow返回计算图解释器)。这里的代码只能运行在CPU、GPU或者TPU上。
opt_init, opt_update = optimizers.adm(step_size=1e-2)
opt_state = opt_init(net_params)
# 定义一个编译的更新步骤
@jit
def step(i, opt_state, x1, y1):
p = optimizers.get_params(opt_state)
g = grad(loss)(p, x1, y1)
return opt_update(i, g, opt_state)
for i in range(100):
opt_state = step(i, opt_state, xrange_inputs, targets)
net_params = optimizers.get_params(opt_state)
重新执行绘图代码
predictions = vmap(partial(net_apply, net_params))(xrange_inputs)
losses = vmap(partial(loss, net_params))(xrange_inputs, targets) # per-input loss
plt.plot(xrange_inputs, predictions, label='prediction')
plt.plot(xrange_inputs, losses, label='loss')
plt.plot(xrange_inputs, targets, label='target')
plt.legend()
在下面MAML的代码里,我们将反复用到上文提到的函数。
探索JAX:用数值检查MAML
在完成机器学习算法的代码时,一定要通过单元测试,测试案例的结果必须可以通过分析的方法得出真值。下面的例子对于toy目标函数做了测试代码。值得注意的是,默认情况下,JAX会对函数的第一个变量计算梯度。
# MAML的梯度检查
# 检查数值
g = lambda x, y: np.square(x) + y
x0 = 2.
y0 = 1.
print('grad(g)(x0) = {}'.format(grad(g)(x0, y0))) # 2x = 4
print('x0 - grad(g)(x0) = {}'.format(x0 - grad(g)(x0, y0))) # x - 2x = -2
def maml_objective(x, y):
return g(x - grad(g)(x, y), y)
# x**2 + 1 = 5
print('maml_objective(x,y)={}'.format(maml_objective(x0, y0)))
# x - (2x) = -2.
print('x0 - maml_objective(x,y) = {}'.format(x0 - grad(maml_objective)(x0, y0)))
用JAX编写MAML
现在我们拓展一下回归正弦曲线的例子,让正弦函数的相位和幅度都可以变化。这个例子是MAML论文里提到的简单示例。下图是从两个不同的任务中采样的点,每一个任务都有训练集(用来计算内部损失)和验证集(用来计算同一任务的外部损失)。
设 MAML
在元学习过程中,网络学着怎样快速地匹配 batch maml
结论
在本教程中,我们研究了MAML算法,并用大约50行Python代码重现了原文中的正弦回归任务。我很高兴地发现,grad、vmap和jit实现MAML非常容易,它们将继续用于我的元学习研究。
那么,“优化”、“学习”、“适应”和“记忆”之间有什么区别呢?我认为它们是等效的,因为使用优化技术(MAML)实现记忆功能是可能的,反之亦然(例如基于RNN的元增强学习)。在强化学习中,模仿教师网络、或根据用户指定的目标进行调节、或从失败中恢复,都可以使用相同的机制。
思考“学习”和“元学习”的精确定义,并尝试将它们与生物智能相对应,使我认识到生命活动的每一个过程,都可归结于不同层面的学习行为:从分子层面的化学反应,到物种层面的遗传进化,行为适应存在于每个时间尺度。在未来我将对人造生命和机器学习的话题做更多阐述,但现在,是时候结束了。感谢您阅读本篇拟合正弦函数的简单教程!
致谢
感谢Matthew Johnson帮助校对本文,并解决了一些有关JAX的问题。