Non-local_pytorch代码解读
参考https://zhuanlan.zhihu.com/p/33345791以及https://github.com/AlexHex7/Non-local_pytorch
代码总的模型框架
from torch import nn
# from lib.non_local_concatenation import NONLocalBlock2D
# from lib.non_local_gaussian import NONLocalBlock2D
from lib.non_local_embedded_gaussian import NONLocalBlock2D
# from lib.non_local_dot_product import NONLocalBlock2D
class Network(nn.Module):
def __init__(self):
super(Network, self).__init__()
self.convs = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(2),
NONLocalBlock2D(in_channels=32),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
NONLocalBlock2D(in_channels=64),
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.fc = nn.Sequential(
nn.Linear(in_features=128*3*3, out_features=256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(in_features=256, out_features=10)
)
def forward(self, x):
batch_size = x.size(0)
output = self.convs(x).view(batch_size, -1)
output = self.fc(output)
return output
if __name__ == '__main__':
import torch
img = torch.randn(3, 1, 28, 28)
net = Network()
out = net(img)
print(out.size())
model主框架是对于minst数据集的分类,只不过中间加入了Non-local模块,剩下的进入Non-local模块代码学习
先从上面图里的框架看起吧
non_local_embedded_gaussian
def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
super(_NonLocalBlockND, self).__init__()
assert dimension in [1, 2, 3]
self.dimension = dimension
self.sub_sample = sub_sample
self.in_channels = in_channels
self.inter_channels = inter_channels
if self.inter_channels is None:
self.inter_channels = in_channels // 2
if self.inter_channels == 0:
self.inter_channels = 1
if dimension == 3:
conv_nd = nn.Conv3d
max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
bn = nn.BatchNorm3d
elif dimension == 2:
conv_nd = nn.Conv2d
max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
bn = nn.BatchNorm2d
else:
conv_nd = nn.Conv1d
max_pool_layer = nn.MaxPool1d(kernel_size=(2))
bn = nn.BatchNorm1d
self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
if bn_layer:
self.W = nn.Sequential(
conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0),
bn(self.in_channels)
)
nn.init.constant_(self.W[1].weight, 0)
nn.init.constant_(self.W[1].bias, 0)
else:
self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0)
nn.init.constant_(self.W.weight, 0)
nn.init.constant_(self.W.bias, 0)
self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
if sub_sample:
self.g = nn.Sequential(self.g, max_pool_layer)
self.phi = nn.Sequential(self.phi, max_pool_layer)
init函数主要做了往常做的以及对bn的初始化,以及定义了theta和phi两个1*1卷积,sub_sample根据参数是否加入max_pooling
然后下面是主要的代码
def forward(self, x):
'''
:param x: (b, c, t, h, w)
:return:
'''
batch_size = x.size(0)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
f = torch.matmul(theta_x, phi_x)
f_div_C = F.softmax(f, dim=-1)
y = torch.matmul(f_div_C, g_x)
y = y.permute(0, 2, 1).contiguous()
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
z = W_y + x
return z
1.把hwt(3维),hw(2维) 放到一起,归为一个维度W
2.g_x=BWC,theta_x=BWC,phi_x=BCW,f=BWW,f_div_C=BWW
3.y=BWC-> BCW-> BChw
4.W_y 也是11卷积 ,bn可选择加,最后一个残差连接
从代码的角度来看,就是先用三次1*1卷积,然后其中两次进行相乘,然后softmax类似进行映射操作,形成一个WW大小的权重,然后用第三个再相乘,有点类似attention的操作,也就是加了层系数加权,其实这个操作也有点类似全连接,计算参数同样很大,多了一层相似性。
然后再来看non_local_concatenation的写法,主要罗列区别性
self.concat_project = nn.Sequential(
nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False),
nn.ReLU()
def forward(self, x):
'''
:param x: (b, c, t, h, w)
:return:
'''
batch_size = x.size(0)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
# (b, c, N, 1)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1)
# (b, c, 1, N)
phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1)
h = theta_x.size(2)
w = phi_x.size(3)
theta_x = theta_x.repeat(1, 1, 1, w)
phi_x = phi_x.repeat(1, 1, h, 1)
concat_feature = torch.cat([theta_x, phi_x], dim=1)
f = self.concat_project(concat_feature)
b, _, h, w = f.size()
f = f.view(b, h, w)
N = f.size(-1)
f_div_C = f / N
y = torch.matmul(f_div_C, g_x)
y = y.permute(0, 2, 1).contiguous()
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
z = W_y + x
return z
)
主要不同点在于形成权重的方式上,这里采用维度上连接的方式进行的,操作,由于维度不同,所以先要弄得相同,然后再连接,连接后1*1卷积进行降维,f_div_C = f / N,感觉没什么用。。。可能数值太大,压缩一下吧,毕竟W很大。。。
然后来看non_local_dot_product
def forward(self, x):
'''
:param x: (b, c, t, h, w)
:return:
'''
batch_size = x.size(0)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
f = torch.matmul(theta_x, phi_x)
N = f.size(-1)
f_div_C = f / N
y = torch.matmul(f_div_C, g_x)
y = y.permute(0, 2, 1).contiguous()
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
z = W_y + x
return z
相比embedded_gaussian,直接采用了点乘,形成W*W的权重,没有使用softmax
最后看下non_local_gaussian
self.phi = max_pool_layer
def forward(self, x):
'''
:param x: (b, c, t, h, w)
:return:
'''
batch_size = x.size(0)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
theta_x = x.view(batch_size, self.in_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
if self.sub_sample:
phi_x = self.phi(x).view(batch_size, self.in_channels, -1)
else:
phi_x = x.view(batch_size, self.in_channels, -1)
f = torch.matmul(theta_x, phi_x)
f_div_C = F.softmax(f, dim=-1)
y = torch.matmul(f_div_C, g_x)
y = y.permute(0, 2, 1).contiguous()
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
z = W_y + x
return z
相比embedded_gaussian,把1个1*1卷积换成了max_pooling层,但这两种区别,我还不是很清楚,等看了论文再来补充。