Searching for A Robust Neural Ar

2019-07-25  本文已影响0人  斯文攸归

——百度

将搜索空间视为有向无环图,为该有向无环图设计可微采样器,该采样器可学习,可以由搜索得到的结构在验证集上的损失来优化,因此称之为:Gradient-based search using Differentiable Architecture Sampler,在CIFAR-10数据集上4 GPU hours可以完成一次搜索过程,达到2.82%的测试错误率和2.5M的参数量。

介绍

搜索一个鲁棒的神经单元(cell)而非整个网络,该单元包含许多变换特征的结构,一个神经网络包含许多这样的单元。下图表示了搜索过程,将一个单元的搜索空间表示为一个DAG(有向无环图),每个灰色节点表示为特征张量,由操作顺序命名。不颜色的边代表不同类型的操作,将某一节点转换为中间特征。同时,每个节点是所有前层节点中间转换特征的累加。在训练时,GDAS从整个DAG中采样一个子图,在子图中每个节点只接受所有前层节点的一个中间特征,具体地,在两个相邻节点的所有中间特征中,GDAS以可微的方式采样一种特征。由此,GDAS能端到端地以梯度下降的方式进行训练,来发现一个鲁棒的cell。

GDAS

GDAS的快主要来源于采样操作,一个DAG包含上百种参数化操作,有着上百万的参数量,直接优化整个DAG(DARTS)将带来两个缺点:1、在一个迭代步中更新大量的参数将耗费很长时间,导致搜索时间超过一天。2、同时优化不同的操作会使得它们相互竞争,例如,不同的操作可能会产生相反的结果。这些相反的操作结果会相互抵消而带来弥散,破坏两个相邻节点之间的信息流动和优化过程。为了解决这两个问题,GDAS在一次迭代中只采样一个子图,因此一次迭代只需要优化DAG的一个部分,加速了训练过程。

GDAS相较于先前的基于强化学习的方法(RL-based)和遗传算法的方法(EA-based)使得搜索过程可微,可以使用梯度下降法。对于强化学习和遗传算法,他们反馈的信息是通过长时间训练的轨迹来进行reward的,而GDAS则是通过损失来反馈的,而且在梯度下降法中,损失是一个连续的可以在每次迭代中给出的量。且GDAS中的采样过程是可以学习的。

方法

对于CNN,一个单元是全卷积的,将所有之前单元的输出作为输入,产生输出特征张量。将CNN中的单元表示为DAGG,包含一系列有序计算节点B,每个节点代表一个特征张量,由前面两个特征张量变换而来:

特征变换

其中,I_i,I_j,I_k分别代表第i,j,k个节点,f_{i,j},f_{i,k}分别表示来自候选操作集F中的两个操作函数。当计算节点数量B=4时,整个单元的节点有7个,I_1,I_2代表前面两个单元的输出,I_3,I_4,I_5,I_6代表计算节点。I_7代表该单元的输出张量,表示为I_7=concat(I_3,I_4,I_5,I_6)。在GDAS中,候选操作集合包含8种操作:恒等映射,零操作,3*3 depth-wise卷积,3*3 depth-wise空洞卷积,5*5 depth-wise空洞卷积,3*3 平均池化,3*3 最大池化(一如DARTS)。

同样搜索两种单元:正常单元和降采样单元,每个正常单元的操作步长为1 ,降采样单元的步长为2,一旦搜搜到所有正常单元和降采样单元,就将其堆叠为完整网络。对于CIFAR-10,堆叠N个正常单元作为一个Block。如下图:

网络结构

可微模型采样

定义神经结构为\alpha ,参数为w_{\alpha},NAS的目标是为了找到一个结构\alpha,实现当以最小化训练损失训练参数w_{\alpha}后,使得网络结构在验证集上的准确率最小化。数学表示:

优化问题

w_{\alpha}^*表示网络结构\alpha 的最佳权重,能实现训练损失最小化。将负的对数似然最为训练对象,D_T,D_V分别表示训练集合验证集。

一个网络结构\alpha包含许多同样的神经单元,该单元由搜索空间G中搜索而来,具体地,节点i,j之间,从候选操作集合F中采样一个变换函数,实际上是从一个离散概率分布\Upsilon _{i,j}中采样而来,在搜索过程中,计算单元中每个节点:

节点计算

离散概率分布\Upsilon _{i,j}是被一个可学习的概率质量函数表示的:

A_{i,j}^k是由K维可学习向量中的第k个元素,F_k表示候选操作集合F中第k个操作。因此K=\vert F \vert ,实际上A_{i,j}编码了相邻节点i,j之间的操作采样概率,因此,一个单元的采样分布表示为A_{i,j}的集合。

给定上两式,可以得到\alpha,w,即可计算训练集上的损失,但因为f_{i,j}采样于离散概率分布,因此梯度不能反传至A_{i,j},为了令方向传播能进行,使用Gumbel-Max的思想重新表达上式:

Gumbel-Max

其中,o_k独立同分布于Gumbel(0,1),o_k=-log(-log(u)),其中u服从0到1之间的均匀分布。h_{i,j}^k是向量h_{i,j}的第k个分量,W_{i,j}^k是节点i,j之间的操作F_k的参数权重。然后,以SoftMax函数来放松argmax,实际上就是Gumbel Softmax:

Gumbel Softmax

\tau 为温度系数,当其趋于零时,\tilde{h}_{i,j}^{k} =h_{i,j}^k。本文在前向传递时用argmax函数,在后向传播中用Gumbel softmax函数,这样就可以用梯度后向传播了。

训练

上述损失函数的主要挑战是学习一个结构\alpha ,为了避免计算高阶导,我们应用替代优化策略以迭代方式更新采样分布和所有函数W的权重。

Eq.(8):Loss的一般形式

该采样分布

A_{i,j}的集合编码而得到,参数

W_{i,j}^k的集合,表示所有单元所有操作的参数。

对于一个采样数据,首先采样结构\alpha ,计算网络输出(仅与w_{\alpha}有关)。

算法1:(alternative optimization strategy (AOS))

算法

结构

训练完成之后,需要从分布中得到最后的网络结构。每个节点i都与前T个节点有关,对于CNN,设置T=2,假设\Omega 是候选索引集,定义节点i,j之间的连接重要性:max_{k\in \Omega }Pr(f_{i,j}=F_k),对于每个节点i,保留先前节点中有最大重要性的2个连接,对于已经保留的节点i,j之间的连接,使用函数F_{argmax_{k\in \Omega }Pr(f_{i,j}=F_k)}来确定节点之间的操作。

本文固定降采样单元,仅仅搜索正常单元。设计的降采样单元如下:

实验

识别率基本与DARTS持平的情况下,搜索时间比它快5倍以上。

实验 实验
上一篇下一篇

猜你喜欢

热点阅读