利用GAN生成CIFAR10图片

2020-03-07  本文已影响0人  dataengineer

要点

代码部分

# load required libraries

import tensorflow as tf
import matplotlib.pyplot as plt
from keras.utils.vis_utils import plot_model
import numpy as np
tf.__version__

'2.0.0'

# load CIFAR10 datasets

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
# print the shapes of training and test data

x_train.shape, y_train.shape, x_test.shape, y_test.shape

((50000, 32, 32, 3), (50000, 1), (10000, 32, 32, 3), (10000, 1))

# plot training data

fig = plt.gcf()
fig.set_size_inches(10,10)
for i in range(49):
    plt.subplot(7,7,1+i)
    plt.imshow(x_train[i])
image.png
# define the standalone discriminator model

def discriminator_model():
    model = tf.keras.models.Sequential([
        tf.keras.layers.Conv2D(filters = 64, kernel_size = (3,3), padding = 'same', input_shape = (32,32,3)),
        tf.keras.layers.LeakyReLU(alpha = 0.2),
        tf.keras.layers.Conv2D(filters = 128, kernel_size = (3,3), padding = 'same', strides = (2,2)),
        tf.keras.layers.LeakyReLU(alpha = 0.2),
        tf.keras.layers.Conv2D(filters = 128, kernel_size = (3,3), padding = 'same', strides = (2,2)),
        tf.keras.layers.LeakyReLU(alpha = 0.2),
        tf.keras.layers.Conv2D(filters = 256, kernel_size = (3,3), padding = 'same', strides = (2,2)),
        tf.keras.layers.LeakyReLU(alpha = 0.2),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dropout(0.4),
        tf.keras.layers.Dense(units = 1, activation = 'sigmoid')
    ])
    model.compile(loss = 'binary_crossentropy', 
              optimizer = tf.keras.optimizers.Adam(lr = 0.0002, beta_1 = 0.5), metrics = ['accuracy'])
    
    return model

提示:判别器没有pooling layer,而是采用2*2的stride,其效果和pooling layer类似

# show the summary and graph of the discriminator model 

model = discriminator_model()

model.summary()
image.png
# convert unsigned int to float32

x_train = x_train.astype('float32')
x_train = (x_train - 127.5)/127.5

提示:生成器以tanh为激活函数,其生成的像素值范围为[-1,1],因此,真实图片的像素值范围也应从[0,255]标准化为[-1,1]

# generate points in latent space as the inputs of the generator

def generate_latent_points(latent_dim,n_samples):
    x_input = np.random.randn(latent_dim * n_samples)
    x_input = x_input.reshape(n_samples,latent_dim)
    return x_input    
# randomly select n real samples

def generate_real_samples(dataset, n_samples):
    # define random instances
    ix = np.random.randint(0, dataset.shape[0], n_samples)
    # retrieve selected images
    x = dataset[ix]
    # generate class label (label = 1)
    y = np.ones((n_samples,1))
    return x,y

# generate n fake samples with class label

def generate_fake_samples(g_model, latent_dim, n_samples):
    # generate points in latent space
    x_input = generate_latent_points(latent_dim, n_samples)
    # predict outputs
    x = g_model.predict(x_input)
    # generate class label (label = 0)
    y = np.zeros((n_samples,1))
    return x,y
# define the standalone generator model

def generator_model(latent_dim):
    n_nodes = 256*4*4
    model = tf.keras.models.Sequential([
        tf.keras.layers.Dense(units = n_nodes, input_dim = latent_dim),
        tf.keras.layers.LeakyReLU(alpha = 0.2),
        tf.keras.layers.Reshape((4,4,256)),
        # upsample to 8*8
        tf.keras.layers.Conv2DTranspose(filters = 128, kernel_size = (4,4), padding = 'same', strides = (2,2)),
        tf.keras.layers.LeakyReLU(alpha = 0.2),
        # upsample to 16*16
        tf.keras.layers.Conv2DTranspose(filters = 128, kernel_size = (4,4), padding = 'same', strides = (2,2)),
        tf.keras.layers.LeakyReLU(alpha = 0.2),
        # upsample to 32*32
        tf.keras.layers.Conv2DTranspose(filters = 128, kernel_size = (4,4), padding = 'same', strides = (2,2)),
        tf.keras.layers.LeakyReLU(alpha = 0.2),
        # output layer
        tf.keras.layers.Conv2D(filters = 3, kernel_size = (3,3), activation = 'tanh', padding = 'same')      
    ])
    return model
# show the summary and graph of the generator model

model = generator_model(100)

model.summary()
image.png
# define gan model (only generator model can be updated)

def gan_model(g_model, d_model):
    # freeze discriminator model
    d_model.trainable = False
    
    model = tf.keras.models.Sequential([
        g_model,
        d_model
    ])
    
    model.compile(loss = 'binary_crossentropy', 
              optimizer = tf.keras.optimizers.Adam(lr = 0.0002, beta_1 = 0.5))
    
    return model
# show the summary and graph of the gan model

latent_dim = 100

g_model = generator_model(latent_dim)

d_model = discriminator_model()

