Tensorflow生成自己的图片数据集TFrecords(支持多标签label)
Tensorflow生成自己的圖片數據集TFrecords
? ? ? ?尊重原創,轉載請注明出處:https://blog.csdn.net/guyuealian/article/details/80857228
? ? ? ?使用TensorFlow進行網絡訓練時,為了提高讀取數據的效率,一般建議將訓練數據轉換為TFrecords格式。為了方面調用,本博客提供一個可通用,已經封裝好的create_tf_record.py模塊,方便以后調用。
? ? ? 博客Github源碼:https://github.com/PanJinquan/tensorflow-learning-tutorials?->tf_record_demo文件夾(覺得可以,還請給個“Star”哦
目錄
Tensorflow生成自己的圖片數據集TFrecords
1.項目結構
2.生成自己的圖片數據集TFrecords
2.1 生成單個record文件 (單label)
2.2 生成單個record文件 (多label)
2.3 生成分割多個record文件?
3. 直接文件讀取方式
4.數據輸入管道:Pipeline機制
map
prefetch
repeat
完整代碼
5.參考資料:
1.項目結構
項目目錄結構如下所示:
其中train.txt保存圖片的路徑和標簽信息
dog/1.jpg 0 dog/2.jpg 0 dog/3.jpg 0 dog/4.jpg 0 cat/1.jpg 1 cat/2.jpg 1 cat/3.jpg 1 cat/4.jpg 12.生成自己的圖片數據集TFrecords
使用下面create_tf_record.py可以生成自己的圖片數據集TFrecords,完整代碼和解析如下:
2.1 生成單個record文件 (單label)
? ? ?下面是封裝好的py文件,可以直接生成單個record文件 ,當然這里假設只有一個label情況
# -*-coding: utf-8 -*- """@Project: create_tfrecord@File : create_tfrecord.py@Author : panjq@E-mail : pan_jinquan@163.com@Date : 2018-07-27 17:19:54@desc : 將圖片數據保存為單個tfrecord文件 """##########################################################################import tensorflow as tf import numpy as np import os import cv2 import matplotlib.pyplot as plt import random from PIL import Image########################################################################## def _int64_feature(value):return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) # 生成字符串型的屬性 def _bytes_feature(value):return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) # 生成實數型的屬性 def float_list_feature(value):return tf.train.Feature(float_list=tf.train.FloatList(value=value))def get_example_nums(tf_records_filenames):'''統計tf_records圖像的個數(example)個數:param tf_records_filenames: tf_records文件路徑:return:'''nums= 0for record in tf.python_io.tf_record_iterator(tf_records_filenames):nums += 1return numsdef show_image(title,image):'''顯示圖片:param title: 圖像標題:param image: 圖像的數據:return:'''# plt.figure("show_image")# print(image.dtype)plt.imshow(image)plt.axis('on') # 關掉坐標軸為 offplt.title(title) # 圖像題目plt.show()def load_labels_file(filename,labels_num=1,shuffle=False):'''載圖txt文件,文件中每行為一個圖片信息,且以空格隔開:圖像路徑 標簽1 標簽2,如:test_image/1.jpg 0 2:param filename::param labels_num :labels個數:param shuffle :是否打亂順序:return:images type->list:return:labels type->list'''images=[]labels=[]with open(filename) as f:lines_list=f.readlines()if shuffle:random.shuffle(lines_list)for lines in lines_list:line=lines.rstrip().split(' ')label=[]for i in range(labels_num):label.append(int(line[i+1]))images.append(line[0])labels.append(label)return images,labelsdef read_image(filename, resize_height, resize_width,normalization=False):'''讀取圖片數據,默認返回的是uint8,[0,255]:param filename::param resize_height::param resize_width::param normalization:是否歸一化到[0.,1.0]:return: 返回的圖片數據'''bgr_image = cv2.imread(filename)if len(bgr_image.shape)==2:#若是灰度圖則轉為三通道print("Warning:gray image",filename)bgr_image = cv2.cvtColor(bgr_image, cv2.COLOR_GRAY2BGR)rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)#將BGR轉為RGB# show_image(filename,rgb_image)# rgb_image=Image.open(filename)if resize_height>0 and resize_width>0:rgb_image=cv2.resize(rgb_image,(resize_width,resize_height))rgb_image=np.asanyarray(rgb_image)if normalization:# 不能寫成:rgb_image=rgb_image/255rgb_image=rgb_image/255.0# show_image("src resize image",image)return rgb_imagedef get_batch_images(images,labels,batch_size,labels_nums,one_hot=False,shuffle=False,num_threads=1):''':param images:圖像:param labels:標簽:param batch_size::param labels_nums:標簽個數:param one_hot:是否將labels轉為one_hot的形式:param shuffle:是否打亂順序,一般train時shuffle=True,驗證時shuffle=False:return:返回batch的images和labels'''min_after_dequeue = 200capacity = min_after_dequeue + 3 * batch_size # 保證capacity必須大于min_after_dequeue參數值if shuffle:images_batch, labels_batch = tf.train.shuffle_batch([images,labels],batch_size=batch_size,capacity=capacity,min_after_dequeue=min_after_dequeue,num_threads=num_threads)else:images_batch, labels_batch = tf.train.batch([images,labels],batch_size=batch_size,capacity=capacity,num_threads=num_threads)if one_hot:labels_batch = tf.one_hot(labels_batch, labels_nums, 1, 0)return images_batch,labels_batchdef read_records(filename,resize_height, resize_width,type=None):'''解析record文件:源文件的圖像數據是RGB,uint8,[0,255],一般作為訓練數據時,需要歸一化到[0,1]:param filename::param resize_height::param resize_width::param type:選擇圖像數據的返回類型None:默認將uint8-[0,255]轉為float32-[0,255]normalization:歸一化float32-[0,1]standardization:標準化float32-[0,1],再減均值中心化:return:'''# 創建文件隊列,不限讀取的數量filename_queue = tf.train.string_input_producer([filename])# create a reader from file queuereader = tf.TFRecordReader()# reader從文件隊列中讀入一個序列化的樣本_, serialized_example = reader.read(filename_queue)# get feature from serialized example# 解析符號化的樣本features = tf.parse_single_example(serialized_example,features={'image_raw': tf.FixedLenFeature([], tf.string),'height': tf.FixedLenFeature([], tf.int64),'width': tf.FixedLenFeature([], tf.int64),'depth': tf.FixedLenFeature([], tf.int64),'label': tf.FixedLenFeature([], tf.int64)})tf_image = tf.decode_raw(features['image_raw'], tf.uint8)#獲得圖像原始的數據tf_height = features['height']tf_width = features['width']tf_depth = features['depth']tf_label = tf.cast(features['label'], tf.int32)# PS:恢復原始圖像數據,reshape的大小必須與保存之前的圖像shape一致,否則出錯# tf_image=tf.reshape(tf_image, [-1]) # 轉換為行向量tf_image=tf.reshape(tf_image, [resize_height, resize_width, 3]) # 設置圖像的維度# 恢復數據后,才可以對圖像進行resize_images:輸入uint->輸出float32# tf_image=tf.image.resize_images(tf_image,[224, 224])# [3]數據類型處理# 存儲的圖像類型為uint8,tensorflow訓練時數據必須是tf.float32if type is None:tf_image = tf.cast(tf_image, tf.float32)elif type == 'normalization': # [1]若需要歸一化請使用:# 僅當輸入數據是uint8,才會歸一化[0,255]# tf_image = tf.cast(tf_image, dtype=tf.uint8)# tf_image = tf.image.convert_image_dtype(tf_image, tf.float32)tf_image = tf.cast(tf_image, tf.float32) * (1. / 255.0) # 歸一化elif type == 'standardization': # 標準化# tf_image = tf.cast(tf_image, dtype=tf.uint8)# tf_image = tf.image.per_image_standardization(tf_image) # 標準化(減均值除方差)# 若需要歸一化,且中心化,假設均值為0.5,請使用:tf_image = tf.cast(tf_image, tf.float32) * (1. / 255) - 0.5 # 中心化# 這里僅僅返回圖像和標簽# return tf_image, tf_height,tf_width,tf_depth,tf_labelreturn tf_image,tf_labeldef create_records(image_dir,file, output_record_dir, resize_height, resize_width,shuffle,log=5):'''實現將圖像原始數據,label,長,寬等信息保存為record文件注意:讀取的圖像數據默認是uint8,再轉為tf的字符串型BytesList保存,解析請需要根據需要轉換類型:param image_dir:原始圖像的目錄:param file:輸入保存圖片信息的txt文件(image_dir+file構成圖片的路徑):param output_record_dir:保存record文件的路徑:param resize_height::param resize_width:PS:當resize_height或者resize_width=0是,不執行resize:param shuffle:是否打亂順序:param log:log信息打印間隔'''# 加載文件,僅獲取一個labelimages_list, labels_list=load_labels_file(file,1,shuffle)writer = tf.python_io.TFRecordWriter(output_record_dir)for i, [image_name, labels] in enumerate(zip(images_list, labels_list)):image_path=os.path.join(image_dir,images_list[i])if not os.path.exists(image_path):print('Err:no image',image_path)continueimage = read_image(image_path, resize_height, resize_width)image_raw = image.tostring()if i%log==0 or i==len(images_list)-1:print('------------processing:%d-th------------' % (i))print('current image_path=%s' % (image_path),'shape:{}'.format(image.shape),'labels:{}'.format(labels))# 這里僅保存一個label,多label適當增加"'label': _int64_feature(label)"項label=labels[0]example = tf.train.Example(features=tf.train.Features(feature={'image_raw': _bytes_feature(image_raw),'height': _int64_feature(image.shape[0]),'width': _int64_feature(image.shape[1]),'depth': _int64_feature(image.shape[2]),'label': _int64_feature(label)}))writer.write(example.SerializeToString())writer.close()def disp_records(record_file,resize_height, resize_width,show_nums=4):'''解析record文件,并顯示show_nums張圖片,主要用于驗證生成record文件是否成功:param tfrecord_file: record文件路徑:return:'''# 讀取record函數tf_image, tf_label = read_records(record_file,resize_height,resize_width,type='normalization')# 顯示前4個圖片init_op = tf.initialize_all_variables()with tf.Session() as sess:sess.run(init_op)coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess, coord=coord)for i in range(show_nums):image,label = sess.run([tf_image,tf_label]) # 在會話中取出image和label# image = tf_image.eval()# 直接從record解析的image是一個向量,需要reshape顯示# image = image.reshape([height,width,depth])print('shape:{},tpye:{},labels:{}'.format(image.shape,image.dtype,label))# pilimg = Image.fromarray(np.asarray(image_eval_reshape))# pilimg.show()show_image("image:%d"%(label),image)coord.request_stop()coord.join(threads)def batch_test(record_file,resize_height, resize_width):''':param record_file: record文件路徑:param resize_height::param resize_width::return::PS:image_batch, label_batch一般作為網絡的輸入'''# 讀取record函數tf_image,tf_label = read_records(record_file,resize_height,resize_width,type='normalization')image_batch, label_batch= get_batch_images(tf_image,tf_label,batch_size=4,labels_nums=5,one_hot=False,shuffle=False)init = tf.global_variables_initializer()with tf.Session() as sess: # 開始一個會話sess.run(init)coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(coord=coord)for i in range(4):# 在會話中取出images和labelsimages, labels = sess.run([image_batch, label_batch])# 這里僅顯示每個batch里第一張圖片show_image("image", images[0, :, :, :])print('shape:{},tpye:{},labels:{}'.format(images.shape,images.dtype,labels))# 停止所有線程coord.request_stop()coord.join(threads)if __name__ == '__main__':# 參數設置resize_height = 224 # 指定存儲圖片高度resize_width = 224 # 指定存儲圖片寬度shuffle=Truelog=5# 產生train.record文件image_dir='dataset/train'train_labels = 'dataset/train.txt' # 圖片路徑train_record_output = 'dataset/record/train.tfrecords'create_records(image_dir,train_labels, train_record_output, resize_height, resize_width,shuffle,log)train_nums=get_example_nums(train_record_output)print("save train example nums={}".format(train_nums))# 產生val.record文件image_dir='dataset/val'val_labels = 'dataset/val.txt' # 圖片路徑val_record_output = 'dataset/record/val.tfrecords'create_records(image_dir,val_labels, val_record_output, resize_height, resize_width,shuffle,log)val_nums=get_example_nums(val_record_output)print("save val example nums={}".format(val_nums))# 測試顯示函數# disp_records(train_record_output,resize_height, resize_width)batch_test(train_record_output,resize_height, resize_width)2.2 生成單個record文件 (多label)
? ? 對于多label的情況,你可以在單label的基礎上增加多個“label': tf.FixedLenFeature([], tf.int64)“,但每次label個數不一樣時,都需要修改,挺麻煩的。這里提供一個方法:label數據也可以像圖像數據那樣,轉為string類型來保存:labels_raw = np.asanyarray(labels,dtype=np.float32).tostring() ,讀取也跟圖像數據一樣:tf_label = tf.decode_raw(features['labels'],tf.float32) ,這樣,不管多少個label,我們都可以保存為record文件了:
? ?多label的TXT文件:
0.jpg 0.33 0.55 1.jpg 0.42 0.73 2.jpg 0.16 0.75 3.jpg 0.78 0.66 4.jpg 0.46 0.59 5.jpg 0.46 0.09 6.jpg 0.89 0.93 7.jpg 0.42 0.82 8.jpg 0.39 0.76 9.jpg 0.46 0.40 # -*-coding: utf-8 -*- """@Project: create_tfrecord@File : create_tf_record_multi_label.py@Author : panjq@E-mail : pan_jinquan@163.com@Date : 2018-07-27 17:19:54@desc : 將圖片數據,多label,保存為單個tfrecord文件 """##########################################################################import tensorflow as tf import numpy as np import os import cv2 import matplotlib.pyplot as plt import random from PIL import Image########################################################################## def _int64_feature(value):return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))def _float_feature(value):return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))# 生成字符串型的屬性 def _bytes_feature(value):return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) # 生成實數型的屬性 def float_list_feature(value):return tf.train.Feature(float_list=tf.train.FloatList(value=value))def get_example_nums(tf_records_filenames):'''統計tf_records圖像的個數(example)個數:param tf_records_filenames: tf_records文件路徑:return:'''nums= 0for record in tf.python_io.tf_record_iterator(tf_records_filenames):nums += 1return numsdef show_image(title,image):'''顯示圖片:param title: 圖像標題:param image: 圖像的數據:return:'''# plt.figure("show_image")# print(image.dtype)plt.imshow(image)plt.axis('on') # 關掉坐標軸為 offplt.title(title) # 圖像題目plt.show()def load_labels_file(filename,labels_num=1,shuffle=False):'''載圖txt文件,文件中每行為一個圖片信息,且以空格隔開:圖像路徑 標簽1 標簽2,如:test_image/1.jpg 0 2:param filename::param labels_num :labels個數:param shuffle :是否打亂順序:return:images type->list:return:labels type->list'''images=[]labels=[]with open(filename) as f:lines_list=f.readlines()if shuffle:random.shuffle(lines_list)for lines in lines_list:line=lines.rstrip().split(' ')label=[]for i in range(labels_num):label.append(float(line[i+1]))images.append(line[0])labels.append(label)return images,labelsdef read_image(filename, resize_height, resize_width,normalization=False):'''讀取圖片數據,默認返回的是uint8,[0,255]:param filename::param resize_height::param resize_width::param normalization:是否歸一化到[0.,1.0]:return: 返回的圖片數據'''bgr_image = cv2.imread(filename)if len(bgr_image.shape)==2:#若是灰度圖則轉為三通道print("Warning:gray image",filename)bgr_image = cv2.cvtColor(bgr_image, cv2.COLOR_GRAY2BGR)rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)#將BGR轉為RGB# show_image(filename,rgb_image)# rgb_image=Image.open(filename)if resize_height>0 and resize_width>0:rgb_image=cv2.resize(rgb_image,(resize_width,resize_height))rgb_image=np.asanyarray(rgb_image)if normalization:# 不能寫成:rgb_image=rgb_image/255rgb_image=rgb_image/255.0# show_image("src resize image",image)return rgb_imagedef get_batch_images(images,labels,batch_size,labels_nums,one_hot=False,shuffle=False,num_threads=1):''':param images:圖像:param labels:標簽:param batch_size::param labels_nums:標簽個數:param one_hot:是否將labels轉為one_hot的形式:param shuffle:是否打亂順序,一般train時shuffle=True,驗證時shuffle=False:return:返回batch的images和labels'''min_after_dequeue = 200capacity = min_after_dequeue + 3 * batch_size # 保證capacity必須大于min_after_dequeue參數值if shuffle:images_batch, labels_batch = tf.train.shuffle_batch([images,labels],batch_size=batch_size,capacity=capacity,min_after_dequeue=min_after_dequeue,num_threads=num_threads)else:images_batch, labels_batch = tf.train.batch([images,labels],batch_size=batch_size,capacity=capacity,num_threads=num_threads)if one_hot:labels_batch = tf.one_hot(labels_batch, labels_nums, 1, 0)return images_batch,labels_batchdef read_records(filename,resize_height, resize_width,type=None):'''解析record文件:源文件的圖像數據是RGB,uint8,[0,255],一般作為訓練數據時,需要歸一化到[0,1]:param filename::param resize_height::param resize_width::param type:選擇圖像數據的返回類型None:默認將uint8-[0,255]轉為float32-[0,255]normalization:歸一化float32-[0,1]standardization:歸一化float32-[0,1],再減均值中心化:return:'''# 創建文件隊列,不限讀取的數量filename_queue = tf.train.string_input_producer([filename])# create a reader from file queuereader = tf.TFRecordReader()# reader從文件隊列中讀入一個序列化的樣本_, serialized_example = reader.read(filename_queue)# get feature from serialized example# 解析符號化的樣本features = tf.parse_single_example(serialized_example,features={'image_raw': tf.FixedLenFeature([], tf.string),'height': tf.FixedLenFeature([], tf.int64),'width': tf.FixedLenFeature([], tf.int64),'depth': tf.FixedLenFeature([], tf.int64),'labels': tf.FixedLenFeature([], tf.string)})tf_image = tf.decode_raw(features['image_raw'], tf.uint8)#獲得圖像原始的數據tf_height = features['height']tf_width = features['width']tf_depth = features['depth']# tf_label = tf.cast(features['labels'], tf.float32)tf_label = tf.decode_raw(features['labels'],tf.float32)# PS:恢復原始圖像數據,reshape的大小必須與保存之前的圖像shape一致,否則出錯# tf_image=tf.reshape(tf_image, [-1]) # 轉換為行向量tf_image=tf.reshape(tf_image, [resize_height, resize_width, 3]) # 設置圖像的維度tf_label=tf.reshape(tf_label, [2]) # 設置圖像的維度# 恢復數據后,才可以對圖像進行resize_images:輸入uint->輸出float32# tf_image=tf.image.resize_images(tf_image,[224, 224])# [3]數據類型處理# 存儲的圖像類型為uint8,tensorflow訓練時數據必須是tf.float32if type is None:tf_image = tf.cast(tf_image, tf.float32)elif type == 'normalization': # [1]若需要歸一化請使用:# 僅當輸入數據是uint8,才會歸一化[0,255]# tf_image = tf.cast(tf_image, dtype=tf.uint8)# tf_image = tf.image.convert_image_dtype(tf_image, tf.float32)tf_image = tf.cast(tf_image, tf.float32) * (1. / 255.0) # 歸一化elif type == 'standardization': # 標準化# tf_image = tf.cast(tf_image, dtype=tf.uint8)# tf_image = tf.image.per_image_standardization(tf_image) # 標準化(減均值除方差)# 若需要歸一化,且中心化,假設均值為0.5,請使用:tf_image = tf.cast(tf_image, tf.float32) * (1. / 255) - 0.5 # 中心化# 這里僅僅返回圖像和標簽# return tf_image, tf_height,tf_width,tf_depth,tf_labelreturn tf_image,tf_labeldef create_records(image_dir,file, output_record_dir, resize_height, resize_width,shuffle,log=5):'''實現將圖像原始數據,label,長,寬等信息保存為record文件注意:讀取的圖像數據默認是uint8,再轉為tf的字符串型BytesList保存,解析請需要根據需要轉換類型:param image_dir:原始圖像的目錄:param file:輸入保存圖片信息的txt文件(image_dir+file構成圖片的路徑):param output_record_dir:保存record文件的路徑:param resize_height::param resize_width:PS:當resize_height或者resize_width=0是,不執行resize:param shuffle:是否打亂順序:param log:log信息打印間隔'''# 加載文件,僅獲取一個labellabels_num=2images_list, labels_list=load_labels_file(file,labels_num,shuffle)writer = tf.python_io.TFRecordWriter(output_record_dir)for i, [image_name, labels] in enumerate(zip(images_list, labels_list)):image_path=os.path.join(image_dir,images_list[i])if not os.path.exists(image_path):print('Err:no image',image_path)continueimage = read_image(image_path, resize_height, resize_width)image_raw = image.tostring()if i%log==0 or i==len(images_list)-1:print('------------processing:%d-th------------' % (i))print('current image_path=%s' % (image_path),'shape:{}'.format(image.shape),'labels:{}'.format(labels))# 這里僅保存一個label,多label適當增加"'label': _int64_feature(label)"項# label=labels[0]# labels_raw="0.12,0,15"labels_raw = np.asanyarray(labels,dtype=np.float32).tostring()example = tf.train.Example(features=tf.train.Features(feature={'image_raw': _bytes_feature(image_raw),'height': _int64_feature(image.shape[0]),'width': _int64_feature(image.shape[1]),'depth': _int64_feature(image.shape[2]),'labels': _bytes_feature(labels_raw),}))writer.write(example.SerializeToString())writer.close()def disp_records(record_file,resize_height, resize_width,show_nums=4):'''解析record文件,并顯示show_nums張圖片,主要用于驗證生成record文件是否成功:param tfrecord_file: record文件路徑:return:'''# 讀取record函數tf_image, tf_label = read_records(record_file,resize_height,resize_width,type='normalization')# 顯示前4個圖片init_op = tf.initialize_all_variables()with tf.Session() as sess:sess.run(init_op)coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess, coord=coord)for i in range(show_nums):image,label = sess.run([tf_image,tf_label]) # 在會話中取出image和label# image = tf_image.eval()# 直接從record解析的image是一個向量,需要reshape顯示# image = image.reshape([height,width,depth])print('shape:{},tpye:{},labels:{}'.format(image.shape,image.dtype,label))# pilimg = Image.fromarray(np.asarray(image_eval_reshape))# pilimg.show()show_image("image:{}".format(label),image)coord.request_stop()coord.join(threads)def batch_test(record_file,resize_height, resize_width):''':param record_file: record文件路徑:param resize_height::param resize_width::return::PS:image_batch, label_batch一般作為網絡的輸入'''# 讀取record函數tf_image,tf_label = read_records(record_file,resize_height,resize_width,type='normalization')image_batch, label_batch= get_batch_images(tf_image,tf_label,batch_size=4,labels_nums=2,one_hot=False,shuffle=True)init = tf.global_variables_initializer()with tf.Session() as sess: # 開始一個會話sess.run(init)coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(coord=coord)for i in range(4):# 在會話中取出images和labelsimages, labels = sess.run([image_batch, label_batch])# 這里僅顯示每個batch里第一張圖片show_image("image", images[0, :, :, :])print('shape:{},tpye:{},labels:{}'.format(images.shape,images.dtype,labels))# 停止所有線程coord.request_stop()coord.join(threads)if __name__ == '__main__':# 參數設置resize_height = 224 # 指定存儲圖片高度resize_width = 224 # 指定存儲圖片寬度shuffle=Truelog=1000# 產生train.record文件image_dir='dataset_regression/images'train_labels = 'dataset_regression/train.txt' # 圖片路徑train_record_output = 'dataset_regression/record/train.tfrecords'create_records(image_dir,train_labels, train_record_output, resize_height, resize_width,shuffle,log)train_nums=get_example_nums(train_record_output)print("save train example nums={}".format(train_nums))# 測試顯示函數# disp_records(train_record_output,resize_height, resize_width)# 產生val.record文件image_dir='dataset_regression/images'val_labels = 'dataset_regression/val.txt' # 圖片路徑val_record_output = 'dataset_regression/record/val.tfrecords'create_records(image_dir,val_labels, val_record_output, resize_height, resize_width,shuffle,log)val_nums=get_example_nums(val_record_output)print("save val example nums={}".format(val_nums))## # 測試顯示函數# # disp_records(train_record_output,resize_height, resize_width)# batch_test(val_record_output,resize_height, resize_width)2.3 生成分割多個record文件?
? ? ? 上述該代碼只保存為單個record文件,當圖片數據很多時候,會導致單個record文件超級巨大的情況,解決方法就是,將數據分成多個record文件保存,讀取時,只需要將多個record文件的路徑列表交給“tf.train.string_input_producer”,完整代碼如下:
# -*-coding: utf-8 -*- """@Project: tf_record_demo@File : tf_record_batchSize.py@Author : panjq@E-mail : pan_jinquan@163.com@Date : 2018-07-27 17:19:54@desc : 將圖片數據保存為多個record文件 """##########################################################################import tensorflow as tf import numpy as np import os import cv2 import math import matplotlib.pyplot as plt import random from PIL import Image########################################################################## def _int64_feature(value):return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) # 生成字符串型的屬性 def _bytes_feature(value):return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) # 生成實數型的屬性 def float_list_feature(value):return tf.train.Feature(float_list=tf.train.FloatList(value=value))def show_image(title,image):'''顯示圖片:param title: 圖像標題:param image: 圖像的數據:return:'''# plt.figure("show_image")# print(image.dtype)plt.imshow(image)plt.axis('on') # 關掉坐標軸為 offplt.title(title) # 圖像題目plt.show()def load_labels_file(filename,labels_num=1):'''載圖txt文件,文件中每行為一個圖片信息,且以空格隔開:圖像路徑 標簽1 標簽2,如:test_image/1.jpg 0 2:param filename::param labels_num :labels個數:return:images type->list:return:labels type->list'''images=[]labels=[]with open(filename) as f:for lines in f.readlines():line=lines.rstrip().split(' ')label=[]for i in range(labels_num):label.append(int(line[i+1]))images.append(line[0])labels.append(label)return images,labelsdef read_image(filename, resize_height, resize_width):'''讀取圖片數據,默認返回的是uint8,[0,255]:param filename::param resize_height::param resize_width::return: 返回的圖片數據是uint8,[0,255]'''bgr_image = cv2.imread(filename)if len(bgr_image.shape)==2:#若是灰度圖則轉為三通道print("Warning:gray image",filename)bgr_image = cv2.cvtColor(bgr_image, cv2.COLOR_GRAY2BGR)rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)#將BGR轉為RGB# show_image(filename,rgb_image)# rgb_image=Image.open(filename)if resize_height>0 and resize_width>0:rgb_image=cv2.resize(rgb_image,(resize_width,resize_height))rgb_image=np.asanyarray(rgb_image)# show_image("src resize image",image)return rgb_imagedef create_records(image_dir,file, record_txt_path, batchSize,resize_height, resize_width):'''實現將圖像原始數據,label,長,寬等信息保存為record文件注意:讀取的圖像數據默認是uint8,再轉為tf的字符串型BytesList保存,解析請需要根據需要轉換類型:param image_dir:原始圖像的目錄:param file:輸入保存圖片信息的txt文件(image_dir+file構成圖片的路徑):param output_record_txt_dir:保存record文件的路徑:param batchSize: 每batchSize個圖片保存一個*.tfrecords,避免單個文件過大:param resize_height::param resize_width:PS:當resize_height或者resize_width=0是,不執行resize'''if os.path.exists(record_txt_path):os.remove(record_txt_path)setname, ext = record_txt_path.split('.')# 加載文件,僅獲取一個labelimages_list, labels_list=load_labels_file(file,1)sample_num = len(images_list)# 打亂樣本的數據# random.shuffle(labels_list)batchNum = int(math.ceil(1.0 * sample_num / batchSize))for i in range(batchNum):start = i * batchSizeend = min((i + 1) * batchSize, sample_num)batch_images = images_list[start:end]batch_labels = labels_list[start:end]# 逐個保存*.tfrecords文件filename = setname + '{0}.tfrecords'.format(i)print('save:%s' % (filename))writer = tf.python_io.TFRecordWriter(filename)for i, [image_name, labels] in enumerate(zip(batch_images, batch_labels)):image_path=os.path.join(image_dir,batch_images[i])if not os.path.exists(image_path):print('Err:no image',image_path)continueimage = read_image(image_path, resize_height, resize_width)image_raw = image.tostring()print('image_path=%s,shape:( %d, %d, %d)' % (image_path,image.shape[0], image.shape[1], image.shape[2]),'labels:',labels)# 這里僅保存一個label,多label適當增加"'label': _int64_feature(label)"項label=labels[0]example = tf.train.Example(features=tf.train.Features(feature={'image_raw': _bytes_feature(image_raw),'height': _int64_feature(image.shape[0]),'width': _int64_feature(image.shape[1]),'depth': _int64_feature(image.shape[2]),'label': _int64_feature(label)}))writer.write(example.SerializeToString())writer.close()# 用txt保存*.tfrecords文件列表# record_list='{}.txt'.format(setname)with open(record_txt_path, 'a') as f:f.write(filename + '\n')def read_records(filename,resize_height, resize_width):'''解析record文件:param filename:保存*.tfrecords文件的txt文件路徑:return:'''# 讀取txt中所有*.tfrecords文件with open(filename, 'r') as f:lines = f.readlines()files_list=[]for line in lines:files_list.append(line.rstrip())# 創建文件隊列,不限讀取的數量filename_queue = tf.train.string_input_producer(files_list,shuffle=False)# create a reader from file queuereader = tf.TFRecordReader()# reader從文件隊列中讀入一個序列化的樣本_, serialized_example = reader.read(filename_queue)# get feature from serialized example# 解析符號化的樣本features = tf.parse_single_example(serialized_example,features={'image_raw': tf.FixedLenFeature([], tf.string),'height': tf.FixedLenFeature([], tf.int64),'width': tf.FixedLenFeature([], tf.int64),'depth': tf.FixedLenFeature([], tf.int64),'label': tf.FixedLenFeature([], tf.int64)})tf_image = tf.decode_raw(features['image_raw'], tf.uint8)#獲得圖像原始的數據tf_height = features['height']tf_width = features['width']tf_depth = features['depth']tf_label = tf.cast(features['label'], tf.int32)# tf_image=tf.reshape(tf_image, [-1]) # 轉換為行向量tf_image=tf.reshape(tf_image, [resize_height, resize_width, 3]) # 設置圖像的維度# 存儲的圖像類型為uint8,這里需要將類型轉為tf.float32# tf_image = tf.cast(tf_image, tf.float32)# [1]若需要歸一化請使用:tf_image = tf.image.convert_image_dtype(tf_image, tf.float32)# 歸一化# tf_image = tf.cast(tf_image, tf.float32) * (1. / 255) # 歸一化# [2]若需要歸一化,且中心化,假設均值為0.5,請使用:# tf_image = tf.cast(tf_image, tf.float32) * (1. / 255) - 0.5 #中心化return tf_image, tf_height,tf_width,tf_depth,tf_labeldef disp_records(record_file,resize_height, resize_width,show_nums=4):'''解析record文件,并顯示show_nums張圖片,主要用于驗證生成record文件是否成功:param tfrecord_file: record文件路徑:param resize_height::param resize_width::param show_nums: 默認顯示前四張照片:return:'''tf_image, tf_height, tf_width, tf_depth, tf_label = read_records(record_file,resize_height, resize_width) # 讀取函數# 顯示前show_nums個圖片init_op = tf.initialize_all_variables()with tf.Session() as sess:sess.run(init_op)coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess, coord=coord)for i in range(show_nums):image,height,width,depth,label = sess.run([tf_image,tf_height,tf_width,tf_depth,tf_label]) # 在會話中取出image和label# image = tf_image.eval()# 直接從record解析的image是一個向量,需要reshape顯示# image = image.reshape([height,width,depth])print('shape:',image.shape,'label:',label)# pilimg = Image.fromarray(np.asarray(image_eval_reshape))# pilimg.show()show_image("image:%d"%(label),image)coord.request_stop()coord.join(threads)def batch_test(record_file,resize_height, resize_width):''':param record_file: record文件路徑:param resize_height::param resize_width::return::PS:image_batch, label_batch一般作為網絡的輸入'''tf_image,tf_height,tf_width,tf_depth,tf_label = read_records(record_file,resize_height, resize_width) # 讀取函數# 使用shuffle_batch可以隨機打亂輸入:# shuffle_batch用法:https://blog.csdn.net/u013555719/article/details/77679964min_after_dequeue = 100#該值越大,數據越亂,必須小于capacitybatch_size = 4# capacity = (min_after_dequeue + (num_threads + a small safety margin?batchsize)capacity = min_after_dequeue + 3 * batch_size#容量:一個整數,隊列中的最大的元素數image_batch, label_batch = tf.train.shuffle_batch([tf_image, tf_label],batch_size=batch_size,capacity=capacity,min_after_dequeue=min_after_dequeue)init = tf.global_variables_initializer()with tf.Session() as sess: # 開始一個會話sess.run(init)coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(coord=coord)for i in range(4):# 在會話中取出images和labelsimages, labels = sess.run([image_batch, label_batch])# 這里僅顯示每個batch里第一張圖片show_image("image", images[0, :, :, :])print(images.shape, labels)# 停止所有線程coord.request_stop()coord.join(threads)if __name__ == '__main__':# 參數設置image_dir='dataset/train'train_file = 'dataset/train.txt' # 圖片路徑output_record_txt = 'dataset/record/record.txt'#指定保存record的文件列表resize_height = 224 # 指定存儲圖片高度resize_width = 224 # 指定存儲圖片寬度batchSize=8000 #batchSize一般設置為8000,即每batchSize張照片保存為一個record文件# 產生record文件create_records(image_dir=image_dir,file=train_file,record_txt_path=output_record_txt,batchSize=batchSize,resize_height=resize_height,resize_width=resize_width)# 測試顯示函數disp_records(output_record_txt,resize_height, resize_width)# batch_test(output_record_txt,resize_height, resize_width)3. 直接文件讀取方式
? ? 上面的都是將數據轉存為tfrecord文件,訓練時再讀取,如果不想轉為record文件,想直接讀取圖像文件進行訓練,可以使用下面的方法:
? ? filename.txt
0.jpg 0
 1.jpg 0
 2.jpg 0
 3.jpg 0
 4.jpg 0
 5.jpg 1
 6.jpg 1
 7.jpg 1
 8.jpg 1
 9.jpg 1
