深度学习机器学习与数据挖掘

2020机器学习GAN(3)

2020-02-26  本文已影响0人  zidea

代码讲解

naruto_vs_sasuke.jpg

今天 tensorflow2.0 实现一个简单全连接实现 GAN。我们先回忆在做机器学习一般流程,然后我们根据流程步骤按填空。

  1. 加载数据集
  2. 定义模型
  3. 定义损失函数
  4. 定义优化函数
  5. 训练模型
  6. 预测
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt

import numpy as np
import glob
import os

%matplotlib inline

这里使用

tf.__version__
'2.0.0'

加载数据集

考虑我的小笔记本性能,我只能在比较简单数据集上进行训练模型,这里使用即使 mnist 手写数字数据,这么经典我就不多说了。

(train_data,train_labels),(_,_) = tf.keras.datasets.mnist.load_data()
print(train_data.shape)
print(train_labels.shape)
(60000, 28, 28)
(60000,)
plt.imshow(train_data[0])
<matplotlib.image.AxesImage at 0x14322cd50>
output_7_1.png

数据处理

# 因为计算机对于浮点数据进行计算比较舒服
train_data = train_data.reshape(train_data.shape[0],28,28,1).astype('float32')
# 通过对数据进行缩放来实现归一化
train_data = (train_data - 127.5)/127.5
# plt.imshow(train_data[0])

批次也就是每一个训练迭代输入 256 张图片来完成一次训练

BATCH_SIZE = 256
BUFFER_SIZE = 60000

dataset 是 tensorflow 新的数据 api,很好用,也是tensorflow 2.0 让人眼前一亮新功能

dataset = tf.data.Dataset.from_tensor_slices(train_data)
dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
dataset
<BatchDataset shapes: (None, 28, 28, 1), types: tf.float32>

定义生成器(generator)

naruto_generator.jpeg

生成器就是要骗过判别器,生成图片让判别器无法识别出是生成器伪造的,而误认为是真实图片。我们这里用了两层全连接将 100 维度向量转换 784 维,然后通过 Reshape 转换图片格式(28,28,1)

def build_generator():
    model = tf.keras.Sequential()
    # 随机向量使用 100 维度向量     
    model.add(layers.Dense(256,input_shape=(100,),use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    
    model.add(layers.Dense(512,use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    
    model.add(layers.Dense(28*28*1,use_bias=False,activation='tanh'))
    model.add(layers.BatchNormalization())
    
    model.add(layers.Reshape((28,28,1)))
    
    return model
    
    

定义判别模型

<img src="images/sasuke_discriminator.png" width="50%"/>


sakuke_discriminator.png

判别器模型主要目的就是从真实图片辨别出生成器图片,也可以看为教练,不断督促判别器做的更好,开始我们用 Flatten 将图片展平,最终输出 1 表示 0 到 1。

def build_discriminator():
    model = tf.keras.Sequential()
    # 首先我们将图片进行 flatten
    model.add(layers.Flatten())
    
    model.add(layers.Dense(512,use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    
    model.add(layers.Dense(256,use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    # 规范到 0 - 1 ,小于 0.5 我们,     
    model.add(layers.Dense(1))
    
    return model
    

超参数

EPOCHS = 100
noise_dim = 100
learning_rate = 1e-4
num_exp_to_generate = 16
seed = tf.random.normal([num_exp_to_generate,noise_dim])
# 生成图片是否为真实图片
# 因为没有激活所以将 from_logits True
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

定损失函数

在判别器损失函数,当输入真实数据集图片时,我们给出 1 ,而输入为生成器的生成图时给出 0

def discriminator_loss(real_out,fake_out):
    real_loss = cross_entropy(tf.ones_like(real_out),real_out)
    fake_loss = cross_entropy(tf.zeros_like(fake_out),fake_out)
    
    return real_loss + fake_loss
在生成器损失函数
def generator_loss(fake_out):
    return cross_entropy(tf.ones_like(fake_out),fake_out)

优化器

evaluation.jpeg

优化器对于判别器和生成器都使用 Adam 。

generator_optimizer = tf.keras.optimizers.Adam(learning_rate)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate)
discriminator = build_discriminator()
generator = build_generator()
上一篇 下一篇

猜你喜欢

热点阅读