经典回顾:知识蒸馏(2015)整理中

2022-08-18  本文已影响0人  Valar_Morghulis

提高几乎任何机器学习算法性能的一种非常简单的方法是在相同的数据上训练许多不同的模型,然后对它们的预测进行平均。不幸的是,使用整个模型集合进行预测非常麻烦,并且可能计算成本太高,无法部署到大量用户,特别是如果单个模型是大型神经网络。Caruana和他的合作者已经证明,可以将集成中的知识压缩到一个更易于部署的模型中,我们使用不同的压缩技术进一步开发了这种方法。我们在MNIST上获得了一些令人惊讶的结果,我们表明,通过将模型集合中的知识提取到单个模型中,我们可以显著改进大量使用的商业系统的声学模型。我们还引入了一种新型的集成,由一个或多个完整模型和许多专家模型组成,这些模型学习区分完整模型混淆的细粒度类。与混合专家不同,这些专家模型可以快速并行地训练。

1导言

许多昆虫的幼虫形态优化为从环境中提取能量和营养,而成虫形态则完全不同,优化为满足不同的旅行和繁殖要求。在大规模机器学习中,我们通常在训练阶段和部署阶段使用非常相似的模型,尽管它们的要求非常不同:对于语音和对象识别等任务,训练必须从非常大、高度冗余的数据集中提取结构,但它不需要实时操作,并且可以使用大量计算。然而,部署到大量用户对延迟和计算资源的要求要严格得多。与昆虫的类比表明,如果能够更容易地从数据中提取结构,我们应该愿意训练非常繁琐的模型。笨重的模型可以是单独训练的模型的集合,也可以是使用非常强的正则化器(如dropout)训练的单个非常大的模型[9]。一旦笨重的模型经过训练,我们就可以使用另一种训练,我们称之为“蒸馏”,将知识从笨重模型迁移到更适合部署的小模型。Rich Caruana及其合作者已经首创了这一策略的一个版本[1]。在他们的重要论文中,他们令人信服地证明了由大量模型集成获得的知识可以迁移到单个小模型中。

可能妨碍对这种非常有前途的方法进行更多研究的一个概念性障碍是,我们倾向于用学习的参数值识别训练模型中的知识,这使得我们很难看到如何改变模型的形式,但保持相同的知识。知识的一个更抽象的观点是,它是一个从输入向量到输出向量的学习映射,从而使知识从任何特定的实例化中解放出来。对于学习区分大量类别的笨重模型,正常的训练目标是最大化正确答案的平均对数概率,但学习的副作用是,经过训练的模型将概率分配给所有错误答案,即使这些概率非常小,其中一些概率也比其他概率大得多。不正确答案的相对概率告诉我们很多关于笨重的模型是如何概括的。例如,宝马汽车的图像被误认为垃圾车的可能性很小,但这种错误仍然比误认为胡萝卜的可能性大很多倍。

一般认为,用于训练的目标函数应尽可能准确地反映用户的真实目标。尽管如此,当真正的目标是很好地推广到新数据时,通常训练模型以优化训练数据的性能。显然,训练模型更好地进行归纳,但这需要关于正确归纳方法的信息,而这些信息通常不可用。然而,当我们将知识从一个大模型提取到一个小模型中时,我们可以训练小模型以与大模型相同的方式进行泛化。如果笨重的模型能够很好地进行泛化,例如,因为它是不同模型的大集合的平均值,那么以相同方式进行泛化训练的小模型在测试数据上的表现通常会比在用于训练集合的同一训练集上以正常方式训练的小模式好得多。

将笨重模型的泛化能力转化为小模型的一个明显方法是使用笨重模型产生的类概率作为训练小模型的“软目标”。对于这个迁移阶段,我们可以使用相同的训练集或单独的“迁移”集。当笨重模型是简单模型的大集合时,我们可以使用它们各自预测分布的算术或几何平均值作为软目标。当软目标具有高熵时,它们在每个训练案例中提供的信息比硬目标多得多,训练案例之间的梯度变化也少得多,因此小模型通常可以在比原始笨重模型少得多的数据上进行训练,并使用更高的学习率。

