深度学习

2020机器学习GAN(F)

2020-02-27  本文已影响0人  zidea
naruto_vs_sasuke.png

训练

training.jpeg

好现在定义好判别器模型和生成模型,而且定义好优化器这里梯度下降使用 Adam 给以较小学习率。在开始训练前我们介绍一下 tensorflow 2.0 的求导函数,感觉点 mxnet 的求导函数,这样我们就可以手动对函数进行求导,用法也很像mxnent。

x = tf.constant(3.0)
with tf.GradientTape() as g:
    g.watch(x)
    y = x * x
dy_dx = g.gradient(y, x)
print(dy_dx)
tf.Tensor(6.0, shape=(), dtype=float32)
  1. 先给判别器真实图片,判别器会对这些真实图片进行打分,输出 real_out
  2. 然后给生成器输入噪声变量来随机生产一张图片
  3. 接下里将生成器用噪声生成图片输入给判别器进行判别,给出 fake_out
  4. 有了这些判别器对真实图片和生成图片进行打分结果,fake_out 和 real_out 我们就可以用这些结果来评判生成器模型和判别器模型
  5. 通过梯度下降不断优化参数来提高判别器和生成器的精度,减少损失函数
sakuke_discriminator.png
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE,noise_dim])
    # tape 胶带,磁带意思     
    with tf.GradientTape() as generator_tape, tf.GradientTape() as discriminator_tape:
        real_out = discriminator(images,training=True)
        gen_images = generator(noise,training=True)
        fake_out = discriminator(gen_images,training=True)
        
        gen_loss = generator_loss(fake_out)
        disc_loss = discriminator_loss(real_out,fake_out)
    generator_gradient = generator_tape.gradient(gen_loss,generator.trainable_variables)
    discriminator_gradient = discriminator_tape.gradient(disc_loss,discriminator.trainable_variables)
    generator_optimizer.apply_gradients(zip(generator_gradient,generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(discriminator_gradient,discriminator.trainable_variables))

可视化生成图片

我们将随机变量传入到生成器(generator)来生成图片,这里显示图片我们用到

def show_images(generator,test_noise):
    gen_images = generator(test_noise,training=False)
    fig = plt.figure(figsize=(8,8))
    for row in range(gen_images.shape[0]):
        plt.subplot(4,4,(row + 1))
        plt.imshow((gen_images[row,:,:,0] + 1)/2,cmap='gray')
        plt.axis('off')
    plt.show()
    

开始训练

def train(dataset,epochs):
    print(epochs)
    for epoch in range(epochs):
        for image_batch in dataset:
            train_step(image_batch)
            print(".",end="")
        show_images(generator,seed)
dataset
train(dataset,100)
<BatchDataset shapes: (None, 28, 28, 1), types: tf.float32>

我们简单选择几张,大家可以手动coding 体验一下


output_37_1.png
output_37_19.png
上一篇 下一篇

猜你喜欢

热点阅读