TFRecords转化和读取
標(biāo)準(zhǔn)TensorFlow格式
TensorFlow的訓(xùn)練過(guò)程其實(shí)就是大量的數(shù)據(jù)在網(wǎng)絡(luò)中不斷流動(dòng)的過(guò)程,而數(shù)據(jù)的來(lái)源在官方文檔[^1](API r1.2)中介紹了三種方式,分別是:
- Feeding。通過(guò)Python直接注入數(shù)據(jù)。
- Reading from files。從文件讀取數(shù)據(jù),本文中的TFRecord屬于此類(lèi)方式。
- Preloaded data。將數(shù)據(jù)以constant或者variable的方式直接存儲(chǔ)在運(yùn)算圖中。
當(dāng)數(shù)據(jù)量較大時(shí),官方推薦采用標(biāo)準(zhǔn)TensorFlow格式[^2](Standard TensorFlow format)來(lái)存儲(chǔ)訓(xùn)練與驗(yàn)證數(shù)據(jù),該格式的后綴名為tfrecord。官方介紹如下:
A TFRecords file represents a sequence of (binary) strings. The format is not random access, so it is suitable for streaming large amounts of data but not suitable if fast sharding or other non-sequential access is desired.
從介紹不難看出,TFRecord文件適用于大量數(shù)據(jù)的順序讀取。而這正好是神經(jīng)網(wǎng)絡(luò)在訓(xùn)練過(guò)程中發(fā)生的事情。
如何使用TFRecord文件
對(duì)于TFRecord文件的使用,官方給出了兩份示例代碼,分別展示了如何生成與讀取該格式的文件。
生成TFRecord文件
第一份代碼convert_to_records.py?[^3]將MNIST里的圖像數(shù)據(jù)轉(zhuǎn)換為了TFRecord格式 。仔細(xì)研讀代碼,可以發(fā)現(xiàn)TFRecord文件中的圖像數(shù)據(jù)存儲(chǔ)在Feature下的image_raw里。image_raw來(lái)自于data_set.images,而后者又來(lái)自mnist.read_data_sets()。因此images的真身藏在mnist.py這個(gè)文件里。
mnist.py并不難找,在Pycharm里按下ctrl后單擊鼠標(biāo)左鍵即可打開(kāi)源代碼。
繼續(xù)追蹤,可以在mnist里發(fā)現(xiàn)圖像來(lái)自extract_images()函數(shù)。該函數(shù)的說(shuō)明里清晰的寫(xiě)明:
Extract the images into a 4D uint8 numpy array [index, y, x, depth].Args:f: A file object that can be passed into a gzip reader.Returns:data: A 4D uint8 numpy array [index, y, x, depth].Raises:ValueError: If the bytestream does not start with 2051.很明顯,返回值變量名為data,是一個(gè)4D Numpy矩陣,存儲(chǔ)值為uint8類(lèi)型,即圖像像素的灰度值(MNIST全部為灰度圖像)。四個(gè)維度分別代表了:圖像的個(gè)數(shù),每個(gè)圖像行數(shù),每個(gè)圖像列數(shù),每個(gè)圖像通道數(shù)。
在獲得這個(gè)存儲(chǔ)著像素灰度值的Numpy矩陣后,使用numpy的tostring()函數(shù)將其轉(zhuǎn)換為Python bytes格式[^4],再使用tf.train.BytesList()函數(shù)封裝為tf.train.BytesList類(lèi),名字為image_raw。最后使用tf.train.Example()將image_raw和其它屬性一遍打包,并調(diào)用tf.python_io.TFRecordWriter將其寫(xiě)入到文件中。
至此,TFRecord文件生成完畢。
可見(jiàn),將自定義圖像轉(zhuǎn)換為T(mén)FRecord的過(guò)程本質(zhì)上是將大量圖像的像素灰度值轉(zhuǎn)換為Python bytes,并與其它Feature組合在一起,最終拼接成一個(gè)文件的過(guò)程。
需要注意的是其它Feature的類(lèi)型不一定必須是BytesList,還可以是Int64List或者FloatList。
讀取TFRecord文件
第二份代碼fully_connected_reader.py?[1]展示了如何從TFRecord文件中讀取數(shù)據(jù)。
讀取數(shù)據(jù)的函數(shù)名為input()。函數(shù)內(nèi)部首先通過(guò)tf.train.string_input_producer()函數(shù)讀取TFRecord文件,并返回一個(gè)queue;然后使用read_and_decode()讀取一份數(shù)據(jù),函數(shù)內(nèi)部用tf.decode_raw()解析出圖像的灰度值,用tf.cast()解析出label的值。之后通過(guò)tf.train.shuffle_batch()的方法生成一批用來(lái)訓(xùn)練的數(shù)據(jù)。并最終返回可供訓(xùn)練的images和labels,并送入inference部分進(jìn)行計(jì)算。
在這個(gè)過(guò)程中,有以下幾點(diǎn)需要留意:
其中第2點(diǎn)的原理我暫時(shí)沒(méi)有弄懂。從代碼上看read_and_decode()返回的是單個(gè)數(shù)據(jù),shuffle_batch接收到的也是單個(gè)數(shù)據(jù),不知道是如何生成批量數(shù)據(jù)的,猜測(cè)與queue有關(guān)系。
所以,讀取TFRecord文件的本質(zhì),就是通過(guò)隊(duì)列的方式依次將數(shù)據(jù)解碼,并按需要進(jìn)行數(shù)據(jù)隨機(jī)化、圖像隨機(jī)化的過(guò)程。
參考
Github: fully_connected_reader.py???
轉(zhuǎn)載于:https://www.cnblogs.com/jyxbk/p/7895313.html
總結(jié)
以上是生活随笔為你收集整理的TFRecords转化和读取的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 软件开发中IT用语-日文和英文对照版
- 下一篇: 《JavaScript高级程序设计》阅读