软件

TensorFlow vs PyTorch 7: 创建模型

2021-11-07  本文已影响0人  LabVIEW_Python

《TensorFlow vs PyTorch 6: ETL,Extract, Transform, and Load》一文中,我们将存储在硬盘中的图像文件转化成了适合训练模型的批次张量数据,本文我们将讲述如何创建模型。

PyTorch创建模型。PyTorch支持nn.Module子类和Sequential顺序方式创建神经网络模型,在上文中,拿到的批次张量数据的形状是:

train_batch_data shape: torch.Size([64, 1, 28, 28])

由此构建输入为[64,1,28,28]的卷积神经网络模型,完整代码如下所示:

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torchvision import transforms 
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt 
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" # Solve the OMP: Error #15

train_dataset = datasets.FashionMNIST(root='data',train=True, download=True, transform=ToTensor())
test_dataset = datasets.FashionMNIST(root='data',train=False, download=True)

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)

train_batch_data, train_batch_labels = next(iter(train_dataloader))
print(f"train_batch_data shape: {train_batch_data.size()}")
print(f"train_batch_labels shape: {train_batch_labels.size()}")
print(train_batch_labels[0])

import torch.nn as nn
import torch.nn.functional as F
class MyCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1,6,3)     # 28->26
        self.maxpool = nn.MaxPool2d(2,2)  # 26->13
        self.conv2 = nn.Conv2d(6,16,3)    # 13->11
        self.fc1 = nn.Linear(16*11*11,128)
        self.fc2 = nn.Linear(128,10)
    
    def forward(self, x):
        x = self.maxpool(F.relu(self.conv1(x)))
        x = F.relu(self.conv2(x))
        x = torch.flatten(x, 1) # start dim=0是batch size, 所以从1开始flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = self.fc2(x)         # 分类输出不需要激活
        return x 

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MyCNN().to(device)
print(model)

X = torch.rand(64, 1, 28, 28, device=device)
logits = model(X)
print(logits.shape)

运行结果:

train_batch_data shape: torch.Size([64, 1, 28, 28])
train_batch_labels shape: torch.Size([64])
tensor(5)
MyCNN(
(conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
(maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
(fc1): Linear(in_features=1936, out_features=128, bias=True)
(fc2): Linear(in_features=128, out_features=10, bias=True)
)
torch.Size([64, 10])

上述代码每层的使用方法,参见《PyTorch神经网络层拆解》

TensorFlow.Keras创建模型。支持三种方式,循序式、函数式和子类式,其中循序式和函数式在Keras中最好用,也最常见;子类式,个人感觉还不成熟,例如:不能很好的使用summary()方法打印出模型的结构。为了跟PyTorch的子类式比较,本例给出TensorFlow.Keras的子类式模型创建方法。

上文中,拿到的tf.dataset批次张量数据的形状是:

(64, 28, 28, 1), 即NHWC

完整范例代码:

import tensorflow as tf
from tensorflow.keras import layers

inputs = tf.random.normal([64,28,28,1]) #The Conv2D op currently only supports the NHWC tensor format on the CPU

class MyCNN(tf.keras.Model):
    def __init__(self,num_classes=10):
        super().__init__()
        self.conv1 = layers.Conv2D(filters=6, kernel_size=3, activation='relu')     # 28->26
        self.maxpool = layers.MaxPool2D(pool_size=(2,2))                            # 26->13
        self.conv2 = layers.Conv2D(filters=16, kernel_size=3, activation='relu')                       # 13->11
        self.flatten = layers.Flatten()
        self.fc1 = layers.Dense(128,activation='relu')
        self.fc2 = layers.Dense(num_classes)
    
    def call(self, x):
        x = self.conv1(x)
        x = self.maxpool(x)
        x = self.conv2(x)
        x = self.flatten(x)
        x = self.fc1(x)         
        x = self.fc2(x) # 分类输出不需要激活
        return x 

model = MyCNN()
# https://stackoverflow.com/questions/64681232/why-is-it-that-input-shape-does-not-include-the-batch-dimension-when-passed-as
# build() https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#build
model.build(input_shape=[None,28,28,1])
model.summary()
logits = model(inputs)
print(f"logits'shape:{logits.shape}")
运行结果

上述代码每层的使用方法,参见《TensorFlow神经网络层拆解》(https://www.jianshu.com/p/8db8d36e7fc3)

总结

  1. TensorFlow.Keras的子类式实现方式有诸多人诟病,最好还是用函数式API创建模型:https://github.com/tensorflow/tensorflow/issues/25036
上一篇下一篇

猜你喜欢

热点阅读