对于像MNIST这样的任务,笨重的模型几乎总是以非常高的置信度生成正确的答案,关于学习函数的大部分信息存在于软目标中非常小的概率比中。例如,2的一个版本的概率为10−6是3和10−9是7,而对于另一个版本,它可能是相反的。这是有价值的信息,定义了数据上丰富的相似性结构(即,它表示哪些2看起来像3,哪些看起来像7),但在迁移阶段,它对交叉熵成本函数的影响非常小,因为概率非常接近于零。卡鲁阿纳和他的合作者通过使用Logit(最终softmax的输入)而不是softmax产生的概率作为学习小模型的目标来规避这个问题,他们最小化了繁琐模型产生的Logit和小模型产生的logit之间的平方差。我们更通用的解决方案,称为“蒸馏”,是提高最终softmax的温度,直到笨重的模型产生一组合适的软目标。然后,我们在训练小模型以匹配这些软目标时使用相同的高温。我们稍后将说明,匹配繁琐模型的Logit实际上是蒸馏的一个特例。

用于训练小模型的传递集可以完全由未标记数据[1]组成,或者我们可以使用原始训练集。我们发现,使用原始训练集效果良好,特别是如果我们在目标函数中添加一个小项,鼓励小模型预测真实目标,并匹配笨重模型提供的软目标。通常情况下,小模型无法精确匹配软目标,而在正确答案的方向上犯错会有帮助。

2.蒸馏

神经网络通常通过使用“softmax”输出层来产生类概率,该输出层通过将zi与其他logit进行比较,将为每个类计算的logit zi转换为概率qi。

式中,T是通常设置为1的温度。使用更高的T值会在不同类别上产生更软的概率分布。

在最简单的蒸馏形式中,通过在迁移集中对蒸馏模型进行训练,并在迁移集中为每种情况使用软目标分布,将知识迁移到蒸馏模型,该软目标分布是通过使用在其softmax中具有高温的笨重模型生成的。训练蒸馏模型时使用相同的高温,但训练后使用的温度为1。

当已知所有或部分传递集的正确标签时,还可以通过训练提取的模型来生成正确的标签,从而显著改进该方法。一种方法是使用正确的标签来修改软目标,但我们发现更好的方法是简单地使用两个不同目标函数的加权平均值。第一个目标函数是与软目标的交叉熵,该交叉熵是使用蒸馏模型的softmax中的高温计算的,该高温与用于从笨重模型生成软目标的高温相同。第二个目标函数是具有正确标签的交叉熵。这是使用蒸馏模型的softmax中完全相同的logits计算的,但温度为1。我们发现,通常通过对第二个目标函数使用条件较低的权重来获得最佳结果。由于软目标产生的梯度大小为1/t2,因此在使用硬目标和软目标时,将其乘以t2非常重要。这确保了如果在实验元参数时改变蒸馏温度,硬目标和软目标的相对贡献大致保持不变。

2.1匹配逻辑是蒸馏的特例

迁移集中的每种情况都对提取模型的每个logit,zi贡献了交叉熵梯度dC/dzi。如果笨重模型具有logits vi,该logits vi产生软目标概率pi,并且迁移训练在温度T下进行,则该梯度由下式给出:

如果温度与logits值相比较高,我们可以近似:

如果我们现在假设Logit对于每个迁移情况分别为零均值

因此,在高温极限下,蒸馏相当于最小化1/2(zi− vi)2,前提是每个分动箱的Logit分别为零均值。在较低的温度下,蒸馏很少关注比平均值负得多的匹配对数。这可能是有利的,因为这些Logit几乎完全不受用于训练笨重模型的成本函数的约束,因此它们可能非常嘈杂。另一方面,非常消极的Logit可能传达有关通过繁琐模型获得的知识的有用信息。这些效应中哪一个占主导地位是一个经验问题。我们表明,当蒸馏模型太小而无法捕获繁琐模型中的所有知识时,中间温度工作得最好,这强烈表明忽略较大的负对数可能会有所帮助。

3.MNIST的初步实验

为了观察蒸馏效果如何,我们在所有60000个训练案例中训练了一个具有两个隐藏层的大型神经网络,其中包含1200个校正线性隐藏单元。如[5]中所述,使用衰减和权重约束对网络进行了强正则化。dropout可以被视为一种训练共享权重的指数级大模型集合的方法。此外,输入图像在任何方向上抖动多达两个像素。该网络实现了67个测试误差,而具有800个校正线性隐藏单元的两个隐藏层的较小网络实现了146个误差。但是,如果仅通过添加额外的任务来匹配大网络在20℃温度下产生的软目标,从而使较小的网络正则化,那么它将获得74个测试误差。这表明,软目标可以将大量知识迁移到提取的模型中,包括如何从翻译后的训练数据中学习概括的知识,即使迁移集不包含任何翻译。

