小样本学习综述
1. 小样本学习
人类非常擅长通过极少量的样本识别一个新物体,比如小孩子只需要书中的一些图片就可以认识什么是“斑马”,什么是“犀牛”。在人类的快速学习能力的启发下,研究人员希望机器学习模型在学习了一定类别的大量数据后,对于新的类别,只需要少量的样本就能快速学习,这就是 Few-shot Learning 要解决的问题。
Few-shot Learning 是 Meta Learning 在监督学习领域的应用。Meta Learning,又称为 learning to learn,在 meta training 阶段将数据集分解为不同的 meta task,去学习类别变化的情况下模型的泛化能力,在 meta testing 阶段,面对全新的类别,不需要变动已有的模型,就可以完成分类。
形式化来说,few-shot 的训练集中包含了很多的类别,每个类别中有多个样本。在训练阶段,会在训练集中随机抽取 C 个类别,每个类别 K 个样本(总共 CK 个数据),构建一个 meta-task,作为模型的支撑集(support set)输入;再从这 C 个类中剩余的数据中抽取一批(batch)样本作为模型的预测对象(batch set)。即要求模型从 C*K 个数据中学会如何区分这 C 个类别,这样的任务被称为 C-way K-shot 问题。
训练过程中,每次训练(episode)都会采样得到不同 meta-task,所以总体来看,训练包含了不同的类别组合,这种机制使得模型学会不同 meta-task 中的共性部分,比如如何提取重要特征及比较样本相似等,忘掉 meta-task 中 task 相关部分。通过这种学习机制学到的模型,在面对新的未见过的 meta-task 时,也能较好地进行分类。
图 1 展示的是一个 2-way 5-shot 的示例,可以看到 meta training 阶段构建了一系列 meta-task 来让模型学习如何根据 support set 预测 batch set 中的样本的标签;meta testing 阶段的输入数据的形式与训练阶段一致(2-way 5-shot),但是会在全新的类别上构建 support set 和 batch。
image.png2. 元学习
3. 方法
早期的 Few-shot Learning 算法研究多集中在图像领域,如图 2 所示,Few-shot Learning模型大致可分为三类:Mode Based,Metric Based 和 Optimization Based。
其中 Model Based 方法旨在通过模型结构的设计快速在少量样本上更新参数,直接建立输入 x 和预测值 P 的映射函数;Metric Based 方法通过度量 batch 集中的样本和 support 集中样本的距离,借助最近邻的思想完成分类;Optimization Based 方法认为普通的梯度下降方法难以在 few-shot 场景下拟合,因此通过调整优化方法来完成小样本分类的任务。
3.1 Model Based方法
3.1.1
3.2 Metric Based方法
如果在 Few-shot Learning 的任务中去训练普通的基于 cross-entropy 的神经网络分类器,那么几乎肯定是会过拟合,因为神经网络分类器中有数以万计的参数需要优化。
相反,很多非参数化的方法(最近邻、K-近邻、Kmeans)是不需要优化参数的,因此可以在 meta-learning 的框架下构造一种可以端到端训练的 few-shot 分类器。该方法是对样本间距离分布进行建模,使得同类样本靠近,异类样本远离。下面介绍相关的方法。
3.2.1 孪生网络(Siamese Network)
如图 4 所示,孪生网络(Siamese Network)[4]通过有监督的方式训练孪生网络来学习,然后重用网络所提取的特征进行 one/few-shot 学习。
3.2.2 匹配网络(Match Network)
相比孪生网络,匹配网络(Match Network)[2]为支撑集和 Batch 集构建不同的编码器,最终分类器的输出是支撑集样本和 query 之间预测值的加权求和。
如图 5 所示,该文章也是在不改变网络模型的前提下能对未知类别生成标签,其主要创新体现在建模过程和训练过程上。对于建模过程的创新,文章提出了基于 memory 和 attention 的 matching nets,使得可以快速学习。
对于训练过程的创新,文章基于传统机器学习的一个原则,即训练和测试是要在同样条件下进行的,提出在训练的时候不断地让网络只看每一类的少量样本,这将和测试的过程是一致的。
3.2.3 原型网络(Prototype Network)
原型网络(Prototype Network)[5]基于这样的想法:每个类别都存在一个原型表达,该类的原型是 support set 在 embedding 空间中的均值。然后,分类问题变成在 embedding 空间中的最近邻。
如图 6 所示,c1、c2、c3 分别是三个类别的均值中心(称 Prototype),将测试样本 x 进行 embedding 后,与这 3 个中心进行距离计算,从而获得 x 的类别。
3.2.4 Relation Network(关系网络)
前面介绍的几个网络结构在最终的距离度量上都使用了固定的度量方式,如 cosine,欧式距离等,这种模型结构下所有的学习过程都发生在样本的 embedding 阶段。
而 Relation Network[6]认为度量方式也是网络中非常重要的一环,需要对其进行建模,所以该网络不满足单一且固定的距离度量方式,而是训练一个网络来学习(例如 CNN)距离的度量方式,在 loss 方面也有所改变,考虑到 relation network 更多的关注 relation score,更像一种回归,而非 0/1 分类,所以使用了 MSE 取代了 cross-entropy。
3.3 Optimization Based方法
3.3.1
参考资料
[1] 小样本学习(Few-shot Learning)综述
[2] 小样本学习遇上机器学习--------随笔记录
[3] Meta-Learning: Learning to Learn Fast
[4] 从 ICLR 2019 一览小样本学习最新进展!