4.數據輸入管道:Pipeline機制
? ? TensorFlow引入了tf.data.Dataset模塊,使其數據讀入的操作變得更為方便,而支持多線程(進程)的操作,也在效率上獲得了一定程度的提高。使用tf.data.Dataset模塊的pipline機制,可實現CPU多線程處理輸入的數據,如讀取圖片和圖片的一些的預處理,這樣GPU可以專注于訓練過程,而CPU去準備數據。
? ? 參考資料:
https://blog.csdn.net/u014061630/article/details/80776975
(五星推薦)TensorFlow全新的數據讀取方式:Dataset API入門教程:http://baijiahao.baidu.com/s?id=1583657817436843385&wfr=spider&for=pc
? ? Dataset支持一類特殊的操作:Transformation。一個Dataset通過Transformation變成一個新的Dataset。通常我們可以通過Transformation完成數據變換,打亂,組成batch,生成epoch等一系列操作。常用的Transformation有:map、batch、shuffle和repeat。
下面就分別進行介紹。
map
? ? 使用?tf.data.Dataset.map,我們可以很方便地對數據集中的各個元素進行預處理。因為輸入元素之間時獨立的,所以可以在多個 CPU 核心上并行地進行預處理。map?變換提供了一個?num_parallel_calls參數去指定并行的級別。
dataset = dataset.map(map_func=parse_fn, num_parallel_calls=FLAGS.num_parallel_calls)prefetch
? ? tf.data.Dataset.prefetch 提供了 software pipelining 機制。該函數解耦了 數據產生的時間 和 數據消耗的時間。具體來說,該函數有一個后臺線程和一個內部緩存區,在數據被請求前,就從 dataset 中預加載一些數據(進一步提高性能)。prefech(n) 一般作為最后一個 transformation,其中 n 為 batch_size。?prefetch 的使用方法如下:
dataset = dataset.batch(batch_size=FLAGS.batch_size) dataset = dataset.prefetch(buffer_size=FLAGS.prefetch_buffer_size) # last transformation return datasetrepeat
? ? repeat的功能就是將整個序列重復多次,主要用來處理機器學習中的epoch,假設原先的數據是一個epoch,使用repeat(5)就可以將之變成5個epoch:
? ? 如果直接調用repeat()的話,生成的序列就會無限重復下去,沒有結束,因此也不會拋出tf.errors.OutOfRangeError異常
完整代碼
# -*-coding: utf-8 -*- """@Project: fine tuning@File : pipeline.py@Author : panjq@E-mail : pan_jinquan@163.com@Date : 2018-11-17 20:18:54 """ import tensorflow as tf import numpy as np import glob import matplotlib.pyplot as pltwidth=0 height=0 def show_image(title, image):'''顯示圖片:param title: 圖像標題:param image: 圖像的數據:return:'''# plt.figure("show_image")# print(image.dtype)plt.imshow(image)plt.axis('on') # 關掉坐標軸為 offplt.title(title) # 圖像題目plt.show()def tf_read_image(filename, label):image_string = tf.read_file(filename)image_decoded = tf.image.decode_jpeg(image_string, channels=3)image = tf.cast(image_decoded, tf.float32)if width>0 and height>0:image = tf.image.resize_images(image, [height, width])image = tf.cast(image, tf.float32) * (1. / 255.0) # 歸一化return image, labeldef input_fun(files_list, labels_list, batch_size, shuffle=True):''':param files_list::param labels_list::param batch_size::param shuffle::return:'''# 構建數據集dataset = tf.data.Dataset.from_tensor_slices((files_list, labels_list))if shuffle:dataset = dataset.shuffle(100)dataset = dataset.repeat() # 空為無限循環dataset = dataset.map(tf_read_image, num_parallel_calls=4) # num_parallel_calls一般設置為cpu內核數量dataset = dataset.batch(batch_size)dataset = dataset.prefetch(2) # software pipelining 機制return datasetif __name__ == '__main__':data_dir = 'dataset/image/*.jpg'# labels_list = tf.constant([0,1,2,3,4])# labels_list = [1, 2, 3, 4, 5]files_list = glob.glob(data_dir)labels_list = np.arange(len(files_list))num_sample = len(files_list)batch_size = 1dataset = input_fun(files_list, labels_list, batch_size=batch_size, shuffle=False)# 需滿足:max_iterate*batch_size <=num_sample*num_epoch,否則越界max_iterate = 3with tf.Session() as sess:iterator = dataset.make_initializable_iterator()init_op = iterator.make_initializer(dataset)sess.run(init_op)iterator = iterator.get_next()for i in range(max_iterate):images, labels = sess.run(iterator)show_image("image", images[0, :, :, :])print('shape:{},tpye:{},labels:{}'.format(images.shape, images.dtype, labels))?
5.參考資料:
[1]https://blog.csdn.net/happyhorizion/article/details/77894055? (五星推薦)
[2]https://blog.csdn.net/ywx1832990/article/details/78462582
[3]https://blog.csdn.net/csuzhaoqinghui/article/details/51377941
?
?
總結
以上是生活随笔為你收集整理的Tensorflow生成自己的图片数据集TFrecords(支持多标签label)的全部內容,希望文章能夠幫你解決所遇到的問題。
 
                            
                        - 上一篇: 解决Ubuntu17.04不能安装git
- 下一篇: Python常用的模块的使用技巧