当蒸馏网在其两个隐藏层中的每一层都有300个或更多单位时,所有高于8的温度都给出了相当相似的结果。但当这一温度从根本上降低到每层30个单位时,2.5到4的温度比更高或更低的温度工作得好得多。

然后,我们尝试从迁移集中省略数字3的所有示例。因此,从蒸馏模型的角度来看,3是一个从未见过的神话数字。尽管如此,蒸馏模型仅产生206个测试误差,其中133个在测试集中的1010个三上。大多数错误是由3类学生的学习偏差太低这一事实造成的。如果该偏差增加3.5(这优化了测试集的整体性能),则提取的模型会产生109个误差,其中14个误差在3秒内。因此,在正确的偏差下,尽管在训练过程中从未见过3,但蒸馏模型在测试3中获得了98.6%的正确率。如果迁移集仅包含训练集的7和8,则提取的模型会产生47.3%的测试误差,但当7和8的偏差减少7.6以优化测试性能时,这将下降到13.2%的测试错误。

4个语音识别实验

在本节中,我们研究了在自动语音识别(ASR)中使用的深度神经网络(DNN)声学模型的置乱效果。我们表明,我们在本文中提出的蒸馏策略实现了理想的效果,即将模型集合蒸馏成单个模型,该模型的工作性能明显优于直接从相同训练数据学习的相同大小的模型。

目前,最先进的ASR系统使用DNN将波形特征的(短)时间上下文映射到隐马尔可夫模型(HMM)离散状态的概率分布[4]。更具体地说,DNN每次在三个电话状态的集群上产生概率分布,然后解码器找到通过HMM状态的路径,这是使用高概率状态和产生语言模型下可能的转录之间的最佳折衷。

虽然可以(并且期望)以这样的方式训练DNN,即通过在所有可能路径上边缘化来考虑解码器(以及因此语言模型),通常,通过(局部)最小化网络预测和标签之间的交叉熵,训练DNN执行逐帧分类,通过强制对齐每个观测的基本真值状态序列:

其中θ是我们的声学模型P的参数,该模型将时间t,st的声学观测值映射到“正确”HMM状态ht的概率P(ht|st;θ′),该概率由与正确单词序列的强制对齐确定。该模型采用分布式随机梯度下降法进行训练。

我们使用一个具有8个隐藏层的架构,每个隐藏层包含2560个校正线性单元,最后一个softmax层包含14000个标签(HMM)。输入是26帧40个Mel缩放滤波器组系数,每帧提前10毫秒,我们预测第21帧的HMM状态。参数总数约为85M。这是Android语音搜索使用的声学模型的一个稍微过时的版本,应该被视为一个非常强大的基线。为了训练DNN声学模型,我们使用了约2000小时的英语口语数据,这产生了约7亿个训练示例。该系统在我们的开发集上实现了58.9%的帧精度和10.9%字错误率

4.1结果

我们训练了10个单独的模型来预测P(ht|st;θ),使用与基线完全相同的架构和训练过程。用不同的初始参数值随机初始化模型,我们发现这在训练模型中创造了足够的多样性,使集合的平均预测显著优于单个模型。我们探索了通过改变每个模型看到的数据集来增加模型的多样性,但我们发现这不会显著改变我们的结果,因此我们选择了更简单的方法。对于蒸馏,我们尝试了[1,2,5,10]的温度,并对硬目标的交叉熵使用了相对权重0.5,其中粗体字体表示表1中使用的最佳值。

表1显示,事实上,我们的蒸馏方法能够从训练集中提取比简单使用硬标签训练单个模型更有用的信息。通过使用10个模型的集成实现的帧分类精度的提高中超过80%被迁移到蒸馏模型,这类似于我们在MNIST的初步实验中观察到的改进。由于目标函数不匹配,集成对WER的最终目标(在23K字测试集上)的改善较小,但集成实现的WER改善再次迁移到蒸馏模型。

我们最近注意到了通过匹配已训练的较大模型的类别概率来学习小声学模型的相关工作[8]。然而,他们使用一个大的未标记数据集在1的温度下进行蒸馏,他们的最佳蒸馏模型仅将小模型的错误率减少了大模型和小模型错误率之间差距的28%。

