GAN(生成对抗网络) 解释
GAN (生成對(duì)抗網(wǎng)絡(luò))是近幾年深度學(xué)習(xí)中一個(gè)比較熱門(mén)的研究方向,它的變種有上千種。
1.什么是GAN
GAN的英文全稱是Generative Adversarial Network,中文名是生成對(duì)抗網(wǎng)絡(luò)。它由兩個(gè)部分組成,生成器和鑒別器(又稱判別器),它們之間的關(guān)系可以用競(jìng)爭(zhēng)或敵對(duì)關(guān)系來(lái)描述。
我們可以拿捕食者與被捕食者之間的例子來(lái)類似說(shuō)明兩者之間的關(guān)系。在生物進(jìn)化的過(guò)程中,被捕食者會(huì)慢慢演化自己的特征,使自己越來(lái)越不容易被捕食者識(shí)別捕捉到,從而達(dá)到欺騙捕食者的目的;與此同時(shí),捕食者也會(huì)隨著被捕食者的演化來(lái)演化自己對(duì)被捕食者的識(shí)別,使自己越來(lái)越容易識(shí)別捕捉到捕食者。這樣就可以達(dá)到兩者共同進(jìn)化的目的。生成器代表的是被捕食者,鑒別器代表的是捕食者。
2.GAN的原理
GAN的工作原理與上述例子還有略微的不同,GAN是已經(jīng)知道最終鑒別的目標(biāo)是什么,但不知道假目標(biāo)是什么,它會(huì)對(duì)生成器所產(chǎn)生的假目標(biāo)做懲罰并對(duì)真目標(biāo)進(jìn)行獎(jiǎng)勵(lì),這樣鑒別器就知道了不好的假目標(biāo)與好的真目標(biāo)具體是什么。生成器則是希望通過(guò)進(jìn)化,產(chǎn)生比上一次更好的假目標(biāo),使鑒別器對(duì)自己的懲罰更小。以上是一個(gè)循環(huán),在下一個(gè)循環(huán)中鑒別器通過(guò)學(xué)習(xí)上一個(gè)循環(huán)進(jìn)化出的假目標(biāo)和真目標(biāo),再次進(jìn)化對(duì)假目標(biāo)的懲罰,同時(shí)生成器再次進(jìn)化,直到與真目標(biāo)一致,結(jié)束進(jìn)化。
GAN簡(jiǎn)單代碼實(shí)現(xiàn)
#是一個(gè)卷積神經(jīng)網(wǎng)絡(luò),變量名是D,其中一層構(gòu)造方式如下。 w = tf.get_variable('w', [4, 4, c_dim, num_filter], initializer=tf.truncated_normal_initializer(stddev=stddev)) dconv = tf.nn.conv2d(ddata, w, strides=[1, 2, 2, 1], padding='SAME') biases = tf.get_variable('biases', [num_filter], initializer=tf.constant_initializer(0.0)) bias = tf.nn.bias_add(dconv, biases) dconv1 = tf.maximum(bias, leak*bias)#是一個(gè)逆卷積神經(jīng)網(wǎng)絡(luò),變量名是G,其中一層構(gòu)造方式如下。 w = tf.get_variable('w', [4, 4, num_filter, num_filter*2], initializer=tf.random_normal_initializer(stddev=stddev)) deconv = tf.nn.conv2d_transpose(gconv2, w, output_shape=[batch_size, s2, s2, num_filter], strides=[1, 2, 2, 1]) biases = tf.get_variable('biases', [num_filter], initializer=tf.constant_initializer(0.0)) bias = tf.nn.bias_add(deconv, biases) deconv1 = tf.nn.relu(bias, name=scope.name)#的網(wǎng)絡(luò)輸入為一個(gè)維服從-1~1均勻分布的隨機(jī)變量,這里取的是100. batch_z = np.random.uniform(-1, 1, [config.batch_size, self.z_dim]).astype(np.float32) #的網(wǎng)絡(luò)輸入是一個(gè)batch的64*64的圖片, #既可以是手寫(xiě)體數(shù)據(jù)也可以是的一個(gè)batch的輸出。#這個(gè)過(guò)程可以參考上圖的a狀態(tài),判別曲線處于不夠穩(wěn)定的狀態(tài), #兩個(gè)網(wǎng)絡(luò)都還沒(méi)訓(xùn)練好。#訓(xùn)練判別網(wǎng)絡(luò) #判別網(wǎng)絡(luò)的損失函數(shù)由兩部分組成,一部分是真實(shí)數(shù)據(jù)判別為1的損失,一部分是的輸出self.G#判別為0的損失,需要優(yōu)化的損失函數(shù)定義如下。self.G = self.generator(self.z) self.D, self.D_logits = self.discriminator(self.images) self.D_, self.D_logits_ = self.discriminator(self.G, reuse=True) self.d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits, tf.ones_like(self.D))) self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits_, tf.zeros_like(self.D_))) self.d_loss = self.d_loss_real + self.d_loss_fake#然后將一個(gè)batch的真實(shí)數(shù)據(jù)batch_images,和隨機(jī)變量batch_z當(dāng)做輸入,執(zhí)行session更新的參數(shù)。 ##### update discriminator on real d_optim = tf.train.AdamOptimizer(FLAGS.learning_rate, beta1=FLAGS.beta1).minimize(d_loss, var_list=d_vars) ... out1 = sess.run([d_optim], feed_dict={real_images: batch_images, noise_images: batch_z})#這一步可以對(duì)比圖b,判別曲線漸漸趨于平穩(wěn)。 #訓(xùn)練生成網(wǎng)絡(luò) #生成網(wǎng)絡(luò)并沒(méi)有一個(gè)獨(dú)立的目標(biāo)函數(shù),它更新網(wǎng)絡(luò)的梯度來(lái)源是判別網(wǎng)絡(luò)對(duì)偽造圖片求的梯度, #并且是在設(shè)定偽造圖片的label是1的情況下,保持判別網(wǎng)絡(luò)不變, #那么判別網(wǎng)絡(luò)對(duì)偽造圖片的梯度就是向著真實(shí)圖片變化的方向。self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits_, tf.ones_like(self.D_))) #然后用同樣的隨機(jī)變量batch_z當(dāng)做輸入更新g_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) .minimize(self.g_loss, var_list=self.g_vars) out2 = sess.run([g_optim], feed_dict={noise_images:batch_z})參考資料:
link1
link2
總結(jié)
以上是生活随笔為你收集整理的GAN(生成对抗网络) 解释的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 《数据库系统概念》学习笔记——恢复系统
- 下一篇: mysql数据库角色的使用