工具癖Pytorch与深度学习个人专题

Pytorch实践(二)——老旧照片恢复器——图片AI自动上色(

2018-07-29  本文已影响16人  dalalaa

图片自动上色的原理很简单,下面我们边做边讲

首先导入必要的工具

import torch as t
from PIL import Image
import numpy as np 
import matplotlib.pyplot as plt 
from skimage import color
from skimage.io import imshow
from tqdm import tqdm_notebook
from torchvision import transforms
%matplotlib inline

加载彩色图片,这里选择了Dota2里面的圣堂刺客TA的一张宣传画,先看看彩色模式

img_rgb = Image.open("H:/COLOURING/TA.jpg").resize((256,256))
img_rgb = np.array(img_rgb)
plt.imshow(img_rgb),img_rgb.shape
彩色TA

然后是黑白模式

img_gray = np.array(Image.open("H:/COLOURING/TA.jpg").convert('L').resize((256,256)))
plt.imshow(img_gray,cmap = 'gray')
黑白TA

在计算机中,彩色图片通常以RGB模式显示(在opencv中是BGR形式),有三个通道,即图片是由三个像素矩阵叠加而成。

而黑白模式的图片只有一个通道,即只有一个像素矩阵。

出RGB模式外,还有很多图片显示模式,比如本文中要用到的lab模式,lab模式中同样有三个通道,第一个通道l是亮度通道,用来表示图片亮度,其效果与黑白图片非常相似。下面是以灰度模式展示的l通道

img_lab = color.rgb2lab(img_rgb/255)
img_lab_l = img_lab[:,:,0]
plt.imshow(img_lab_l,cmap = 'gray')
亮度TA

l通道展示效果与灰度图像无异,另外两个通道a和b是两个色彩通道,下面我们把两个色彩通道单独拎出来看看:

首先是a通道:

img_lab_a = img_lab[:,:,1]
plt.imshow(img_lab_a,cmap = 'gray') # matplotlib没有专门绘制ab通道的cmap,所以这里只是个示意图,真实色彩不是这样的。
a通道TA

很明显,色彩通道里面看不到图像的线条信息,下面再看一下b通道:

img_lab_b = img_lab[:,:,2]
plt.imshow(img_lab_b,cmap = 'gray') # matplotlib没有专门绘制ab通道的cmap,所以这里只是个示意图,真实色彩不是这样的。
b通道TA

b通道里面没了眼影的TA是真的丑~~

上色原理

介绍到这里,自动上色的原理已经很明朗了,就是以亮度层为data,ab层作为target,建立一个从亮度图像到色彩层的映射。

注意,这里的黑白图片指的是亮度层,而不是灰度图片

搭建神经网络

首先需要建立一个神经网络,这个网络的输入是图片的l层,输出是图片的ab层。

class Net(t.nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1 = t.nn.Sequential(
            t.nn.Conv2d(1,16,3,stride=2,padding=1),
            t.nn.BatchNorm2d(16),
            t.nn.ReLU(),
            t.nn.Upsample(scale_factor=2)
        )
        self.conv2 = t.nn.Sequential(
            t.nn.Conv2d(16,32,3,2,1),
            t.nn.BatchNorm2d(32),
            t.nn.ReLU(),
            t.nn.Upsample(scale_factor=2)
        )
        self.conv3 = t.nn.Sequential(
            t.nn.Conv2d(32,16,3,2,1),
            t.nn.BatchNorm2d(16),
            t.nn.ReLU(),
            t.nn.Upsample(scale_factor=2)
        )
        self.conv4 = t.nn.Sequential(
            t.nn.Conv2d(16,2,3,2,1),
            t.nn.BatchNorm2d(2),
            t.nn.ReLU(),
            t.nn.Upsample(scale_factor=2)
        )
        
    def forward(self,x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        return x

处理数据

img_gray = img_gray[:,:,np.newaxis]
img_lab_l = img_lab_l[:,:,np.newaxis]
img_gray.shape,img_lab_l.shape
((256, 256, 1), (256, 256, 1))
x_train = img_lab_l
y_train = img_lab[:,:,1:3]
y_train /= 128
transform = transforms.Compose([
    transforms.ToTensor(),
])

PIL中image对象是(H,W,C)形状,而Pytorch中的图像tensor是(C,H,W)形状,需要进行转换

x_train,y_train = transform(x_train),transform(y_train)
x_train,y_train = x_train.float(),y_train.float()
x_train,y_train = x_train.view(-1,1,256,256),y_train.view(-1,2,256,256)
x_train.shape,y_train.shape
(torch.Size([1, 1, 256, 256]), torch.Size([1, 2, 256, 256]))

训练模型

net = Net()
net
Net(
  (conv1): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Upsample(scale_factor=2, mode=nearest)
  )
  (conv2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Upsample(scale_factor=2, mode=nearest)
  )
  (conv3): Sequential(
    (0): Conv2d(32, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Upsample(scale_factor=2, mode=nearest)
  )
  (conv4): Sequential(
    (0): Conv2d(16, 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Upsample(scale_factor=2, mode=nearest)
  )
)
EPOCHS = 500
LR = 0.01
criterion = t.nn.MSELoss()
optimizer = t.optim.Adam(net.parameters(),lr=LR,weight_decay=0.0)
for epoch in tqdm_notebook(range(EPOCHS)):
    index=0
    if epoch % 100 == 0:
        for param_group in optimizer.param_groups:
            LR = LR * 0.9
            param_group['lr'] = LR
    prediction = net.forward(x_train)
    loss = criterion(prediction,y_train)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
loss

还原并显示图像

net.eval()
prediction = net.forward(x_train)
prediction *= 128
prediction = prediction[0].data.numpy()
x_train = x_train[0].data.numpy()
x_train.shape,prediction.shape
result = np.zeros((256,256,3))
result[:,:,0] = x_train[0]
result[:,:,1] = prediction[0]
result[:,:,2] = prediction[1]
result_rgb = color.lab2rgb(result)
plt.imshow(np.array(result_rgb))
上色TA

500个EPOCHS后,图片已经有点样子了,迭代更多次数之后就能够达到原图的效果了。

需要源代码的可以私信我~

对机器学习感兴趣的朋友可以加群:

机器学习-菜鸡互啄
上一篇 下一篇

猜你喜欢

热点阅读