深度学习

Pytorch 抠图算法 Deep Image Matting

2019-07-21  本文已影响30人  公输睚信

        本文旨在实现抠图算法 Semantic Human Matting 的第二阶段模型 M-Net,也即 Deep Image Matting。值得说明的是,本文实现的模型与原始论文略有出入,除了模型的输入层有细微差别之外,损失函数也作了简化(但无本质差别)。

        本文 训练数据 来源于 爱分割 公司开源的 数据集,总共包含 34427 张图片和对应的 alpha 通道,数据量非常大,能公开特别值得点赞。

Semantic Human Matting 抠图模型

        总的来说,Semantic Human Matting 论文提出的自动抠图的思路特别清晰明了(如上图),对于一张待抠图像,首先通过语义分割模型(即 T-Net)分割出前景F_s、背景B_s和未知区域U_sF_s + B_s + U_s = 1),然后广义的认为前景(F_s)+ 未知区域(U_s)组成一个三分图( Trimap),此时再利用 Deep Image Matting(即 M-Net) 即可高质量的完成抠图。完整的模型将在接下来的几篇文章逐步实现,本文只关注该模型的第二阶段(M-Net)。

        M-Net 接受待抠图像(前景与背景的 RGB 3 通道合成)以及语义分割模型输出的 3 通道预测(F_s, B_s, U_s)拼接而成的 6 通道输入,经过编码器提取图像特征之后,由解码器得到预测 \alpha_r。如果语义分割模型分割的精度较高,那么可以认为 F_s, B_s 对应的区域已经很好的抠出了大部分的前景和背景,唯一需要提升准确率的是待抠对象的边缘区域,所以模型的第二阶段 M-Net 的目的就是细化的预测边缘区域(这正是 Deep Image Matting 要干的事情),两部分结合即得到最终的预测:
F_s + U_s\alpha_r 。
这个公式可以这样理解:
\textrm{预测的前景} = \textrm{确定区域上的前景} + \textrm{未知区域上的前景}
也就是:
P(\textrm{前景}) = P(确定区域上前景) + P(未知区域上前景)
根据全概率公式,用符号来表示则是:
\begin{align} \alpha &= P(F) \\ &= P(F|\mathrm{known})P(\mathrm{known}) + P(F|\mathrm{unknown})P(\mathrm{unknow}) \\ &= \frac{F_s}{F_s + B_s}(F_s + B_s) + \alpha_rU_s\\ &= F_s + U_s\alpha_r \end{align}

        但上述公式存在一个缺陷,即如果待抠目标外有大块噪声,则最终的预测也消除不了这个噪声,如下图:

语义分割之后的前景带有外部噪声(衣服左侧的小照片)
为了消除第一阶段可能包含的外部噪声,本文的在实现 M-Net 的时候做了一个小的改动}:第二阶段的输入改为由待抠图像 + F_s 组成的 4 通道图片(此时,相当于将 F_s 看成是三分图 trimap),并且将第二阶段的预测则作为最终的预测。

一、模型

模型文件 model.py 如下:

# -*- coding: utf-8 -*-
"""
Created on Sun Jul 21 07:08:58 2019

@author: shirhe-lyh

Implementation of paper:
    Deep Image Matting, Ning Xu, eta., arxiv:1703.03872
"""

import torch
import torchvision as tv

VGG16_BN_MODEL_URL = 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth'

VGG16_BN_CONFIGS = {
    '13conv':
        [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 
         'M', 512, 512, 512],
    '10conv':
        [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512]
    }


def make_layers(cfg, batch_norm=False):
    """Copy from: torchvision/models/vgg.
    
    Changs retrue_indices in MaxPool2d from False to True.
    """
    layers = []
    in_channels = 3
    for v in cfg:
        if v == 'M':
            layers += [torch.nn.MaxPool2d(kernel_size=2, stride=2, 
                                          return_indices=True)]
        else:
            conv2d = torch.nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, torch.nn.BatchNorm2d(v), 
                           torch.nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, torch.nn.ReLU(inplace=True)]
            in_channels = v
    return torch.nn.Sequential(*layers)


