Unet项目解析(5): 数据封装、数据加载、数据显示
生活随笔
收集整理的這篇文章主要介紹了
Unet项目解析(5): 数据封装、数据加载、数据显示
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
項目GitHub主頁:https://github.com/orobix/retina-unet
參考論文:Retina blood vessel segmentation with a convolution neural network (U-net)?Retina blood vessel segmentation with a convolution neural network (U-net)
1.數據封裝成HDF5格式
import os import h5py import numpy as np from PIL import Imagedef write_hdf5(arr,outfile): # arr:數據 outfile:數據保存文件位置with h5py.File(outfile,"w") as f:f.create_dataset("image", data=arr, dtype=arr.dtype)# 訓練數據位置:圖像 金標準 掩膜 original_imgs_train = "./DRIVE/training/images/" groundTruth_imgs_train = "./DRIVE/training/1st_manual/" borderMasks_imgs_train = "./DRIVE/training/mask/" # 測試數據位置:圖像 金標準 掩膜 original_imgs_test = "./DRIVE/test/images/" groundTruth_imgs_test = "./DRIVE/test/1st_manual/" borderMasks_imgs_test = "./DRIVE/test/mask/" # 封裝數據保存位置 dataset_path = "./datasets_training_testing/"Nimgs = 20 channels = 3 height = 584 width = 565def get_datasets(imgs_dir,groundTruth_dir,borderMasks_dir,train_test="null"):imgs = np.empty((Nimgs,height,width,channels))groundTruth = np.empty((Nimgs,height,width)) # 二值圖像 channels=1border_masks = np.empty((Nimgs,height,width)) # 二值圖像 channels=1for path, subdirs, files in os.walk(imgs_dir):# path=當前路徑 subdirs=子文件夾 files=文件夾內所有的文件for i in range(len(files)): # len(files) 所有圖像的數量print ("original image: " +files[i])img = Image.open(imgs_dir+files[i]) # 讀取圖像到內存imgs[i] = np.asarray(img) # 轉換成numpy數據格式groundTruth_name = files[i][0:2] + "_manual1.gif"print ("ground truth name: " + groundTruth_name)g_truth = Image.open(groundTruth_dir + groundTruth_name)groundTruth[i] = np.asarray(g_truth)border_masks_name = ""if train_test=="train":border_masks_name = files[i][0:2] + "_training_mask.gif"elif train_test=="test":border_masks_name = files[i][0:2] + "_test_mask.gif"else:print "please specify if train or test!!"exit()print ("border masks name: " + border_masks_name)b_mask = Image.open(borderMasks_dir + border_masks_name)border_masks[i] = np.asarray(b_mask)print ("imgs max: " +str(np.max(imgs)))print ("imgs min: " +str(np.min(imgs)))assert(np.max(groundTruth)==255 and np.max(border_masks)==255) # 斷言判斷assert(np.min(groundTruth)==0 and np.min(border_masks)==0)# 調整張量格式 [Nimg, channels, height, width]imgs = np.transpose(imgs,(0,3,1,2)) groundTruth = np.reshape(groundTruth,(Nimgs,1,height,width))border_masks = np.reshape(border_masks,(Nimgs,1,height,width))# 檢查張量格式assert(imgs.shape == (Nimgs,channels,height,width)) assert(groundTruth.shape == (Nimgs,1,height,width))assert(border_masks.shape == (Nimgs,1,height,width))return imgs, groundTruth, border_masksif not os.path.exists(dataset_path):os.makedirs(dataset_path) # 封裝訓練數據集 imgs_train, groundTruth_train, border_masks_train = get_datasets(original_imgs_train,groundTruth_imgs_train,borderMasks_imgs_train,"train") print ("saving train datasets ... ...") write_hdf5(imgs_train, dataset_path + "imgs_train.hdf5") write_hdf5(groundTruth_train, dataset_path + "groundTruth_train.hdf5") write_hdf5(border_masks_train,dataset_path + "borderMasks_train.hdf5")# 封裝測試數據集 imgs_test, groundTruth_test, border_masks_test = get_datasets(original_imgs_test,groundTruth_imgs_test,borderMasks_imgs_test,"test") print ("saving test datasets ... ...") write_hdf5(imgs_test,dataset_path + "DRIVE_dataset_imgs_test.hdf5") write_hdf5(groundTruth_test, dataset_path + "DRIVE_dataset_groundTruth_test.hdf5") write_hdf5(border_masks_test,dataset_path + "DRIVE_dataset_borderMasks_test.hdf5")2. 寫入、加載HDF5文件
def write_hdf5(arr,outfile):with h5py.File(outfile,"w") as f:f.create_dataset("image", data=arr, dtype=arr.dtype) def load_hdf5(infile):with h5py.File(infile,"r") as f: # "image"是寫入的時候規定的字段 key-valuereturn f["image"][()] # 調用方法 train_imgs_original = load_hdf5( file_dir )3.圖像灰階轉換
# 將RGB圖像轉換為Gray圖像 def rgb2gray(rgb):assert (len(rgb.shape)==4) #[Nimgs, channels, height, width]assert (rgb.shape[1]==3) #確定是否為RGB圖像bn_imgs = rgb[:,0,:,:]*0.299 + rgb[:,1,:,:]*0.587 + rgb[:,2,:,:]*0.114bn_imgs = np.reshape(bn_imgs,(rgb.shape[0],1,rgb.shape[2],rgb.shape[3])) # 確保張量形式正確return bn_imgs4.利用已知信息進行分組顯示
# 對數據集劃分,進行分組顯示,totimg圖像陣列 def group_images(data,per_row): # data:數據 per_row:每行顯示的圖像個數assert data.shape[0]%per_row==0 # data=[Nimgs, channels, height, width]assert (data.shape[1]==1 or data.shape[1]==3)data = np.transpose(data,(0,2,3,1)) # 用于顯示all_stripe = []for i in range(int(data.shape[0]/per_row)): # data.shape[0]/per_row 行數stripe = data[i*per_row] # 相當于matlab中的 data(i*per_row, :, :, :) 一張圖像for k in range(i*per_row+1, i*per_row+per_row):stripe = np.concatenate((stripe,data[k]),axis=1) # 每per_row張圖像拼成一行all_stripe.append(stripe) # 加入列表totimg = all_stripe[0]for i in range(1,len(all_stripe)):totimg = np.concatenate((totimg,all_stripe[i]),axis=0) # 每行圖像進行拼湊 共len(all_stripe)行return totimg def visualize(data,filename):assert (len(data.shape)==3) #height*width*channelsimg = Noneif data.shape[2]==1: #in case it is black and whitedata = np.reshape(data,(data.shape[0],data.shape[1]))if np.max(data)>1:img = Image.fromarray(data.astype(np.uint8)) #the image is already 0-255else:img = Image.fromarray((data*255).astype(np.uint8)) #the image is between 0-1img.save(filename + '.png') #保存return img總結
以上是生活随笔為你收集整理的Unet项目解析(5): 数据封装、数据加载、数据显示的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 使用临界段实现优化的进程间同步对象-原理
- 下一篇: C++ 中重载 + 操作符的正确方法