Pytorch 抠图算法 Deep Image Matting
本文旨在实现抠图算法 Semantic Human Matting 的第二阶段模型 M-Net,也即 Deep Image Matting。值得说明的是,本文实现的模型与原始论文略有出入,除了模型的输入层有细微差别之外,损失函数也作了简化(但无本质差别)。
本文 训练数据 来源于 爱分割 公司开源的 数据集,总共包含 34427 张图片和对应的 alpha 通道,数据量非常大,能公开特别值得点赞。
Semantic Human Matting 抠图模型总的来说,Semantic Human Matting 论文提出的自动抠图的思路特别清晰明了(如上图),对于一张待抠图像,首先通过语义分割模型(即 T-Net)分割出前景、背景和未知区域(),然后广义的认为前景()+ 未知区域()组成一个三分图( Trimap),此时再利用 Deep Image Matting(即 M-Net) 即可高质量的完成抠图。完整的模型将在接下来的几篇文章逐步实现,本文只关注该模型的第二阶段(M-Net)。
M-Net 接受待抠图像(前景与背景的 RGB 3 通道合成)以及语义分割模型输出的 3 通道预测()拼接而成的 6 通道输入,经过编码器提取图像特征之后,由解码器得到预测 。如果语义分割模型分割的精度较高,那么可以认为 对应的区域已经很好的抠出了大部分的前景和背景,唯一需要提升准确率的是待抠对象的边缘区域,所以模型的第二阶段 M-Net 的目的就是细化的预测边缘区域(这正是 Deep Image Matting 要干的事情),两部分结合即得到最终的预测:
这个公式可以这样理解:
也就是:
根据全概率公式,用符号来表示则是:
但上述公式存在一个缺陷,即如果待抠目标外有大块噪声,则最终的预测也消除不了这个噪声,如下图:
为了消除第一阶段可能包含的外部噪声,本文的在实现 M-Net 的时候做了一个小的改动}:第二阶段的输入改为由待抠图像 + 组成的 4 通道图片(此时,相当于将 看成是三分图 trimap),并且将第二阶段的预测则作为最终的预测。
一、模型
模型文件 model.py 如下:
# -*- coding: utf-8 -*-
"""
Created on Sun Jul 21 07:08:58 2019
@author: shirhe-lyh
Implementation of paper:
Deep Image Matting, Ning Xu, eta., arxiv:1703.03872
"""
import torch
import torchvision as tv
VGG16_BN_MODEL_URL = 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth'
VGG16_BN_CONFIGS = {
'13conv':
[64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512,
'M', 512, 512, 512],
'10conv':
[64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512]
}
def make_layers(cfg, batch_norm=False):
"""Copy from: torchvision/models/vgg.
Changs retrue_indices in MaxPool2d from False to True.
"""
layers = []
in_channels = 3
for v in cfg:
if v == 'M':
layers += [torch.nn.MaxPool2d(kernel_size=2, stride=2,
return_indices=True)]
else:
conv2d = torch.nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
if batch_norm:
layers += [conv2d, torch.nn.BatchNorm2d(v),
torch.nn.ReLU(inplace=True)]
else:
layers += [conv2d, torch.nn.ReLU(inplace=True)]
in_channels = v
return torch.nn.Sequential(*layers)
class VGGFeatureExtractor(torch.nn.Module):
"""Feature extractor by VGG network."""
def __init__(self, config=None, batch_norm=True):
"""Constructor.
Args:
config: The convolutional architecture of VGG network.
batch_norm: A boolean indicating whether the architecture
include Batch Normalization layers or not.
"""
super(VGGFeatureExtractor, self).__init__()
self._config = config
if self._config is None:
self._config = VGG16_BN_CONFIGS.get('10conv')
self.features = make_layers(self._config, batch_norm=batch_norm)
self._indices = None
def forward(self, x):
self._indices = []
for layer in self.features:
if isinstance(layer, torch.nn.modules.pooling.MaxPool2d):
x, indices = layer(x)
self._indices.append(indices)
else:
x = layer(x)
return x
def vgg16_bn_feature_extractor(config=None, pretrained=True, progress=True):
model = VGGFeatureExtractor(config, batch_norm=True)
if pretrained:
state_dict = tv.models.utils.load_state_dict_from_url(
VGG16_BN_MODEL_URL, progress=progress)
model.load_state_dict(state_dict, strict=False)
return model
class DIM(torch.nn.Module):
"""Deep Image Matting."""
def __init__(self, feature_extractor):
"""Constructor.
Args:
feature_extractor: Feature extractor, such as VGGFeatureExtractor.
"""
super(DIM, self).__init__()
# Head convolution layer, number of channels: 4 -> 3
self._head_conv = torch.nn.Conv2d(in_channels=4, out_channels=3,
kernel_size=5, padding=2)
# Encoder
self._feature_extractor = feature_extractor
self._feature_extract_config = self._feature_extractor._config
# Decoder
self._decode_layers = self.decode_layers()
# Prediction
self._final_conv = torch.nn.Conv2d(self._feature_extract_config[0], 1,
kernel_size=5, padding=2)
self._sigmoid = torch.nn.Sigmoid()
def forward(self, x):
x = self._head_conv(x)
x = self._feature_extractor(x)
indices = self._feature_extractor._indices[::-1]
index = 0
for layer in self._decode_layers:
if isinstance(layer, torch.nn.modules.pooling.MaxUnpool2d):
x = layer(x, indices[index])
index += 1
else:
x = layer(x)
x = self._final_conv(x)
x = self._sigmoid(x)
return x
def decode_layers(self):
layers = []
strides = [1]
channels = []
config_reversed = self._feature_extract_config[::-1]
for i, v in enumerate(config_reversed):
if v == 'M':
strides.append(2)
channels.append(config_reversed[i+1])
channels.append(channels[-1])
in_channels = self._feature_extract_config[-1]
for stride, out_channels in zip(strides, channels):
if stride == 2:
layers += [torch.nn.MaxUnpool2d(kernel_size=2, stride=2)]
layers += [torch.nn.Conv2d(in_channels, out_channels,
kernel_size=5, padding=2),
torch.nn.BatchNorm2d(num_features=out_channels),
torch.nn.ReLU(inplace=True)]
in_channels = out_channels
return torch.nn.Sequential(*layers)
def loss(self, alphas_pred, alphas_gt, images=None, epsilon=1e-12):
losses = torch.sqrt(
torch.mul(alphas_pred - alphas_gt, alphas_pred - alphas_gt) +
epsilon)
loss = torch.mean(losses)
if images is not None:
images_fg_gt = torch.mul(images, alphas_gt)
images_fg_pred = torch.mul(images, alphas_pred)
images_fg_error = images_fg_pred - images_fg_gt
losses_image = torch.sqrt(
torch.mul(images_fg_error, images_fg_error) + epsilon)
loss += torch.mean(losses_image)
return loss
if __name__ == '__main__':
feature_extractor = vgg16_bn_feature_extractor(config=VGG16_BN_CONFIGS['13conv'])
print(feature_extractor.modules)
dim = DIM(feature_extractor=feature_extractor)
images = torch.rand((2, 4, 320, 320))
alphas = dim(images)
print(alphas.shape)
alphas_gt = torch.rand((2, 1, 320, 320))
loss = dim.loss(alphas, alphas_gt)
print(loss.item())
二、训练
【附录】
快速上手 Pytorch 的资料:PyTorch Tutorial for Deep Learning Researchers。