改进版的fashion-mnist DCGAN

2019-08-16  本文已影响0人  圣_狒司机

在loss函数写法上做改进,代码更简单;

import tensorflow as tf
from tensorflow.keras.datasets.fashion_mnist import load_data
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,Reshape,Conv2DTranspose,Conv2D,MaxPool2D,Flatten,BatchNormalization
import numpy as np
import matplotlib.pyplot as plt

(train_x,train_y),(test_x,test_y) = load_data()
train_x = train_x[:150]/255
x_real,y_real = zip(*zip(train_x,np.ones(train_x.shape[0])))

g = Sequential([Dense(4*4*128,input_shape=(10,)),
                Reshape((4,4,128)),
                Conv2DTranspose(64,(4,4),padding="valid",activation="relu"),
                BatchNormalization(),
                Conv2DTranspose(32,(2,2),strides=(2, 2),padding="same",activation="relu"),
                BatchNormalization(),
                Conv2DTranspose(1, (2,2),strides=(2, 2),padding="same",activation="tanh"),
                Reshape((28,28))])

d = Sequential([Reshape((28,28,1),input_shape=(28,28)),
                Conv2D(32,(2,2),padding="same",activation="relu"),
                MaxPool2D((2,2)),
                Conv2D(64,(2,2),padding="same",activation="relu"),
                MaxPool2D((2,2)),
                Conv2D(64,(2,2),padding="valid",activation="relu"),
                MaxPool2D((2,2)),
                Flatten(),
                Dense(1,activation="sigmoid")])
gan = Sequential([g,d])
d.compile(optimizer="adam",loss="binary_crossentropy",metrics=['accuracy'])

for i in range(50):
    print(f"===============判别器第{i+1}轮训练================")
    d.trainable = True
    x_fake,y_fake = zip(*zip(g(tf.random.uniform((train_x.shape[0],10),1,0)),np.zeros(train_x.shape[0])))
    x = x_real + x_fake
    y = y_real + y_fake
    dataset = tf.data.Dataset.from_tensor_slices((np.array(x),np.array(y))).shuffle(150).batch(20)
    d.fit(dataset,epochs=2)
    
    print(f"===============生成器第{i+1}轮训练================")
    d.trainable = False
    gan.compile(optimizer="adam",loss="binary_crossentropy")
    x = tf.random.uniform((100,10),1,0)
    y = 1-d(g(x))
    gan.fit(x,y,epochs=50)

img = g(tf.random.uniform((1,10),1,0))[0]
plt.imshow(img)
plt.show()

训练20个来回,能看到可识别的效果。
可以调节加载数据量,当然越多越慢!

上一篇下一篇

猜你喜欢

热点阅读