使用WGAN生成手写字体
生活随笔
收集整理的這篇文章主要介紹了
使用WGAN生成手写字体
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
import sys;
sys.path.append("/home/hxj/anaconda3/lib/python3.6/site-packages")
import numpy as np
import tensorflow as tf
from PIL import Image
from tensorflow.examples.tutorials.mnist import input_data
#這里為了加快速度,先下載好再導入
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
print("##################")def combine(image):assert len(image) == 64rows = []for i in range(8):cols = []for j in range(8):index = i * 8 + jimg = image[index].reshape(28, 28)cols.append(img)row = np.concatenate(tuple(cols), axis=0)rows.append(row)new_image = np.concatenate(tuple(rows), axis=1)return new_image.astype("uint8")def dense(inputs, shape, name, bn=False, act_fun=None):W = tf.get_variable(name + ".w", initializer=tf.random_normal(shape=shape))b = tf.get_variable(name + ".b", initializer=(tf.zeros((1, shape[-1])) + 0.1))y = tf.add(tf.matmul(inputs, W), b)def batch_normalization(inputs, out_size, name, axes=0):mean, var = tf.nn.moments(inputs, axes=[axes])scale = tf.get_variable(name=name + ".scale", initializer=tf.ones([out_size]))offset = tf.get_variable(name=name + ".shift", initializer=tf.zeros([out_size]))epsilon = 0.001return tf.nn.batch_normalization(inputs, mean, var, offset, scale, epsilon, name=name + ".bn")if bn:y = batch_normalization(y, shape[1], name=name + ".bn")if act_fun:y = act_fun(y)return ydef D(inputs, name, reuse=False):with tf.variable_scope(name, reuse=reuse):l1 = dense(inputs, [784, 512], name="relu1", act_fun=tf.nn.relu)l2 = dense(l1, [512, 512], name="relu2", act_fun=tf.nn.relu)l3 = dense(l2, [512, 512], name="relu3", act_fun=tf.nn.relu)y = dense(l3, [512, 1], name="output")return ydef G(inputs, name, reuse=False):with tf.variable_scope(name, reuse=reuse):l1 = dense(inputs, [100, 512], name="relu1", act_fun=tf.nn.relu)l2 = dense(l1, [512, 512], name="relu2", act_fun=tf.nn.relu)l3 = dense(l2, [512, 512], name="relu3", act_fun=tf.nn.relu)y = dense(l3, [512, 784], name="output", bn=True, act_fun=tf.nn.sigmoid)return yz = tf.placeholder(tf.float32, [None, 100], name="noise") # 100
x = tf.placeholder(tf.float32, [None, 784], name="image") # 28*28
real_out = D(x, "D")
gen = G(z, "G")
fake_out = D(gen, "D", reuse=True)vars = tf.trainable_variables()D_PARAMS = [var for var in vars if var.name.startswith("D")]
G_PARAMS = [var for var in vars if var.name.startswith("G")]d_clip = [tf.assign(var, tf.clip_by_value(var, -0.01, 0.01)) for var in D_PARAMS]
d_clip = tf.group(*d_clip) # 限制參數
wd = tf.reduce_mean(real_out) - tf.reduce_mean(fake_out)
d_loss = tf.reduce_mean(fake_out) - tf.reduce_mean(real_out)
g_loss = tf.reduce_mean(-fake_out)d_opt = tf.train.RMSPropOptimizer(1e-3).minimize(d_loss,global_step=tf.Variable(0),var_list=D_PARAMS
)g_opt = tf.train.RMSPropOptimizer(1e-3).minimize(g_loss,global_step=tf.Variable(0),var_list=G_PARAMS
)
is_restore = False
# is_restore = True # 是否第一次訓練(不需要載入模型)
sess = tf.Session()
sess.run(tf.global_variables_initializer())if is_restore:saver = tf.train.Saver()# 提取變量saver.restore(sess, "my_net/GAN_net.ckpt")print("Model restore...")CRITICAL_NUM = 5
for step in range(100 * 1000):if step < 25 or step % 500 == 0:critical_num = 100else:critical_num = CRITICAL_NUMfor ep in range(critical_num):noise = np.random.normal(size=(64, 100))batch_xs = mnist.train.next_batch(64)[0]_, d_loss_v, _ = sess.run([d_opt, d_loss, d_clip], feed_dict={x: batch_xs,z: noise})for ep in range(1):noise = np.random.normal(size=(64, 100))_, g_loss_v = sess.run([g_opt, g_loss], feed_dict={z: noise})print("Step:%d D-loss:%.4f G-loss:%.4f" % (step + 1, d_loss_v, g_loss_v))if step % 1000 == 999:batch_xs = mnist.train.next_batch(64)[0]# batch_xs = pre(batch_xs)noise = np.random.normal(size=(64, 100))mpl_v = sess.run(wd, feed_dict={x: batch_xs,z: noise})print("################## Step %d WD:%.4f ###############" % (step + 1, mpl_v))generate = sess.run(gen, feed_dict={z: noise})generate *= 255generate = np.clip(generate, 0, 255)image = combine(generate)Image.fromarray(image).save("image/Step_%d.jpg" % (step + 1))saver = tf.train.Saver()save_path = saver.save(sess, "my_net/GAN_net.ckpt")print("Model save in %s" % save_path)
sess.close()
實驗結果
訓練1000次
?
?訓練9000次
?訓練15000次
訓練25000次
訓練3300次
訓練42000次
訓練5000次
轉載于:https://www.cnblogs.com/hxjbc/p/8260541.html
總結
以上是生活随笔為你收集整理的使用WGAN生成手写字体的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Zabbix监控Zookeeper健康状
- 下一篇: 【转】先说IEnumerable,我们每