Tensorflow学习教程------tfrecords数据格式生成与读取
生活随笔
收集整理的這篇文章主要介紹了
Tensorflow学习教程------tfrecords数据格式生成与读取
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
首先是生成tfrecords格式的數據,具體代碼如下:
#coding:utf-8import os import tensorflow as tf from PIL import Imagecwd = os.getcwd() ''' 此處我加載的數據目錄如下: bt -- 14018.jpg14019.jpg14020.jpgnbt -- 1_ddd.jpg1_dsdfs.jpg1_dfd.jpg這里的bt nbt 就是類別,也就是代碼中的classes '''writer = tf.python_io.TFRecordWriter("train.tfrecords") classes = ['bt','nbt'] for index, name in enumerate(classes):class_path = cwd + '/'+ name +'/' #每一類圖片的目錄地址for img_name in os.listdir(class_path):img_path = class_path + img_name #每一張圖片的路徑img = Image.open(img_path)img = img.resize((224,224)) img_raw = img.tobytes() #將圖片轉化為原生bytesexample = tf.train.Example(features = tf.train.Features(feature={'label':tf.train.Feature(int64_list = tf.train.Int64List(value=[index])),'img_raw':tf.train.Feature(bytes_list = tf.train.BytesList(value=[img_raw]))}))print "write" + ' ' + str(img_path) + "to train.tfrecords."writer.write(example.SerializeToString()) #序列化為字符串 writer.close()然后讀取生成的tfrecords數據,并且將tfrecords里面的數據保存成jpg格式的圖片。具體代碼如下:
#coding:utf-8 import os import tensorflow as tf from PIL import Image cwd = '/media/project/tfLearnning/dataread/pic/' def read_and_decode(filename):#根據文件名生成一個隊列filename_queue = tf.train.string_input_producer([filename])reader = tf.TFRecordReader()_, serialized_example = reader.read(filename_queue) #返回文件名和文件 features = tf.parse_single_example(serialized_example,features={'label':tf.FixedLenFeature([],tf.int64),'img_raw':tf.FixedLenFeature([],tf.string),})img = tf.decode_raw(features['img_raw'],tf.uint8)img = tf.reshape(img,[224,224,3])#img = tf.cast(img,tf.float32) * (1./255) - 0.5 # 將圖片變成tensor#對圖片進行歸一化操作將【0,255】之間的像素歸一化到【-0.5,0.5】,標準化處理可以使得不同的特征具有相同的尺度(Scale)。#這樣,在使用梯度下降法學習參數的時候,不同特征對參數的影響程度就一樣了label = tf.cast(features['label'], tf.int32) #將標簽轉化tensorprint imgprint labelreturn img, label#read_and_decode('train.tfrecords') img, label = read_and_decode('train.tfrecords') #print img.shape, label img_batch, label_batch = tf.train.shuffle_batch([img,label],batch_size=10,capacity=2000,min_after_dequeue=1000) #形成一個batch的數據,由于使用shuffle,因此每次取batch的時候#都是隨機取的,可以使樣本盡可能被充分地訓練,保證min_after值小于capacit值 init = tf.global_variables_initializer()with tf.Session() as sess:sess.run(init)# 創建一個協調器,管理線程coord = tf.train.Coordinator()# 啟動QueueRunner, 此時文件名隊列已經進隊threads = tf.train.start_queue_runners(sess=sess, coord=coord)for i in range(10):example, l = sess.run([img, label]) #從對列中一張一張讀取圖片和標簽#example, l = sess.run([img_batch,label_batch])print(example.shape,l)img1=Image.fromarray(example, 'RGB') #將tensor轉化成圖片格式img1.save(cwd+str(i)+'_'+'Label_'+str(l)+'.jpg')#save image# 通知其他線程關閉 coord.request_stop()# 其他所有線程關閉之后,這一函數才能返回coord.join(threads)?
轉載于:https://www.cnblogs.com/cnugis/p/8393807.html
總結
以上是生活随笔為你收集整理的Tensorflow学习教程------tfrecords数据格式生成与读取的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 5nm Zen 4来了!AMD官宣:锐龙
- 下一篇: 红旗首款MPV HQ9开售:40万买国产