5.在非常大的数据集上训练专家团队

训练模型集合是利用并行计算的一种非常简单的方法,通常的反对意见是,集合在测试时需要太多计算,可以通过使用蒸馏来解决。然而,对集成还有另一个重要的反对意见:如果单个模型是大型神经网络,且数据集非常大,则训练时所需的计算量过大,即使很容易并行化。

在本节中,我们给出了这样一个数据集的示例,并展示了学习专家模型如何减少学习集成所需的总计算量,每个模型都关注于类的不同可混淆子集。专注于进行细粒度区分的专家的主要问题是它们很容易过拟合,我们描述了如何通过使用软目标来防止这种过拟合。

5.1 JFT数据集

JFT是一个内部谷歌数据集,拥有1亿张带有15000个标签的标签图像。当我们进行这项工作时,谷歌的JFT基线模型是一个深度卷积神经网络[7],已经在大量核心上使用异步随机梯度下降法训练了大约六个月。该训练使用了两种类型的并行[2]。首先,有许多神经网络的复制品运行在不同的核集上,并从训练集中处理不同的小批量。每个副本计算其当前小批量上的平均梯度,并将此梯度发送到分片参数服务器,该服务器将返回参数的新值。这些新值反映了参数服务器自上次向复制副本发送参数以来接收到的所有梯度。其次,通过在每个核上放置不同的神经元子集,将每个复制品分散在多个核上。集成训练仍然是第三种类型的并行,可以围绕其他两种类型进行,但前提是有更多的核心可用。等待几年来训练一组模型不是一种选择,因此我们需要一种更快的方法来改进基线模型。

5.2专业模型

当类的数量非常大时,笨重的模型是一个集合,其中包含一个对所有数据进行训练的通才模型和许多“专家”模型,每个模型都是对数据进行训练,这些数据高度丰富,来自非常容易混淆的类子集(如不同类型的蘑菇)。通过将它不关心的所有类合并到一个垃圾箱类中,可以使此类专家的softmax更小。

为了减少过度拟合并分担学习低级特征检测器的工作,每个专家模型都用通才模型的权重初始化。然后,通过训练专家对这些权重进行轻微修改,其中一半样本来自其特殊子集,一半样本来自训练集的其余部分。训练后,我们可以通过将垃圾箱类的logit增加专家类过采样比例的对数来校正有偏差的训练集。

5.3为专家分配课程

为了为专家导出对象类别的分组,我们决定将重点放在我们整个网络经常混淆的类别上。尽管我们可以计算混淆矩阵并将其用作找到此类聚类的方法,但我们选择了一种更简单的方法,即不需要真正的标签来构造聚类。

特别是,我们将聚类算法应用于我们的通才模型预测的协方差矩阵,以便经常一起预测的一组类S m将用作我们的一个专业模型m的目标。我们将在线版本的K均值算法应用于协方差矩阵的列,并获得合理的聚类(如表2所示)。我们尝试了几种产生类似结果的聚类算法。

5.4通过专家团队进行推理

在研究提取专家模型时会发生什么之前,我们想看看包含专家的团队表现如何。除了专家模型,我们总是有一个通才模型,这样我们可以处理没有专家的类,这样我们就可以决定使用哪些专家。给定输入图像x,我们分两步进行top one分类:

步骤1:对于每个测试用例,我们根据通才模型找到n个最可能的类。将这组类称为k。在我们的实验中,我们使用n=1。

步骤2:然后我们取所有专家模型m,其可混淆类的特殊子集S m与k有一个非空交集,并将其称为专家活动集Ak(请注意,该集可能为空)。然后,我们找到所有类上的全概率分布q,其最小化:

其中KL表示KL散度,Pg表示专家模型或通才全模型的概率分布。分布pm是m的所有专业类加上单个垃圾箱类的分布,因此当计算其与全q分布的KL发散时,我们将全q分布分配给m垃圾箱中所有类的所有概率求和。

等式5没有一般的闭式解,尽管当所有模型为每一类产生单一概率时,解是算术平均值或几何平均值,这取决于我们是使用KL(p,q)还是KL(q,p))。我们将q=Softmax(z)(T=1)参数化,并使用梯度下降优化logits z w.r.T.等式5。注意,必须对每个图像执行此优化。

5.5结果