gan_model = gan_model(g_model,d_model)

gan_model.summary()
image.png
# show and save the plots of generated images

def save_plot(examples, epoch, n = 7):
    # scale from [-1,1] to [0,1]
    examples = (examples + 1)/2.0
    # make plot
    for i in range(n*n):
        plt.subplot(n,n,i+1)
        plt.imshow(examples[i])
    
    # save plots
    filename = 'generated_plot_e%03d.png' % (epoch+1)
    plt.savefig(filename)

# evaluate discriminator model performance, display generated images, save generator model

def summarize_performance(epoch, g_model, d_model, dataset, latent_dim, n_samples = 150):
    # prepare real samples
    x_real, y_real = generate_real_samples(dataset, n_samples)
    # evaluate discriminator on real samples
    _, acc_real = d_model.evaluate(x_real, y_real, verbose = 0)
    # prepare fake samples
    x_fake, y_fake = generate_fake_samples(g_model, latent_dim, n_samples)
    # evaluate discriminator on fake samples
    _, acc_fake = d_model.evaluate(x_fake, y_fake, verbose = 0)
    # display discriminator performance
    print('>Accuracy real: %.0f%%, fake: %.0f%%' % (acc_real*100, acc_fake*100))
    
    # show and save the plots of generated images
    save_plot(x_fake, epoch)
    
    # save generator model tile file
    filename = 'generator_model_%03d.h5' % (epoch+1)
    g_model.save(filename)
# train gan model

def train_gan(g_model, d_model, gan_model, dataset, latent_dim, n_epochs = 20, n_batch = 128):
    bat_per_epoch = int(dataset.shape[0] / n_batch)
    half_batch = int(n_batch / 2)
    # manually enumerate epochs
    for i in range(n_epochs):
        # enumerate batches
        for j in range(bat_per_epoch):
            # randomly select n real samples
            x_real, y_real = generate_real_samples(dataset, half_batch)
            # update standalone discriminator model
            d_loss1, _ = d_model.train_on_batch(x_real, y_real)
            # generate fake samples
            x_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
            # update standalone discriminator model again
            d_loss2, _ = d_model.train_on_batch(x_fake, y_fake)
            # generate points in latent space as the inputs of generator model
            x_gan = generate_latent_points(latent_dim, n_batch)
            # generate class label for fake samples (label = 1)
            y_gan = np.ones((n_batch,1))
            # update the generator model with discriminator model errors
            g_loss = gan_model.train_on_batch(x_gan, y_gan)
            # display the loss
            print('>%d, %d/%d, d1=%.3f, d2=%.3f g=%.3f' % (i+1, j+1, bat_per_epoch, d_loss1, d_loss2, g_loss))
        
        # evaluate model performance every 5 epochs  
        if (i + 1)%5 == 0:
            summarize_performance(i, g_model, d_model, dataset, latent_dim)

提示:GAN的目标是让生成器生成“看似真实”的图片,然而这些图片的质量高低无法通过客观的误差指标来体现,只能由程序员进行人工判读。换言之,即程序员不检查图片的质量,就不知道什么时候该停止训练。例如,某一个epoch结束后,生成器输出的图片质量很高,此时若不停止训练,之后生成的图片质量会发生波动(GAN的对抗性导致每一个batch后生成器都会发生变化),也可能提升,也可能降低。因此,在实际训练过程中,程序员要周期性地评估判别器分辨真假图片的能力(即分类精度),也要周期性地生成图片并进行人工判读,还要周期性地保存生成器模型

# train gan model

train_gan(g_model, d_model, gan_model, x_train, latent_dim)

1, 1/390, d1=0.376, d2=0.280 g=1.740
1, 2/390, d1=0.351, d2=0.322 g=1.679
1, 3/390, d1=0.274, d2=0.299 g=1.866
1, 4/390, d1=0.301, d2=0.272 g=2.027
1, 5/390, d1=0.257, d2=0.230 g=2.256
1, 6/390, d1=0.204, d2=0.186 g=2.558
……
……
20, 387/390, d1=0.724, d2=0.636 g=0.865
20, 388/390, d1=0.665, d2=0.623 g=0.837
20, 389/390, d1=0.678, d2=0.717 g=0.867
20, 390/390, d1=0.718, d2=0.606 g=0.960
Accuracy real: 51%, fake: 89%

image.png
提示:本代码以20个epochs为示例,每5个epochs评估一次模型性能,20个epochs共评估模型性能4次,生成图片4副,保存模型4个。接下来,就可以用性能最好的生成器生成图片了。
# generate images with final generator model

model = tf.keras.models.load_model('generator_model_020.h5') # load model saved after 20 epochs

latent_points = generate_latent_points(100,100) # generate points in latent space

X = model.predict(latent_points) # generate images

X = (X + 1)/2.0 # scale the range from [-1,1] to [0,1]
# plot the images

fig = plt.gcf()
fig.set_size_inches(20,20)
for i in range(100):
    plt.subplot(10,10,1+i)
    plt.imshow(X[i])
image.png
上一篇 下一篇

猜你喜欢

热点阅读