tensorflow2.0 keras DCGAN mnist 手写数据集生成
前言这个是关于DCGAN 的一个简单示例 ,在经过100次迭代后能输出比较好的结果,为了让后人少走弯路或者新人更快入门 ,现在将代码的完整结构用面向对象的方法写如下。代码的结构尽量做到结构明了,我相信没有过多的介绍也能让大家明白代码的结构,相应的注释在代码中已经较为详细代码import numpy as npimport matplotlibfrom matplotlib import...
·
前言
这个是关于DCGAN 的一个简单示例 ,在经过100次迭代后能输出比较好的结果,为了让后人少走弯路或者新人更快入门 ,现在将代码的完整结构用面向对象的方法写如下。
代码的结构尽量做到结构明了,我相信没有过多的介绍也能让大家明白代码的结构,相应的注释在代码中已经较为详细
代码
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers,optimizers,losses
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.python.keras import backend as K
from tensorflow.keras.utils import plot_model
from IPython.display import Image
import cv2
import PIL
import json, os
import sys
import labelme
import labelme.utils as utils
import glob
import itertools
class GAN():
def __init__(self, #定义全局变量
):
self.img_shape = (28, 28, 1)
self.save_path = r'C:\Users\Administrator\Desktop\photo\DCGAN.h5'
self.img_path = r'C:\Users\Administrator\Desktop\photo'
self.batch_size =100
self.latent_dim = 100
self.sample_interval=1
self.epoch=100
#建立GAN模型的方法
self.generator_model = self.build_generator()
self.discriminator_model = self.build_discriminator()
self.model = self.bulid_model()
def build_generator(self):#生成器
input=keras.Input(shape=self.latent_dim)
x=layers.Dense(256*7*7)(input)
x=layers.BatchNormalization(momentum=0.8)(x)
x=layers.Activation(activation='relu')(x)
x=layers.Reshape((7,7,256))(x)
x=layers.UpSampling2D((2,2))(x)
x = layers.Conv2D(128, (3, 3), padding='same')(x)
x = layers.Conv2D(128, (3, 3), padding='same')(x)
x = layers.BatchNormalization(momentum=0.8)(x)
x = layers.Activation(activation='relu')(x)
x = layers.UpSampling2D((2, 2))(x)
x = layers.Conv2D(64, (3, 3), padding='same')(x)
x = layers.Conv2D(64, (3, 3), padding='same')(x)
x = layers.BatchNormalization(momentum=0.8)(x)
x = layers.Activation(activation='relu')(x)
x = layers.Conv2D(32, (3, 3), padding='same')(x)
x = layers.Conv2D(32, (3, 3), padding='same')(x)
x = layers.BatchNormalization(momentum=0.8)(x)
x = layers.Activation(activation='relu')(x)
output = layers.Conv2D(1, (3, 3), padding='same',activation='sigmoid')(x)
model=keras.Model(inputs=input,outputs=output,name='generator')
model.summary()
return model
def build_discriminator(self):#判别器
input=keras.Input(shape=self.img_shape)
x = layers.Conv2D(32, (3, 3), padding='same')(input)
x = layers.Conv2D(32, (3, 3), padding='same')(x)
x = layers.LeakyReLU(alpha=0.2)(x)
x = layers.Dropout(0.25)(x)
x = layers.MaxPooling2D(2,2)(x)
x = layers.Conv2D(64, (3, 3), padding='same')(x)
x = layers.Conv2D(64, (3, 3), padding='same')(x)
x = layers.LeakyReLU(alpha=0.2)(x)
x = layers.Dropout(0.25)(x)
x = layers.MaxPooling2D(2, 2)(x)
x = layers.Conv2D(128, (3, 3), padding='same')(x)
x = layers.Conv2D(128, (3, 3), padding='same')(x)
x = layers.LeakyReLU(alpha=0.2)(x)
x = layers.Dropout(0.25)(x)
x = layers.MaxPooling2D(2, 2)(x)
x = layers.Conv2D(256, (3, 3), padding='same')(x)
x = layers.Conv2D(256, (3, 3), padding='same')(x)
x = layers.LeakyReLU(alpha=0.2)(x)
x = layers.Dropout(0.25)(x)
x=layers.Flatten()(x)
output=layers.Dense(1,activation='sigmoid')(x)
model=keras.Model(inputs=input,outputs=output,name='discriminator')
model.summary()
return model
def bulid_model(self):#建立GAN模型
self.discriminator_model.compile(loss='binary_crossentropy',
optimizer=keras.optimizers.Adam(0.0001, 0.00001),
metrics=['accuracy'])
self.discriminator_model.trainable = False#使判别器不训练
inputs = keras.Input(shape=self.latent_dim)
img = self.generator_model(inputs)
outputs = self.discriminator_model(img)
model = keras.Model(inputs=inputs, outputs=outputs)
model.summary()
model.compile(optimizer=keras.optimizers.Adam(0.0001, 0.00001),
loss='binary_crossentropy',
)
return model
def load_data(self):
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()
train_images = train_images /255
train_images = np.expand_dims(train_images, axis=3)
print('img_number:',train_images.shape)
return train_images
def train(self):
train_images=self.load_data()#读取数据
#生成标签
valid = np.ones((self.batch_size, 1))
fake = np.zeros((self.batch_size, 1))
step=int(train_images.shape[0]/self.batch_size)#计算步长
print('step:',step)
for epoch in range(self.epoch):
train_images = (tf.random.shuffle(train_images)).numpy()#每个epoch打乱一次
if epoch % self.sample_interval == 0:
self.generate_sample_images(epoch)
for i in range(step):
idx = np.arange(i*self.batch_size,i*self.batch_size+self.batch_size,1)#生成索引
imgs =train_images[idx]#读取索引对应的图片
noise = np.random.normal(0, 1, (self.batch_size, 100)) # 生成标准的高斯分布噪声
gan_imgs = self.generator_model.predict(noise)#通过噪声生成图片
#----------------------------------------------训练判别器
discriminator_loss_real = self.discriminator_model.train_on_batch(imgs, valid) # 真实数据对应标签1
discriminator_loss_fake = self.discriminator_model.train_on_batch(gan_imgs, fake) # 生成的数据对应标签0
discriminator_loss = 0.5 * np.add(discriminator_loss_real, discriminator_loss_fake)
#----------------------------------------------- 训练生成器
noise = np.random.normal(0, 1, (self.batch_size, 100))
generator_loss = self.model.train_on_batch(noise, valid)
print("epoch:%d step:%d [discriminator_loss: %f, acc: %.2f%%] [generator_loss: %f]" % (
epoch,i,discriminator_loss[0], 100 * discriminator_loss[1], generator_loss))
self.model.save(self.save_path)#每个epoch存储模型
def generate_sample_images(self, epoch):
row, col = 5, 5#行列的数字
noise = np.random.normal(0, 1, (row * col, self.latent_dim))#生成噪声
gan_imgs = self.generator_model.predict(noise)
fig, axs = plt.subplots(row, col)#生成5*5的画板
idx = 0
for i in range(row):
for j in range(col):
axs[i, j].imshow(gan_imgs[idx, :, :, 0], cmap='gray')
axs[i, j].axis('off')
idx += 1
fig.savefig(self.img_path+"/%d.png" % epoch)
plt.close()#关闭画板
def pred(self):
model=keras.models.load_model(self.save_path)
model.summary()
noise = np.random.normal(0, 1, (1, self.latent_dim))
generator=keras.Model(inputs=model.layers[1].input,outputs=model.layers[1].output)
generator.summary()
img=np.squeeze(generator.predict([noise]))
plt.imshow(img)
plt.show()
print(img.shape)
if __name__ == '__main__':
GAN = GAN()
GAN.train()
预测效果
可见网络生成的图片已经真的无法分清楚是真是假了
更多推荐
所有评论(0)