计算机视觉笔记

DARTS代码阅读

2019-08-05  本文已影响0人  高斯纯牛奶

0x00 背景知识

先放上一篇综述文章,对于理解NAS(网络结构搜索)的问题有很大的帮助:https://blog.csdn.net/c9Yv2cf9I06K2A9E/article/details/82321884
另外,DARTS搜索,强烈建议先看下inception的网络结构和nasnet的论文,DARTS的论文基础是建立在之上的,某种程度上可以看做是对nasnet的优化。

0x01 搜索思路

基于前人的经验(inception/nasnet),DARTS使用cell作为模型结构搜索的基础单元,所学习的单元堆叠成卷积网络,也可以递归连接形成递归网络。
cell内节点间先默认所有可能的操作连接,每个连接初始化权重参数值,结构搜索也就是训练这些权重参数,最终两节点间选取权重最大的操作作为最终结构参数。

训练过程中,交替训练网络结构参数和网络参数。

0x02 代码定义

genotype结构定义

normal=[(‘sep_conv_3x3’, 0), (‘sep_conv_3x3’, 1), (‘sep_conv_3x3’, 0), (‘sep_conv_3x3’, 1), (‘sep_conv_3x3’, 1), (‘skip_connect’, 0), (‘skip_connect’, 0), (‘dil_conv_3x3’, 2)], normal_concat=[2, 3, 4, 5]

取了genotype里的一个normal cell的定义及其对应的cell结构图首先说明下,这个定义的解释。DARTS搜索的也就是这个定义。
normal定义里(‘sep_conv_3x3’, 1)的0,1,2,3,4,5对应到图中的红色字体标注的。
从normal文字定义两个元组一组,映射到图中一个蓝色方框的节点(这个是作者搜索出来的结构,结构不一样,对应关系不一定是这样的)
sep_conv_xxxx表示操作,0/1表示输入来源
(‘sep_conv_3x3’, 1), (‘sep_conv_3x3’, 0) —-> 节点0
(‘sep_conv_3x3’, 0), (‘sep_conv_3x3’, 1) —-> 节点1
(‘sep_conv_3x3’, 1), (‘skip_connect’, 0) —-> 节点2
(‘skip_connect’, 0), (‘dil_conv_3x3’, 2) —-> 节点3
normal_concat=[2, 3, 4, 5] —-> cell输出c_{k}

DARTS搜索NOTE

首先明确,DARTS搜索实际只搜cell内结构,整个模型的网络结构是预定好的,比如多少层,网络宽度,cell内几个节点等;
在构建搜索的网络结构时,有几个特别的地方:
1.预构建cell时,采用的一个MixedOp:包含了两个节点所有可能的连接(genotype中的PRIMITIVES);
2.初始化了一个alphas矩阵,网络做forward时,参数传入,在cell里使用,搜索过程中所有可能连接都在时,计算mixedOp的输出,采用加权的形式。
3.训练过程对train数据每个step又切成两份: train和validate, train用来训练网络参数,validate用来训练结构参数。

0x03 关键代码片段

以下把代码中一些关键的,影响到理解DARTS的地方说明一下:

  logits = model(input)
  loss = criterion(logits, target)
  loss.backward()
  nn.utils.clip_grad_norm(model.parameters(), args.grad_clip)
  optimizer.step()

这里就是论文里近似后的交叉梯度下降,其中architect.step()是结构参数weights的梯度下降,optimizer.step()是网络参数的梯度下降。

class MixedOp(nn.Module):
  def __init__(self, C, stride):
    super(MixedOp, self).__init__()
    self._ops = nn.ModuleList()
    for primitive in PRIMITIVES:
      op = OPS[primitive](C, stride, False)
      if 'pool' in primitive:
        op = nn.Sequential(op, nn.BatchNorm2d(C, affine=False))
      self._ops.append(op)
  def forward(self, x, weights):
    return sum(w * op(x) for w, op in zip(weights, self._ops)) # weighted op

这个是MixedOp,两节点间操作把PRIMITIVES里定义的所有操作都连接上,计算输出时利用传入的weights进行加权。