从训练有素的基线全网络开始,专家们的训练速度非常快(JFT需要几天而不是几周)。此外,所有专家都经过完全独立的训练。表3显示了基线系统和与专家模型相结合的基线系统的绝对测试精度。有了61个专家模型,总体测试精度相对提高了4.4%。我们还报告了条件测试精度,即仅考虑属于专家类的示例,并将我们的预测限制在该类子集上的精度。

对于我们的JFT专家实验,我们训练了61名专家模型,每个模型有300个班(加上垃圾箱班)。因为专家的类集合不是不相交的,所以我们经常有多个专家覆盖一个特定的图像类。表4显示了测试集示例的数量、使用专家时位置1处正确示例数量的变化以及JFT数据集top1精度的相对提高百分比,按涵盖该类的专家数量细分。当我们有更多的专家覆盖一个特定的类时,准确性的提高会更大,这一总体趋势让我们感到鼓舞,因为训练独立的专家模型非常容易并行化。

6个软目标作为正则化器

我们关于使用软目标而不是硬目标的主要主张之一是,许多有用的信息可以在软目标中携带,而这些信息不可能用单个硬目标编码。在本节中,我们通过使用少得多的数据来拟合前面描述的基线语音模型的85M参数,证明了这是一个非常大的影响。表5显示,只有3%的数据(约2000万个示例),使用硬目标训练基线模型会导致严重的过度拟合(我们提前停止了,因为在达到44.5%后精度会急剧下降),而使用软目标训练的同一模型能够恢复整个训练集中的几乎所有信息(约2%)。更值得注意的是,我们不必提前停止:具有软目标的系统简单地“收敛”到57%。这表明,软目标是一种非常有效的方式,可以将根据所有数据训练的模型发现的规律传递给另一个模型。

6.1使用软目标防止专家过度拟合

我们在JFT数据集上的实验中使用的专家将他们所有的非专家类合并为一个垃圾箱类。如果我们允许专家在所有课程中都有一个完整的softmax,那幺可能有一个比提前停止更好的方法来防止他们过度拟合。专家在其特殊课程中接受了高度丰富的数据训练。这意味着它的训练集的有效大小要小得多,并且它有很强的倾向于过度适应它的特殊课程。这个问题不能通过使专家更小来解决,因为这样我们就失去了建模所有非专家类所获得的非常有用的传递效应。

我们使用3%的语音数据进行的实验强烈表明,如果专家是用通才的权重初始化的,那么除了用硬目标进行训练外,我们还可以通过用非特殊类的软目标对其进行训练,使其保留几乎所有关于非特殊类的知识。软目标可以由通才提供。我们目前正在探索这种方法。

7.与专家混合的关系

使用经过数据子集训练的专家与混合专家[6]有一些相似之处,后者使用门控网络计算将每个示例分配给每个专家的概率。在专家学习处理分配给他们的示例的同时,选通网络正在学习根据专家对该示例的相对鉴别性能选择将每个示例分配给哪个专家。使用专家的鉴别性能来确定学习分配比简单地对输入向量进行聚类并为每个聚类分配一名专家要好得多,但这使得训练难以并行化:首先,每个专家的加权训练集不断变化,这取决于所有其他专家;其次,选通网络需要比较同一示例中不同专家的性能,以了解如何修改其分配概率。这些困难意味着专家的混合很少用于可能最有益的领域:包含明显不同子集的庞大数据集的任务。

将多个专家的训练并行化要容易得多。我们首先训练一个通才模型,然后使用混淆矩阵来定义专家训练的子集。一旦定义了这些子集,专家就可以完全独立地训练。在测试时,我们可以使用来自通才模型的预测来决定哪些专家是相关的,并且只需要运行这些专家。

8讨论

我们已经证明,提取对于将知识从集合或大型高度正则化模型迁移到较小的提取模型非常有效。在MNIST上,即使用于训练蒸馏模型的传递集缺少一个或多个类的示例,蒸馏也能非常好地工作。对于Android语音搜索所使用的深度声学模型,我们已经证明,通过训练深度神经网络集合实现的几乎所有改进都可以提炼为相同大小的单个神经网络,这更容易部署。

对于真正的大型神经网络,即使训练一个完整的集成也是不可行的,但我们已经证明,通过学习大量的专家网络,每个专家网络都可以学习区分高度易混淆的集群中的类,可以显著提高经过长时间训练的单个真正大型网络的性能。我们还没有证明我们可以将专家中的知识提取回单个大网络中。

上一篇下一篇

猜你喜欢

热点阅读