Spatial Transform Networks学习笔记及p
卷积神经网络是一种很强大的神经网络,但是它的一个不足的之处在于,当输入数据在空间上发生变化时,结果很容易受到影响。如果神经网络可以动态地变换输入图片,包括平移、旋转、缩放等,那么网络就不仅能找到最相关的区域,还能把它变换到最合适的位置。这就是Jaderberg等人提出的 Spatial Transform Networks。
1. 图像变换
放射变换指的是一个向量通过一次线性变换和一次平移,变换到了另一个向量,公式为:
其中:
原向量
变换之后的向量
位移向量
仿射变换矩阵
2D变换:
缩放:
拉伸:
水平方向(y轴不变)
垂直方向(x轴不变)
旋转:以原点为中心旋转
翻转:
以x为轴翻转:
以y为轴翻转:
2D仿射变换矩阵加上平移,可以组成一个6个参数的矩阵
齐次坐标:加上1作为第三个轴
变换矩阵则变为:
使用齐次坐标的优点是,可以通过乘以各个矩阵将任意数量的仿射变换组合为一个。
2. 双线性插值
当图片经过变换之后,变换得到的坐标可能不是整数,所以要把坐标通过双线性插值变为整数。
假设想要知道P点的颜色,已知P点附近的四个点的值,先在x轴上进行插值得到R1,R2,然后在y轴进行插值,得到P点处的值(公式以后再补)。
3. Spatial Transform Networks
对于一个图片分类的神经网络,我们希望它能识别出同一类中不同大小、不同视角、不同形状/形变的图片,简单粗暴一点的想法是在训练的时候加入各种大小角度的图片,让模型学会辨别不同角度里的同一类型。另一个想法就是在模型学习分类图片之前,这些图片就都被旋转缩放到了合适的大小,这样就能减少分类时的工作量,让模型更加快乐。
Spatial Transformer是一个可微分的模块,在前向过程中把一个输入特征图经过特定的变换,得到输出特征图。主要分为三部分:定位网络、网格生成器和采样器。
定位网络得到宽W、高H、通道C的特征图U,输出为, 是被用在输入特征图上的变换的参数,对于仿射变换来说,就是之前提过的有6个参数的矩阵M。
, 定位网络的函数可以说全连接层或者卷积层或者其他的,只要能保证输出的是。
网格生成器的输出是一个参数化的采样网格,这个网格是输出特征图V的网格,首先生成一个基于输入特征图的网格,然后应用上一步得到的变换得到输出特征图对应的网格。
采样器在输入特征图上都能得到对应到输出特征图的值,这里就可以用到双线性插值,因为采样和插值都是可微分的(也一定要用可微分的采样公式),loss的梯度就不仅能流回输入特征图,还有采样的网格坐标,由坐标可以得到关于放射变换参数的偏微分,从而使放射变换得到优化。
这个spatial transformer可以放在一个CNN网络里的任何地方,从而得到一个spatial transform network。
优点:据说运行的很快,不会影响速度,还可以提高效率,对特征图进行向下或者向上采样。
4. 代码实现[4]
这里用了pytorch的VGG模型,加上spatial transformer,做一个FashionMNIST的分类。
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib as mpl
首先定义网络模型。
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.features = models.vgg16_bn(pretrained=True).features
# FashionMNIST是灰度图,通道数为1
self.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
self.classifier = nn.Linear(512, 10)
# Spatial transformer localization-network
self.localization = nn.Sequential(
nn.Conv2d(3, 6, kernel_size=5),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True),
nn.Conv2d(6,10, kernel_size=5),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True),
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
)
# Regressor for the 3 * 2 affine matrix
self.fc_loc = nn.Sequential(
nn.Linear(10, 32),
nn.ReLU(True),
nn.Linear(32, 3 * 2)
)
# Initialize the weights/bias with identity transformation
self.fc_loc[2].weight.data.zero_()
self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
# Spatial transformer network forward function
def stn(self, x):
xs = self.localization(x)
xs = xs.view(-1, xs.size()[1] * xs.size()[2] * xs.size()[3])
theta = self.fc_loc(xs)
theta = theta.view(-1, 2, 3)
# affine_grid输入是仿射矩阵(Nx2x3)和
# the target output image size. (N×C×H×W for 2D or N×C×D×H×W for 3D)
# 输出Tensor的尺寸(Tensor.Size(NxHxWx2)),输出的是归一化的二维网格。
grid = F.affine_grid(theta, x.size())
# grid_sample函数中将图像坐标归一化到[−1,1],其中0对应-1,width-1对应1。
x = F.grid_sample(x, grid)
return x, theta
def forward(self, x):
# transform the input
x = self.stn(x)
# Perform the usual forward pass
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return F.log_softmax(x, dim=1)
加载数据
# Training dataset
train_loader = torch.utils.data.DataLoader(
datasets.FashionMNIST(root='./data/FashionMNIST', train=True,
transform=transforms.Compose([
transforms.Resize(64),
transforms.RandomRotation(90),
transforms.ToTensor(),
])), batch_size=64, shuffle=True, num_workers=4)
# Test dataset
test_loader = torch.utils.data.DataLoader(
datasets.FashionMNIST(root='./data/FashionMNIST', train=False,
transform=transforms.Compose([
transforms.Resize(64),
transforms.RandomRotation(90),
transforms.ToTensor(),
])), batch_size=64, shuffle=True, num_workers=4)
设置optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
训练以及测试函数
def train(epoch):
model.train()
correct = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
with torch.set_grad_enabled(True):
output = model(data)
loss = F.nll_loss(output, target)
pred = output.max(1, keepdim=True)[1]
correct += pred.eq(target.view_as(pred)).sum().item()
loss.backward()
optimizer.step()
accuracy = correct / len(train_loader.dataset)
print('Train Epoch: {} \nLoss: {:.6f} \tAccuracy: {:.4f}'.format(
epoch, loss.item(), accuracy))
def test():
with torch.no_grad():
model.eval()
test_loss = 0
correct = 0
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output, theta = model(data)
# sum up batch loss
test_loss += F.nll_loss(output, target, size_average=False).item()
# get the index of the max log-probability
pred = output.max(1, keepdim=True)[1]
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = correct / len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
.format(test_loss, correct, len(test_loader.dataset),
100. * accuracy)
图片转换以及可视化
def convert_image_np(inp):
"""Convert a Tensor to numpy image."""
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
return inp
def visualize_stn():
with torch.no_grad():
# Get a batch of training data
model.load_state_dict(best_model)
data = next(iter(test_loader))[0].to(device)
input_tensor = data.cpu()
transformed_input_tensor, theta = model.stn(data)
transformed_input_tensor = transformed_input_tensor.cpu()
in_grid = convert_image_np(
torchvision.utils.make_grid(input_tensor))
out_grid = convert_image_np(
torchvision.utils.make_grid(transformed_input_tensor))
# Plot the results side-by-side
f, axarr = plt.subplots(1, 2)
axarr[0].imshow(in_grid)
axarr[0].set_title('Dataset Images')
axarr[1].imshow(out_grid)
axarr[1].set_title('Transformed Images')
主函数
for epoch in range(1, 50 + 1):
loss, accuracy = train(epoch)
test()
# Visualize the STN transformation on some input batch
visualize_stn()
plt.ioff()
plt.show()
结果参考:
test loss: 0.231391
test accuracy: 0.9174
stn_fashionmnist_rotated.png
参考
[1] Spatial Transformer Networks. arXiv:1506.02025
[2] Deep Learning Paper Implementations: Spatial Transformer Networks
[3] Bilinear interpolation,Wikipedia
[4] pytorch stn实现