导入环境
import numpy as np |
数据准备
运行程序下面代码,mnist数据集会自动下载mnist = input_data.read_data_sets('data')
获得输入数据
def get_inputs(noise_dim, image_height, image_width, image_depth): |
生成器
def get_generator(noise_img, output_dim, is_train=True, alpha=0.01): |
判别器
def get_discriminator(inputs_img, reuse=False, alpha=0.01): |
目标函数
def get_loss(inputs_real, inputs_noise, image_depth, smooth=0.1): |
优化器
def get_optimizer(g_loss, d_loss, learning_rate=0.001): |
显示图片
def plot_images(samples): |
开始训练
# 定义参数 |