员外带你读论文:SeqGAN论文分享
本次要分享和總結(jié)的論文為:,其論文鏈接SeqGAN,源自,參考的實(shí)現(xiàn)代碼鏈接代碼實(shí)現(xiàn)。
本篇論文結(jié)合了?和?的知識(shí),整篇論文讀下來(lái)難度較大,在這里就淺薄的談下自己的見(jiàn)解。
好了,老規(guī)矩,帶著代碼分析論文。
動(dòng)機(jī)
我們知道?網(wǎng)絡(luò)在計(jì)算機(jī)視覺(jué)上得到了很好的應(yīng)用,然而很可惜的是,其在自然語(yǔ)言處理上并不,最初的??僅僅定義在實(shí)數(shù)領(lǐng)域,?通過(guò)訓(xùn)練出的生成器來(lái)產(chǎn)生合成數(shù)據(jù),然后在合成數(shù)據(jù)上運(yùn)行判別器,判別器的輸出梯度將會(huì)告訴你,如何通過(guò)略微改變合成數(shù)據(jù)而使其更加現(xiàn)實(shí)。一般來(lái)說(shuō)只有在數(shù)據(jù)連續(xù)的情況下,你才可以略微改變合成的數(shù)據(jù),而如果數(shù)據(jù)是離散的,則不能簡(jiǎn)單的通過(guò)改變合成數(shù)據(jù)。例如,如果你輸出了一張圖片,其像素值是,那么接下來(lái)你可以將這個(gè)值改為。如果輸出了一個(gè)單詞“penguin”,那么接下來(lái)就不能將其改變?yōu)椤皃enguin + .001”,因?yàn)闆](méi)有“penguin +.001”這個(gè)單詞。因?yàn)樗械淖匀徽Z(yǔ)言處理()的基礎(chǔ)都是離散值,如“單詞”、“字母”或者“音節(jié)”,??中應(yīng)用??是非常困難的。
?只能衡量一個(gè)完整的句子的好壞程度,對(duì)于部分生成的句子,很難預(yù)測(cè)其后面的部分,無(wú)法很好的對(duì)其打分。
在傳統(tǒng)的?模型中,我們通常用?來(lái)訓(xùn)練模型,但是這個(gè)訓(xùn)練方式也存在一個(gè)嚴(yán)重的問(wèn)題,也就是論文中所說(shuō)的?,在模型訓(xùn)練階段,我們用?作為?,但是在真正的預(yù)測(cè)階段時(shí),我們只能從上一步產(chǎn)生的分布中以某種方式抽樣某一個(gè)?作為下一步的,也就是這個(gè)階段的?的分布可能是不一樣的。
論文的大體思路
這里寫(xiě)圖片描述我們先看左圖:現(xiàn)在有一批,生成器生成一批假數(shù)據(jù),我們利用?的方式來(lái)?生成器,也就是讓生成器不斷擬合?的分布。這個(gè)過(guò)程經(jīng)過(guò)幾個(gè)回合;然后把訓(xùn)練好的生成器生成的數(shù)據(jù)作為,?作為?來(lái)?判別器。這樣就?出了生成器和判別器。
再看右圖:先了解下強(qiáng)化學(xué)習(xí)的四個(gè)重要概率:,?為現(xiàn)在已經(jīng)生成的?,?是下一個(gè)即將生成的?,??為?的生成器,?為?的判別器所回傳的信息。
帶著代碼仔細(xì)分析各個(gè)部分
在實(shí)現(xiàn)代碼中,生成器是一個(gè)?神經(jīng)網(wǎng)絡(luò),判別器是一個(gè)?網(wǎng)絡(luò),其?是由一個(gè)?生成的。
pretrain
由上面的分析可知,在?生成器時(shí),只是利用?的方法來(lái)訓(xùn)練,不需要考慮。
我們先看看代碼中是如何?生成器的:
for epoch in xrange(PRE_EPOCH_NUM):## gen_data_loader存儲(chǔ)真實(shí)數(shù)據(jù)。loss = pre_train_epoch(sess, generator, gen_data_loader)##該操作為利用MLE訓(xùn)練生成器if epoch % 5 == 0:┆ generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)## 利用生成器生成一批假數(shù)據(jù)┆ likelihood_data_loader.create_batches(eval_file)##將假數(shù)據(jù)存進(jìn)likelihood_data_loader┆ test_loss = target_loss(sess, target_lstm, likelihood_data_loader)##測(cè)試下當(dāng)前生成的假數(shù)據(jù)與target_lstm生成的真實(shí)數(shù)據(jù)的loss┆ print 'pre-train epoch ', epoch, 'test_loss ', test_loss┆ buffer = 'epoch:\t'+ str(epoch) + '\tnll:\t' + str(test_loss) + '\n'┆ log.write(buffer)以上過(guò)程循環(huán)?個(gè)循環(huán),不斷的?生成器。
再來(lái)看看怎么?判別器的:
for _ in range(50):generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)##上面的生成器生成一批負(fù)樣本dis_data_loader.load_train_data(positive_file, negative_file)for _ in range(3):┆ dis_data_loader.reset_pointer()┆ for it in xrange(dis_data_loader.num_batch):┆ ┆ x_batch, y_batch = dis_data_loader.next_batch()##獲取一個(gè)batch數(shù)據(jù),二分類(lèi)┆ ┆ feed = {┆ ┆ ┆ discriminator.input_x: x_batch,┆ ┆ ┆ discriminator.input_y: y_batch,┆ ┆ ┆ discriminator.dropout_keep_prob: dis_dropout_keep_prob┆ ┆ }## 判別器是一個(gè)二分類(lèi)的CNN網(wǎng)絡(luò),利用cross-entropy作為損失函數(shù)。┆ ┆ _ = sess.run(discriminator.train_op, feed)##訓(xùn)練判別器生成器的目標(biāo)函數(shù)
由強(qiáng)化學(xué)習(xí)的相關(guān)知識(shí),我們可知其目標(biāo)就是,也就是生成器生成了一句完整的句子后,我們希望盡可能的使其所有?的?之和盡可能的大,也就是如何公式:
如何理解上式呢?其中?可理解為一個(gè)完整句子的?之和,?表示初始狀態(tài),?表示生成器的參數(shù)。后面的求和過(guò)程表示,每生成一個(gè),我們都會(huì)計(jì)算其生成該?的概率與其對(duì)應(yīng)的?值,那么兩者相乘即表示生成該?的期望?值。求和后即為該整句的期望?值。生成器的目標(biāo)就是不斷的。至于?為啥表示成,因?yàn)?是由后面判別器?決定的。
reward 求法
由上面的?,如何求每一步的?呢?也就是求。論文中提到?只能對(duì)一個(gè)完整的句子進(jìn)行打分,而不能對(duì)生成的不完整句子打分,因此引入強(qiáng)化學(xué)習(xí)的方式:
蒙特卡洛樹(shù)搜索方法:
在?的階段時(shí),我們希望生成器生成的?在當(dāng)前生成概率分布中對(duì)應(yīng)的概率與該?的?乘積和越大越好(上面的意義),那么該?的?如何計(jì)算得到呢?要知道只有是一個(gè)完整的句子,其判別器才能對(duì)其進(jìn)行打分。例如我們?cè)谏傻?步的時(shí),后面的?是未知的,我們只能讓生成器繼續(xù)向后面生成,直到生成一個(gè)完整的句子。然后在喂給判別器打分,為了讓這個(gè)打分更有說(shuō)服力,我們讓這個(gè)過(guò)程重復(fù)?次,然后取平均的,可以用如下圖示展示這個(gè)過(guò)程:
這里寫(xiě)圖片描述簡(jiǎn)單來(lái)說(shuō),就是生成器每生成一個(gè)?都會(huì)有相應(yīng)的一個(gè),而這個(gè)?后面的?都是未知的,只能按照生成器來(lái)補(bǔ)全,形成一個(gè)完整的句子,這個(gè)過(guò)程進(jìn)行?次,會(huì)生成?個(gè)不同的完整句子(因?yàn)樯善鞯碾S機(jī)性,不可能出現(xiàn)相同的句子)。然后將這個(gè)句子放到判別器中得到?個(gè)不同的打分結(jié)果,對(duì)這?個(gè)打分結(jié)果取平均作為該?最終的,論文中講到當(dāng)生成第?個(gè)?時(shí):
上式左邊表示?出來(lái)的??個(gè)不同的完整句子。
綜上所述:
那么代碼中是如何實(shí)現(xiàn)這一步的呢?
def get_reward(self, sess, input_x, rollout_num, discriminator): """ input_x: 需要打分的序列 rollout_num: 即sample的次數(shù),即上面的N discriminator: 判別器 """rewards = []for i in range(rollout_num):┆ # given_num between 1 to sequence_length - 1 for a part completed sentence會(huì)遍歷一整句中的每個(gè)token,給其打分┆ for given_num in range(1, self.sequence_length ):┆ ┆ feed = {self.x: input_x, self.given_num: given_num}##生成一批樣本,前give_num的token由input_x提供,give_num后的token由生成器補(bǔ)上。由此生成一批完整的句子。┆ ┆ samples = sess.run(self.gen_x, feed)┆ ┆ feed = {discriminator.input_x: samples, discriminator.dropout_keep_prob: 1.0}┆ ┆ ypred_for_auc = sess.run(discriminator.ypred_for_auc, feed)##喂給判別器,給每個(gè)句子打分,作為reward┆ ┆ ypred = np.array([item[1] for item in ypred_for_auc])┆ ┆ if i == 0:┆ ┆ ┆ rewards.append(ypred)┆ ┆ else:┆ ┆ ┆ rewards[given_num - 1] += ypred## 在rollout_num循環(huán)中,相同位置的reward相加。┆ # the last token reward┆ feed = {discriminator.input_x: input_x, discriminator.dropout_keep_prob: 1.0}##如果give_num已經(jīng)是最后一個(gè)token了,則喂給判別器的樣本就全是input_x。┆ ypred_for_auc = sess.run(discriminator.ypred_for_auc, feed)##注意下:ypred_for_auc只是softmax_logits,二分類(lèi)的,第一個(gè)數(shù)為該樣本為假樣本概率,第二個(gè)數(shù)為其為真樣本概率┆ ypred = np.array([item[1] for item in ypred_for_auc])##我們拿其真樣本概率作為reward┆ if i == 0:┆ ┆ rewards.append(ypred)┆ else:┆ ┆ # completed sentence reward┆ ┆ rewards[self.sequence_length - 1] += ypredrewards = np.transpose(np.array(rewards)) / (1.0 * rollout_num) # batch_size x seq_length##取平均值。return rewards通過(guò)上面的分析,我們可知,每個(gè)?的?都是由判別器得到的。那么這個(gè)打分的過(guò)程是怎么做的呢?我們來(lái)看看判別器的實(shí)現(xiàn)代碼:
不行,代碼太多了,簡(jiǎn)單解說(shuō):生成器就是一個(gè)?網(wǎng)絡(luò),我們會(huì)將?(二分類(lèi)的)??成一個(gè)四維的張量,然后通過(guò)各種的卷積,?操作得到一個(gè)結(jié)果,然后再經(jīng)過(guò)一個(gè)線性操作最終得到只有二維的張量,再做一個(gè)?操作得到?作為?值等等,直接看下面精華代碼:
with tf.name_scope("output"):##num_classes為2,真樣本還是假樣本W(wǎng) = tf.Variable(tf.truncated_normal([num_filters_total, num_classes], stddev=0.1), name="W")b = tf.Variable(tf.constant(0.1, shape=[num_classes]), name="b")l2_loss += tf.nn.l2_loss(W)l2_loss += tf.nn.l2_loss(b)self.scores = tf.nn.xw_plus_b(self.h_drop, W, b, name="scores")self.ypred_for_auc = tf.nn.softmax(self.scores)##rewardself.predictions = tf.argmax(self.scores, 1, name="predictions")# CalculateMean cross-entropy loss with tf.name_scope("loss"):losses = tf.nn.softmax_cross_entropy_with_logits(logits=self.scores, labels=self.input_y)self.loss = tf.reduce_mean(losses) + l2_reg_lambda * l2_loss***其實(shí)就是將該整句被判別器判別為真樣本的概率作為該??的?***,嗯,就是這么簡(jiǎn)單。判別為真樣本的概率越大,則在當(dāng)前步選擇該?越正向。
值得一提的是:這里?的方式并不唯一,論文中的對(duì)比實(shí)驗(yàn)就用了?指標(biāo)作為?來(lái)指導(dǎo)生成器的訓(xùn)練。
好了,怎么求?已經(jīng)搞清楚了。
接下來(lái)再看看判別器是如何訓(xùn)練的?以及他的目標(biāo)函數(shù)。
判別器的目標(biāo)函數(shù)
簡(jiǎn)短解說(shuō):判別器是一個(gè)?網(wǎng)絡(luò),我們喂給判別器的樣本是一個(gè)二分類(lèi)的樣本,即有生成器生成的一批假樣本,也有一批真樣本,然后直接做個(gè)二分類(lèi),損失函數(shù)就是一個(gè)?:
實(shí)現(xiàn)代碼上面已寫(xiě)到。
policy Gradient
這一步有大量的數(shù)學(xué)公式需要推導(dǎo)。
我們由上面的分析,知道生成器的目標(biāo)函數(shù)為:
我們?cè)賮?lái)看看上式是如何得到的:
可以這么理解:上式在?為生成器時(shí),狀態(tài)為已經(jīng)生成?到?的?情況下,當(dāng)前第?步選擇?的?值。那么:
這樣,一個(gè)完整句子的期望?就可以表示成
那么如何?呢?利用?方法,需要對(duì)?求導(dǎo)。
具體求導(dǎo)過(guò)程就不贅述了,如有興趣請(qǐng)看論文。
最后求出的結(jié)果為:
這里寫(xiě)圖片描述這里寫(xiě)圖片描述利用梯度上升法來(lái)更新生成器參數(shù):
那么目標(biāo)函數(shù)的優(yōu)化過(guò)程在代碼中如何實(shí)現(xiàn)呢?
self.g_loss = -tf.reduce_sum(## self.x 為生成器生成一個(gè)序列,我們需要找到這個(gè)序列中每個(gè)token在生成器分布中的概率,然后與對(duì)應(yīng)的reward相乘。求和取負(fù)作為要優(yōu)化的losstf.reduce_sum(┆ tf.one_hot(tf.to_int32(tf.reshape(self.x, [-1])), self.num_emb, 1.0, 0.0) * tf.log(┆ ┆ tf.clip_by_value(tf.reshape(self.g_predictions, [-1, self.num_emb]), 1e-20, 1.0)┆ ), 1) * tf.reshape(self.rewards, [-1]) ) g_opt = self.g_optimizer(self.learning_rate)self.g_grad, _ = tf.clip_by_global_norm(tf.gradients(self.g_loss, self.g_params), self.grad_clip)##更新生成器參數(shù) self.g_updates = g_opt.apply_gradients(zip(self.g_grad, self.g_params))在?可以自動(dòng)的反向求導(dǎo),所以許多細(xì)節(jié)不需要在代碼中顯示。
整體算法流程
這里寫(xiě)圖片描述利用?方法?生成器、判別器,這部分上面已經(jīng)講過(guò)。下面稍微詳細(xì)講下、。
G_step
利用生成器生成一批假樣本。注意生成器每一步都是生成一個(gè)在?上的分布,我們以某種方式抽樣一個(gè)?作為本步生成的。
在?作為目標(biāo)函數(shù)時(shí),我們需得到?中當(dāng)前步的?在當(dāng)前生成分布中的概率,在SeqGAN 中考慮的是當(dāng)前步得到的??在生成分布中的概率以及該?的?,我們利用蒙特卡洛樹(shù)搜索法得到每個(gè)??的?。
利用?更新生成器的參數(shù)。
實(shí)現(xiàn)代碼:
for total_batch in range(TOTAL_BATCH):# Train the generator for one stepfor it in range(1):┆ samples = generator.generate(sess)##生成器生成一批序列## 獲得序列中每個(gè)token 的reward┆ rewards = rollout.get_reward(sess, samples, 16, discriminator)## 將序列與其對(duì)應(yīng)的reward 喂給生成器,以policy gradient更新生成器┆ feed = {generator.x: samples, generator.rewards: rewards}┆ _ = sess.run(generator.g_updates, feed_dict=feed)以上就是訓(xùn)練生成器的過(guò)程, 在這個(gè)階段,判別器不發(fā)生改變只是對(duì)當(dāng)前的生成情況做出反饋,也就是?。
D_step
利用上面已經(jīng)訓(xùn)完的生成器生成一批樣本作為假樣本,加上已有的一批真樣本,作為訓(xùn)練數(shù)據(jù),來(lái)訓(xùn)練一個(gè)二分類(lèi)的判別器。
實(shí)現(xiàn)代碼:
for _ in range(5):generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)dis_data_loader.load_train_data(positive_file, negative_file)for _ in range(3):┆ dis_data_loader.reset_pointer()┆ for it in xrange(dis_data_loader.num_batch):┆ ┆ x_batch, y_batch = dis_data_loader.next_batch()┆ ┆ feed = { ┆ ┆ ┆ discriminator.input_x: x_batch,┆ ┆ ┆ discriminator.input_y: y_batch,┆ ┆ ┆ discriminator.dropout_keep_prob: dis_dropout_keep_prob┆ ┆ } ┆ ┆ _ = sess.run(discriminator.train_op, feed)個(gè)人總結(jié)與疑點(diǎn)
如果用傳統(tǒng)的?網(wǎng)絡(luò)來(lái)做?的任務(wù),那么生成器生成的序列需要喂給判別器,然后利用判別器來(lái)反向的糾正生成器,這個(gè)時(shí)候梯度的微調(diào)不再適用在離散的數(shù)據(jù)上,并且梯度在回傳時(shí)可能會(huì)有一些困難。如下圖:
生成器是以某種方式采樣生成一批數(shù)據(jù)傳給判別器的,這樣判別器反向?qū)⑻荻然貍鹘o生成器時(shí)貌似不太好辦?
而在?中,生成器每生成一個(gè)?時(shí),都會(huì)計(jì)算該?在生成分別中的概率,并且利用上一次訓(xùn)完的判別器計(jì)算出相應(yīng)的,這個(gè)?可以理解為生成該?的權(quán)重?經(jīng)過(guò)幾輪的訓(xùn)練后,?越大的?越正向,越容易生成(這其實(shí)是強(qiáng)化學(xué)習(xí)的思想)。這里面就不需要判別器反向傳梯度給生成器了。這就避免了梯度的微調(diào)導(dǎo)致不適用在離散的樣本上?
不知道上面我個(gè)人的理解是否正確?如有想法歡迎留言討論。
開(kāi)頭就說(shuō)了傳統(tǒng)的?方法存在?問(wèn)題,在本篇論文中,生成器只是在?時(shí)用了?來(lái)做?,而在真正訓(xùn)練生成器的時(shí)候,生成器并沒(méi)有用到,只是在判別器 中用到了?來(lái)訓(xùn)練生成器 ,訓(xùn)練好的生成器能得到更好、更準(zhǔn)確的。那么無(wú)論在,還是在?階段,其生成器的輸入時(shí)一致的,不存在在?時(shí)用?作為?,而在預(yù)測(cè)的時(shí)候用。
備注:公眾號(hào)菜單包含了整理了一本AI小抄,非常適合在通勤路上用學(xué)習(xí)。
往期精彩回顧那些年做的學(xué)術(shù)公益-你不是一個(gè)人在戰(zhàn)斗適合初學(xué)者入門(mén)人工智能的路線及資料下載機(jī)器學(xué)習(xí)在線手冊(cè)深度學(xué)習(xí)在線手冊(cè)備注:加入本站微信群或者qq群,請(qǐng)回復(fù)“加群”加入知識(shí)星球(4500+用戶,ID:92416895),請(qǐng)回復(fù)“知識(shí)星球”喜歡文章,點(diǎn)個(gè)在看
與50位技術(shù)專(zhuān)家面對(duì)面20年技術(shù)見(jiàn)證,附贈(zèng)技術(shù)全景圖總結(jié)
以上是生活随笔為你收集整理的员外带你读论文:SeqGAN论文分享的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 爱可可推荐!关于竞赛思路,方法和代码实践
- 下一篇: 机器学习大佬的进阶之路!一位北大硕士毕业