def forward(self, s0, s1, weights):
    s0 = self.preprocess0(s0)
    s1 = self.preprocess1(s1)
    states = [s0, s1]
    offset = 0
    for i in range(self._steps):
      s = sum(self._ops[offset+j](h, weights[offset+j]) for j, h in enumerate(states)) # all nodes before can be input, mixop.
      offset += len(states) #0, 2, 5, 9
      states.append(s)
    return torch.cat(states[-self._multiplier:], dim=1)

self.ops[], 实际是14(2+3+4+5)个MixedOp,2+3+4+5的解释,对于第一个内部节点,有两个可能的输入(c{k-1}, c_{k-2}),对于第二个内部节点,有三个可能的输入(两个同节点1,另加上第一个节点)……
代码里,weights[],也是一个长度14的list,前2个对应到第一个节点的两个输入的权重,第3~5这3个元素对应到第二个节点的三个输入的权重……这就是上面代码里offset的作用

class Architect(object):
  def __init__(self, model, args):
    self.network_momentum = args.momentum
    self.network_weight_decay = args.weight_decay
    self.model = model
    self.optimizer = torch.optim.Adam(self.model.arch_parameters(),   #arch_parameters, 
        lr=args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay) 

需要注意的是Architect里optimizer优化器的参数是model.arch_parameters(), 这个对应到的是model_search.py里定义的._arch_parameters,及初始化的各节点连接的权重。
def _initialize_alphas(self):
k = sum(1 for i in range(self._steps) for n in range(2+i)) # 2+i, 2 for two inputs, i=0,1,2,3, nodes before this. 2+3+4+5
num_ops = len(PRIMITIVES)

self.alphas_normal = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)
    self.alphas_reduce = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)
    self._arch_parameters = [
      self.alphas_normal,
      self.alphas_reduce,
    ]

def _parse(weights):
      #  weights: [2 + 3 + 4 + 5][len(PRIMITIVES)]
      gene = []
      n = 2
      start = 0
      for i in range(self._steps): #ch: steps = 4
        end = start + n 
        print('start=', start, 'end=', end, 'n=', n)
        W = weights[start:end].copy()
        print(W) # ch: add
        # chenhua: for x, -max(W[x][...]), W[][] is the parameters for architect. lambda elect out the OP weights most.
        edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[:2]
        print(edges)
        for j in edges: #ch: j, edges mean op, all possible ops between two node
          print(j)
          k_best = None
          for k in range(len(W[j])):  #ch: k, the weights for possible connection?
            if k != PRIMITIVES.index('none'):
              if k_best is None or W[j][k] > W[j][k_best]:
                print('W[j][k]=', W[j][k], 'W[j][k_best]=', W[j][k_best])
                k_best = k
          gene.append((PRIMITIVES[k_best], j))  #ch: find ????
        start = end
        n += 1
      return gene
    # ch: alphas_xxx, parameters for architect??
    gene_normal = _parse(F.softmax(self.alphas_normal, dim=-1).data.cpu().numpy())
    gene_reduce = _parse(F.softmax(self.alphas_reduce, dim=-1).data.cpu().numpy())
    concat = range(2+self._steps-self._multiplier, self._steps+2) #ch: step=4, mltiplier=3
    print('concat', concat)
    genotype = Genotype(
      normal=gene_normal, normal_concat=concat,
      reduce=gene_reduce, reduce_concat=concat
    )
    print('genotype=', genotype)
    return genotype

搜索过程中搜索出的结果(节点间的op)的打印,就是靠这个函数。
核心是找出两个节点间不为none的所有ops中权重最大的,就是最终的结果。
注意:weights[][]的size是[2 + 3 + 4 + 5][len(PRIMITIVES)]

参考链接

  1. https://cloud.tencent.com/developer/article/1348049
  2. https://blog.csdn.net/srdlaplace/article/details/80863346
  3. https://www.jiqizhixin.com/articles/2018-06-27-6
上一篇 下一篇

猜你喜欢

热点阅读