PyTrch深度学习简明实战19 - 语义分割(UNET)
2023-04-01 本文已影响0人
薛东弗斯
[学习笔记20:图像语义分割 - pbc的成长之路 - 博客园 (cnblogs.com)]
(https://www.cnblogs.com/miraclepbc/p/14385632.html)
U-Net: Convolutional Networks for Biomedical Image Segmentation (arxiv.org)
conv层都是3x3. 填充为valid
4层下采样,4层上采样
对一个模型,最关键的点有2个:1-结构;2-损失函数
![](https://img.haomeiwen.com/i3968643/1f9c8fbe80c39163.png)
![](https://img.haomeiwen.com/i3968643/33fb32781cf95d5a.png)
![](https://img.haomeiwen.com/i3968643/12e4ab4c16ae9232.png)
![](https://img.haomeiwen.com/i3968643/825d162166c05f3c.png)
![](https://img.haomeiwen.com/i3968643/3f6d689e484bcb46.png)
![](https://img.haomeiwen.com/i3968643/7910a3fd45b63554.png)
![](https://img.haomeiwen.com/i3968643/a66b0168059c46c3.png)
![](https://img.haomeiwen.com/i3968643/c6b9b659f99f957a.png)
![](https://img.haomeiwen.com/i3968643/f1d22fa6f5704fa1.png)
![](https://img.haomeiwen.com/i3968643/05c73e254c72d6cb.png)
![](https://img.haomeiwen.com/i3968643/86d01d872e1b8f80.png)
![](https://img.haomeiwen.com/i3968643/0c8e366f4c39ac85.png)
![](https://img.haomeiwen.com/i3968643/4a7ddd4b6021fbc0.png)
语义分割:标注属于哪一类。不同类别用不同颜色区分开来。
![](https://img.haomeiwen.com/i3968643/bc067785d7fe613e.png)
![](https://img.haomeiwen.com/i3968643/94a24980b826b1d7.png)
![](https://img.haomeiwen.com/i3968643/684f49a826cdf84d.png)
![](https://img.haomeiwen.com/i3968643/627dd4b57e72b9ad.png)
![](https://img.haomeiwen.com/i3968643/98d72e9f91b4d2fa.png)
![](https://img.haomeiwen.com/i3968643/d15e504bafa3081d.png)
![](https://img.haomeiwen.com/i3968643/c6377cbfb2873bd5.png)
![](https://img.haomeiwen.com/i3968643/90c54ae56506eae2.png)
(1条消息) 香港中文大学 数据集_mit trajectory data set_Frist2018的博客-CSDN博客
www.cse.cuhk.edu.hk/~leojia/projects/automatting/papers/deepmatting.pdf
Deep Automatic Portrait Matting (cuhk.edu.hk)
图像语义分割:是前景还是背景,属于二分类问题。
![](https://img.haomeiwen.com/i3968643/edbba7d84d602044.png)
输出n+1
![](https://img.haomeiwen.com/i3968643/7fd0a92a71cc83a0.png)
![](https://img.haomeiwen.com/i3968643/37767bedb9b16981.png)
![](https://img.haomeiwen.com/i3968643/a9e0dd10741f8e9b.png)
![](https://img.haomeiwen.com/i3968643/4572bc6e6acae016.png)
![](https://img.haomeiwen.com/i3968643/685578727c654d1e.png)
![](https://img.haomeiwen.com/i3968643/ba59902222f68655.png)
![](https://img.haomeiwen.com/i3968643/b8e6aaff23cc23ad.png)
![](https://img.haomeiwen.com/i3968643/fa9be1c548bc87c8.png)
Deep Automatic Portrait Matting (cuhk.edu.hk)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import torchvision
from torchvision import transforms
import os
import glob
from PIL import Image
BATCH_SIZE = 8
# 绘制原图
# pil_img = Image.open('./data/hk/training/00001.png')
# np_img = np.array(pil_img)
# plt.imshow(np_img)
# plt.show()
# 绘制分割后的图
# pil_img = Image.open('./data/hk/training/00001_matte.png')
# np_img = np.array(pil_img)
# plt.imshow(np_img)
# plt.show()
# np_img.max(), np_img.min() # (255, 0)
# np_img.shape # (800, 600)
# np.unique(np_img) .., 255]) # array([ 0, .., 255]) 像素点0-255直接,不是2分类的0或者1
# 绘制像素点为0/1的图片
# pil_img = Image.open('./data/hk/training/00001_matte.png')
# np_img = np.array(pil_img)
# np_img[np_img>0]=1
# plt.imshow(np_img)
# plt.show()
# np.unique(np_img) # array([0, 1], dtype=uint8) 此时,像素只包括0和1. 这种变换对原有像素有一定的损失。
all_pics = glob.glob('./data/hk/training/*.png')
# all_pics[:5]
# ['./data/hk/training\\00001.png',
# './data/hk/training\\00001_matte.png',
# './data/hk/training\\00002.png',
# './data/hk/training\\00002_matte.png',
# './data/hk/training\\00003.png']
images = [p for p in all_pics if 'matte' not in p]
# len(images) # 1700
annotations = [p for p in all_pics if 'matte' in p]
# len(annotations) # 1700
np.random.seed(2021)
index = np.random.permutation(len(images))
images = np.array(images)[index]
anno = np.array(annotations)[index]
all_test_pics = glob.glob('./data/hk/testing/*.png')
test_images = [p for p in all_test_pics if 'matte' not in p]
test_anno = [p for p in all_test_pics if 'matte' in p]
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
class Portrait_dataset(data.Dataset):
def __init__(self, img_paths, anno_paths): # 需要提高图片路径+分割图路径
self.imgs = img_paths
self.annos = anno_paths
def __getitem__(self, index): # 切片
img = self.imgs[index]
anno = self.annos[index]
pil_img = Image.open(img)
img_tensor = transform(pil_img) # 通过transform转换为tensor。 对于原图的处理
pil_anno = Image.open(anno)
anno_tensor = transform(pil_anno)
anno_tensor = torch.squeeze(anno_tensor).type(torch.long) # 默认转换后的尺寸是256*256*1,1个channel。 用squeeze去掉chanel
anno_tensor[anno_tensor > 0] = 1 # 如果大于0,就置为1
return img_tensor, anno_tensor
def __len__(self):
return len(self.imgs)
train_dataset = Portrait_dataset(images, anno)
test_dataset = Portrait_dataset(test_images, test_anno)
train_dl = data.DataLoader(train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
)
test_dl = data.DataLoader(
test_dataset,
batch_size=BATCH_SIZE,
)
imgs_batch, annos_batch = next(iter(train_dl)) # 取出批次数据
# img_batch.shape # batch=8, channel=3, 大小256*256
# annos_batch.shape # batch=8, channel=1, 大小256*256. 用torch.squeese 就把torch为1的维度去掉了
img = imgs_batch[0].permute(1,2,0).numpy() # 对第1张图片进行绘图。 permute将channel放到最后面
anno = annos_batch[0].numpy() # anno图片没有channle这个属性,因此不需要用permute
plt.subplot(1,2,1) # 绘制1行2列的第1张图
plt.imshow(img)
plt.subplot(1,2,2) # 绘制1行2列的第2张图
plt.imshow(anno)
class Downsample(nn.Module):
def __init__(self, in_channels, out_channels):
super(Downsample, self).__init__()
self.conv_relu = nn.Sequential( # 实现2层卷积
nn.Conv2d(in_channels, out_channels,
kernel_size=3, padding=1), # 当kernel为3时,为保证图像大小相同,则需要设置padding=1
nn.ReLU(inplace=True), # 论文里面没有加padding,这里加了padding,是为了方便后面的concat操作
nn.Conv2d(out_channels, out_channels,
kernel_size=3, padding=1), # 卷积后,尺寸减少k-1. k为kernel size。 四个方向都减小k-1
nn.ReLU(inplace=True)
)
self.pool = nn.MaxPool2d(kernel_size=2)
def forward(self, x, is_pool=True):
if is_pool:
x = self.pool(x)
x = self.conv_relu(x)
return x
class Upsample(nn.Module): # 上采样完成3步运算,卷积+卷积+上采样
def __init__(self, channels): # 采样一次,channel会变为之前的一半
super(Upsample, self).__init__()
self.conv_relu = nn.Sequential(
nn.Conv2d(2*channels, channels, # 输入2* channels,输出channels
kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(channels, channels, # 第2层卷积,channel数不变
kernel_size=3, padding=1),
nn.ReLU(inplace=True) # 如果不需要中间数据,则用inplace=True。提高内存利用率。如果不设置,则系统会为中间数据新开辟一块内存区域。
)
self.upconv_relu = nn.Sequential(
nn.ConvTranspose2d(channels, # 反卷积
channels//2, # 上采样,输出channel会变为输入的一般
kernel_size=3,
stride=2,
padding=1, # 此处padding只是定义输入像素的位置,从第1个像素开始。此处padding的含义与卷积不同
output_padding=1), # 反卷积后,对最外层进行填充。此处的out_padding的含义与卷积的padding一样
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.conv_relu(x)
x = self.upconv_relu(x)
return x
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.down1 = Downsample(3, 64)
self.down2 = Downsample(64, 128)
self.down3 = Downsample(128, 256)
self.down4 = Downsample(256, 512)
self.down5 = Downsample(512, 1024)
self.up = nn.Sequential(
nn.ConvTranspose2d(1024,
512,
kernel_size=3,
stride=2, # 将图片放大为原来的2倍
padding=1,
output_padding=1),
nn.ReLU(inplace=True)
)
self.up1 = Upsample(512)
self.up2 = Upsample(256)
self.up3 = Upsample(128)
self.conv_2 = Downsample(128, 64) # 2层卷积
self.last = nn.Conv2d(64, 2, kernel_size=1) # 输出层。输入64,输出2分类
def forward(self, x):
x1 = self.down1(x, is_pool=False) # 5层下采样,不需要池化
x2 = self.down2(x1) # x2在x1的基础上下采样
x3 = self.down3(x2) # x3在x2的基础上下采样
x4 = self.down4(x3) # x4在x3的基础上下采样
x5 = self.down5(x4) # x5在x4的基础上下采样
x5 = self.up(x5)
# 充分利用中间特征的特性,有利于提升语义分割模型的效果
x5 = torch.cat([x4, x5], dim=1) # 32*32*1024 # x4的输出与x5的输出需要经过concat合并,增加厚度。 沿着最后的维度channel进行合并
x5 = self.up1(x5) # 64*64*256)
x5 = torch.cat([x3, x5], dim=1) # 64*64*512 # x3的输出与x5的输出需要经过concat合并,增加厚度。 沿着最后的维度channel进行合并
x5 = self.up2(x5) # 128*128*128
x5 = torch.cat([x2, x5], dim=1) # 128*128*256 # x2的输出与x5的输出需要经过concat合并,增加厚度。 沿着最后的维度channel进行合并
x5 = self.up3(x5) # 256*256*64
x5 = torch.cat([x1, x5], dim=1) # 256*256*128 # x1的输出与x5的输出需要经过concat合并,增加厚度。 沿着最后的维度channel进行合并
x5 = self.conv_2(x5, is_pool=False) # 256*256*64
x5 = self.last(x5) # 256*256*3 # 输出层
return x5
model = Net()
if torch.cuda.is_available():
model.to('cuda')
loss_fn = nn.CrossEntropyLoss()
from torch.optim import lr_scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
def fit(epoch, model, trainloader, testloader):
correct = 0
total = 0
running_loss = 0
model.train()
for x, y in trainloader:
if torch.cuda.is_available():
x, y = x.to('cuda'), y.to('cuda')
y_pred = model(x)
loss = loss_fn(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
y_pred = torch.argmax(y_pred, dim=1)
correct += (y_pred == y).sum().item()
total += y.size(0)
running_loss += loss.item()
exp_lr_scheduler.step()
epoch_loss = running_loss / len(trainloader.dataset)
epoch_acc = correct / (total*256*256)
test_correct = 0
test_total = 0
test_running_loss = 0
model.eval()
with torch.no_grad():
for x, y in testloader:
if torch.cuda.is_available():
x, y = x.to('cuda'), y.to('cuda')
y_pred = model(x)
loss = loss_fn(y_pred, y)
y_pred = torch.argmax(y_pred, dim=1)
test_correct += (y_pred == y).sum().item()
test_total += y.size(0)
test_running_loss += loss.item()
epoch_test_loss = test_running_loss / len(testloader.dataset)
epoch_test_acc = test_correct / (test_total*256*256)
print('epoch: ', epoch,
'loss: ', round(epoch_loss, 3),
'accuracy:', round(epoch_acc, 3),
'test_loss: ', round(epoch_test_loss, 3),
'test_accuracy:', round(epoch_test_acc, 3)
)
return epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc
epochs = 30
train_loss = []
train_acc = []
test_loss = []
test_acc = []
for epoch in range(epochs):
epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch,
model,
train_dl,
test_dl)
train_loss.append(epoch_loss)
train_acc.append(epoch_acc)
test_loss.append(epoch_test_loss)
test_acc.append(epoch_test_acc)
# 保存模型
PATH = 'unet_model.pth'
torch.save(model.state_dict(), PATH)
# 测试模型
my_model = Net()
my_model.load_state_dict(torch.load(PATH))
num=3 # 取出3张图片
image, mask = next(iter(test_dl))
pred_mask = my_model(image)
plt.figure(figsize=(10, 10))
for i in range(num):
plt.subplot(num, 3, i*num+1) # i从0开始, 第一行 第1张图片的原图
plt.imshow(image[i].permute(1,2,0).cpu().numpy())
plt.subplot(num, 3, i*num+2) # 实际的分割图
plt.imshow(mask[i].cpu().numpy())
plt.subplot(num, 3, i*num+3) # 预测出的分割图
plt.imshow(torch.argmax(pred_mask[i].permute(1,2,0), axis=-1).detach().numpy()) # detach 取出实际结果
# train数据集上测试
image, mask = next(iter(train_dl))
pred_mask = my_model(image)
plt.figure(figsize=(10, 10))
for i in range(num):
plt.subplot(num, 3, i*num+1)
plt.imshow(image[i].permute(1,2,0).cpu().numpy())
plt.subplot(num, 3, i*num+2)
plt.imshow(mask[i].cpu().numpy())
plt.subplot(num, 3, i*num+3)
plt.imshow(torch.argmax(pred_mask[i].permute(1,2,0), axis=-1).detach().numpy())
model
Net(
(down1): Downsample(
(conv_relu): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
)
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(down2): Downsample(
(conv_relu): Sequential(
(0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
)
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(down3): Downsample(
(conv_relu): Sequential(
(0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
)
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(down4): Downsample(
(conv_relu): Sequential(
(0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
)
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(down5): Downsample(
(conv_relu): Sequential(
(0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
)
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(up): Sequential(
(0): ConvTranspose2d(1024, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
(1): ReLU(inplace=True)
)
(up1): Upsample(
(conv_relu): Sequential(
(0): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
)
(upconv_relu): Sequential(
(0): ConvTranspose2d(512, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
(1): ReLU(inplace=True)
)
)
(up2): Upsample(
(conv_relu): Sequential(
(0): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
)
(upconv_relu): Sequential(
(0): ConvTranspose2d(256, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
(1): ReLU(inplace=True)
)
)
(up3): Upsample(
(conv_relu): Sequential(
(0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
)
(upconv_relu): Sequential(
(0): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
(1): ReLU(inplace=True)
)
)
(conv_2): Downsample(
(conv_relu): Sequential(
(0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
)
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(last): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
)