解决tensorflow在训练的时候权重是nan问题
搭建普通的卷積CNN網(wǎng)絡(luò)。
nan表示的是無窮或者是非數(shù)值,比如說你在tensorflow中使用一個(gè)數(shù)除以0,那么得到的結(jié)果就是nan。
在一個(gè)matrix中,如果其中的值都為nan很有可能是因?yàn)椴捎玫腸ost function不合理導(dǎo)致的。
?
當(dāng)使用tensorflow構(gòu)建一個(gè)最簡單的神經(jīng)網(wǎng)絡(luò)的時(shí)候,按照tensorflow官方給出的教程:
https://www.tensorflow.org/get_started/mnist/beginners
http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/mnist_beginners.html ?(中文教程)
?
具體的含義就不解釋了。大概分為三個(gè)部分:1,導(dǎo)入數(shù)據(jù)集;2,搭建模型,并且定義cost function(也叫l(wèi)oss function);3,訓(xùn)練。
對(duì)于過程1,我們采用的不是mnist數(shù)據(jù)集,而是自己定義了一個(gè)數(shù)據(jù)集,其中
對(duì)于過程2,我們使用最簡單的CNN網(wǎng)絡(luò),然后定義cost function的方式是:
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
對(duì)于過程3,我們也采用教程中的例子去訓(xùn)練。
?
但是在初始化W后就立刻查看W參數(shù)的結(jié)果,得到的結(jié)果都是nan,以下是輸出W權(quán)重后的結(jié)果:
這個(gè)現(xiàn)象是由于cost function引起的:
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
上面的語句中的y_是數(shù)據(jù)集的label。我們做的是顯著性檢測,就是數(shù)據(jù)集的ground truth。
并且這個(gè)label或者ground truth一定要是one hot類型的變量。
那什么是one hot類型的變量呢?
舉一個(gè)例子:比如一個(gè)5個(gè)類的數(shù)據(jù)集,用0,1,2,3,4來表示5個(gè)類的標(biāo)簽,因此label=0,1,2,3,4。這時(shí)候有的人會(huì)把y_=0,1,2,3,4。直接輸入到cost function——-tf.reduce_sum(y_*tf.log(y))中,那么這樣會(huì)導(dǎo)致W參數(shù)初始化都是nan。
解決辦法就是我們把label=0,1,2,3,4變?yōu)閛ne hot變量,改變后的結(jié)果是:label=[1,0,0,0,0],[0,1,0,0,0],[0,0,1,0,0],[0,0,0,1,0],[0,0,0,0,1],這樣再輸入到tf.reduce_sum(y_*tf.log(y))中,就是正確的了,如下圖,我們采用的解決辦法是第二種,具體參考下文。
?
那么本文提供兩種方法來解決這個(gè)問題:
1,將y_從原來的類別數(shù)字變?yōu)閛ne hot變量,使用
labels = tf.reshape(labels, [batch_size, 1]) indices = tf.reshape(tf.range(0, batch_size, 1), [batch_size, 1]) labels = tf.sparse_to_dense(tf.concat(values=[indices, labels], axis=1),[batch_size, num_classes], 1.0, 0.0) 將label轉(zhuǎn)為one hot(batch_size是你每次抓取的訓(xùn)練集的個(gè)數(shù)) 2,換一個(gè)cost function,原來的cost function = -tf.reduce_sum(y_*tf.log(y)) 使用的是交叉熵函數(shù),現(xiàn)在我們換成二次代價(jià)函數(shù) cost function = tf.reduce_sum(tf.square(tf.substract(y_,y)))?
轉(zhuǎn)載于:https://www.cnblogs.com/sddai/p/8526108.html
總結(jié)
以上是生活随笔為你收集整理的解决tensorflow在训练的时候权重是nan问题的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 认识零信任安全网络架构
- 下一篇: poj3050 穷竭搜索 挑战程序设计竞