class VGGFeatureExtractor(torch.nn.Module):
    """Feature extractor by VGG network."""
    
    def __init__(self, config=None, batch_norm=True):
        """Constructor.
        
        Args:
            config: The convolutional architecture of VGG network.
            batch_norm: A boolean indicating whether the architecture 
                include Batch Normalization layers or not.
        """
        super(VGGFeatureExtractor, self).__init__()
        self._config = config
        if self._config is None:
            self._config = VGG16_BN_CONFIGS.get('10conv')
        self.features = make_layers(self._config, batch_norm=batch_norm)
        self._indices = None
        
    def forward(self, x):
        self._indices = []
        for layer in self.features:
            if isinstance(layer, torch.nn.modules.pooling.MaxPool2d):
                x, indices = layer(x)
                self._indices.append(indices)
            else:
                x = layer(x)
        return x
    
    
def vgg16_bn_feature_extractor(config=None, pretrained=True, progress=True):
    model = VGGFeatureExtractor(config, batch_norm=True)
    if pretrained:
        state_dict = tv.models.utils.load_state_dict_from_url(
            VGG16_BN_MODEL_URL, progress=progress)
        model.load_state_dict(state_dict, strict=False)
    return model


class DIM(torch.nn.Module):
    """Deep Image Matting."""
    
    def __init__(self, feature_extractor):
        """Constructor.
        
        Args:
            feature_extractor: Feature extractor, such as VGGFeatureExtractor.
        """
        super(DIM, self).__init__()
        # Head convolution layer, number of channels: 4 -> 3
        self._head_conv = torch.nn.Conv2d(in_channels=4, out_channels=3,
                                          kernel_size=5, padding=2)
        # Encoder
        self._feature_extractor = feature_extractor
        self._feature_extract_config = self._feature_extractor._config
        # Decoder
        self._decode_layers = self.decode_layers()
        # Prediction
        self._final_conv = torch.nn.Conv2d(self._feature_extract_config[0], 1,
                                           kernel_size=5, padding=2)
        self._sigmoid = torch.nn.Sigmoid()
        
    def forward(self, x):
        x = self._head_conv(x)
        x = self._feature_extractor(x)
        indices = self._feature_extractor._indices[::-1]
        index = 0
        for layer in self._decode_layers:
            if isinstance(layer, torch.nn.modules.pooling.MaxUnpool2d):
                x = layer(x, indices[index])
                index += 1
            else:
                x = layer(x)
        x = self._final_conv(x)
        x = self._sigmoid(x)
        return x
    
    def decode_layers(self):
        layers = []
        strides = [1]
        channels = []
        config_reversed = self._feature_extract_config[::-1]
        for i, v in enumerate(config_reversed):
            if v == 'M':
                strides.append(2)
                channels.append(config_reversed[i+1])
        channels.append(channels[-1])
        in_channels = self._feature_extract_config[-1]
        for stride, out_channels in zip(strides, channels):
            if stride == 2:
                layers += [torch.nn.MaxUnpool2d(kernel_size=2, stride=2)]
            layers += [torch.nn.Conv2d(in_channels, out_channels,
                                       kernel_size=5, padding=2),
                       torch.nn.BatchNorm2d(num_features=out_channels),
                       torch.nn.ReLU(inplace=True)]
            in_channels = out_channels
        return torch.nn.Sequential(*layers)
    
    def loss(self, alphas_pred, alphas_gt, images=None, epsilon=1e-12):
        losses = torch.sqrt(
            torch.mul(alphas_pred - alphas_gt, alphas_pred - alphas_gt) + 
            epsilon)
        loss = torch.mean(losses)
        if images is not None:
            images_fg_gt = torch.mul(images, alphas_gt)
            images_fg_pred = torch.mul(images, alphas_pred)
            images_fg_error = images_fg_pred - images_fg_gt
            losses_image = torch.sqrt(
                torch.mul(images_fg_error, images_fg_error) + epsilon)
            loss += torch.mean(losses_image)
        return loss
 
    
if __name__ == '__main__':
    feature_extractor = vgg16_bn_feature_extractor(config=VGG16_BN_CONFIGS['13conv'])
    print(feature_extractor.modules)
    
    dim = DIM(feature_extractor=feature_extractor)
    
    images = torch.rand((2, 4, 320, 320))
    
    alphas = dim(images)
    print(alphas.shape)
    
    alphas_gt = torch.rand((2, 1, 320, 320))
    loss = dim.loss(alphas, alphas_gt)
    print(loss.item())

二、训练

附录
快速上手 Pytorch 的资料:PyTorch Tutorial for Deep Learning Researchers

上一篇下一篇

猜你喜欢

热点阅读