论文 | NeurIPS2019 Meta-Weight-Net
一 写在前面
未经允许,不得转载,谢谢~~~
今天这篇paper是NeurIPS2019的一篇paper,虽然时间有点久了,但是看完paper还是有觉得值得借鉴的地方,还是简单记录一下📝。
- 出处:NeurIPS2019
- title: Meta-Weight-Net: Learning an Explicit Mapping For SampleWeighting
- link: https://arxiv.org/pdf/1902.07379.pdf](https://arxiv.org/pdf/1902.07379.pdf
二 主要内容
2.1 backgrounds
deep learning容易对biased data产生过拟合的现象。
这里作者重点归纳了两种biased data情况:
- noisy data 标签有噪声数据
- long-tail data 长尾分布数据
这种过拟合自然会导致模型的生成泛化能力受到影响,而为了解决这个问题的一个思路就是进行sample reweighting,也就是对不同的样本设置不同的权重。 那reweighting的方法本质要学习的就是从不同样本到权重之间的映射关系,然后通过最小化加权之后的损失函数来优化模型参数。
2.2 related work
目前主要的sample reweighting方法可以分为两大类:
- 以focal loss为代表:
- 单样本的loss越大 --> 认为这个样本更难分辨 --> 增加这个样本的loss权重;
- 经典方法包括focal loss,AdaBoost,hard negative mining;
- 这类方法主要适合用于解决long-tail数据,使得分布少的类别能拥有更高的权重;
- 以SPL为代表:
- 单样本的loss越小 --> 认为该样本的标签可信度更高 --> 增加这个样本的loss权重;
- 经典方法包括SPL,iterative reweighting,以及其他变种方法;
- 这类方法适合用于解决noisy data问题,使得标签正确的样本拥有更高的权重;
下图以focal loss和SPL为例,直观给出了两类方法的差别,focal loss递增,SPL递减。
2.3 motivation
作者首先总结了现有方法的两大缺点:
1) 在现实无法预知data具体分布(long-tail还是noisy)的情况下,不知道要选递增型还是递减型。更何况,现实中可能出现的是long-tail并且noisy的数据分布;
2) 不管是哪一类方法,都需要超参数。
针对以上两点,该文的motivation就是能否设计一个自适应的且不需要超参数的reweighting方法,即找到一种从loss到weight的映射关系。
三 文章方法 Meta-Weighting-Net (MW-Net)
3.1 key idea
为了提出这样一个自适应的且不需要超参数的reweighting方法,文章的主要想法是用MLP来充当weight fucntion的作用,即让MLP自动学习从loss到weight之间的映射关系。然后用unbiased meta data来引导MLP的参数学习。
如下图所示,文章确实可以做到可以同时处理不同分布的数据(long-tail/noisy)。
3.2 具体方法
记整个分类网络为, 用于预测样本loss权重的MLP网络为, 网络的整体训练过程如下图:
可以重点关注箭头的颜色,红色的表示的是meta-weight-net的参数更新过程,而黑色的表示的整体分类网络的参数更新过程。对于时间t而言,最重要的几个步骤如下:
1) 对于分类网络的参数, 用从训练集中采出的minibatch data进行网络参数的更新,得到, 注意这里是暂时更新的,并没有替换原来的参数,可以理解为是一个临时变量。(图中step5)
2) 对于MLP网络为,用当前的预测得到的loss作为MLP网络的输入,得到输出的loss weights,用meta-dataset构建出来的minibatch data更新参数, 得到t+1时刻的,替换原来中的网络参数。(图中step6)
3) 用t+1时刻的和t时刻的, 再次用训练集中采出的minibatch data进行网络参数的更新,得到,这次的才真正作为t+1时刻的, 替换原来中的网络参数。(图中step7)
具体的公式可能看起来稍微有点复杂,但其实就是SGD在mini-batch上的优化。
最终的伪代码如下所示:
四 写在最后
整个思路还是比较巧妙的,而且之前的实验结果图也确实验证了方法能对不同分布的数据都有效。
目前还存在两点问题:
1) meta-dataset具体什么怎么构造的,为什么在更新MLP的时候不能用正常的mini-batch,而要用meta-dataset;
2) 参数更新为什么一定要分3步,直接a)更新;2)更新是不可以的吗
太细节的地方可能没有get到,欢迎知道的小伙伴多多交流。