Keras GAN 同时储存模型和优化器信息

2020-12-28  本文已影响0人  寂风如雪

Github 上有一个很好的开源项目,介绍见这里,包含了用 Keras 实现的 GAN 及其变种。但在研究的过程中发现,难以同时存储模型和优化器状态,如果使用 model.save 则发现再次 load 时优化器信息丢失,而 Keras 的 model 又不能 pickle,也尝试了 pickle 优化器,但似乎会出现一些奇怪的问题,比如 Adam 某些层上的动量不变化等,于是查看了 Keras 源码,并实现了解决方法,分享如下:

from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout
from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.datasets import mnist
from tensorflow.python.keras.saving import hdf5_format
import h5py

def get_optimizer_weights(filepath):
    f = h5py.File(filepath, mode='r')
    weights = hdf5_format.load_optimizer_weights_from_hdf5_group(f)
    return f, weights

class GAN():
    def __init__(self):
        # --------------------------------- #
        #   行28,列28,也就是mnist的shape
        # --------------------------------- #
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        # 28,28,1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100
        # adam优化器
        optimizer_d = Adam(0.0002, 0.5)
        optimizer_c = Adam(0.0002, 0.5)

        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
            optimizer=optimizer_d,
            metrics=['accuracy'])

        self.generator = self.build_generator()
        gan_input = Input(shape=(self.latent_dim,))
        img = self.generator(gan_input)
        # 在训练generate的时候不训练discriminator
        self.discriminator.trainable = False
        # 对生成的假图片进行预测
        validity = self.discriminator(img)
        self.combined = Model(gan_input, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer_c)


    def build_generator(self):
        # --------------------------------- #
        #   生成器,输入一串随机数字
        # --------------------------------- #
        model = Sequential()

        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(noise, img)

    def build_discriminator(self):
        # ----------------------------------- #
        #   评价器,对输入进来的图片进行评价
        # ----------------------------------- #
        model = Sequential()
        # 输入一张图片
        model.add(Flatten(input_shape=self.img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        # 判断真伪
        model.add(Dense(1, activation='sigmoid'))

        img = Input(shape=self.img_shape)
        validity = model(img)

        return Model(img, validity)

    def train(self, epochs, batch_size=128, sample_interval=50):
        # 获得数据
        (X_train, _), (_, _) = mnist.load_data()

        # 进行标准化
        X_train = X_train / 127.5 - 1.
        X_train = np.expand_dims(X_train, axis=3)

        # 创建标签
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            # --------------------------- #
            #   随机选取batch_size个图片
            #   对discriminator进行训练
            # --------------------------- #
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            gen_imgs = self.generator.predict(noise)

            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # --------------------------- #
            #  训练generator
            # --------------------------- #
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            g_loss = self.combined.train_on_batch(noise, valid)
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            if epoch % sample_interval == 0:
                self.sample_images(epoch)

    def sample_images(self, epoch):

        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%d.png" % epoch)
        plt.close()
    
    def save(self, path):
        discriminator_path = path + '_discriminator.h5'
        generator_path = path + '_generator.h5'
        combined_path = path + '_combined.h5'

        self.discriminator.trainable = True
        self.discriminator.save(discriminator_path)
        self.discriminator.trainable = False
        self.generator.save(generator_path)
        self.combined.save(combined_path)
        
    def load(self, path):
        discriminator_path = path + '_discriminator.h5'
        generator_path = path + '_generator.h5'
        combined_path = path + '_combined.h5'
        
        self.discriminator = tf.keras.models.load_model(discriminator_path)
        self.generator = tf.keras.models.load_model(generator_path, compile=False)
        f, c_weights = get_optimizer_weights(combined_path)
        
        self.discriminator.trainable = False
        
        gan_input = Input(shape=(self.latent_dim,))
        img = self.generator(gan_input)
        validity = self.discriminator(img)
        self.combined = Model(gan_input, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))
        
        self.combined.optimizer._create_all_weights(self.combined.trainable_variables)
        self.combined.optimizer.set_weights(c_weights)
        
        f.close()

存储模型:

gan = GAN()
gan.train(epochs=1, batch_size=256, sample_interval=1000)
gan.save('savetest')

读取模型

gan2 = GAN() 
gan2.load('savetest')

祝大家 造假炼丹,学习愉快。

上一篇下一篇

猜你喜欢

热点阅读