ResNeXt 之 输入数据预处理代码详解
生活随笔
收集整理的這篇文章主要介紹了
ResNeXt 之 输入数据预处理代码详解
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
了解深度學習的同學都知道,在寫代碼的時候,主要的兩個部分就是網絡搭建和數據預處理工作,那么這就需要我們不斷地積累才能更好地使用,還不了解ILSVRC2012數據集形式的要先了解其形式,基本形式,就是 類別文件夾(n0xxxxx)->圖片名 ,想要詳細了解的,給大家推薦兩篇博客,一定要結合來看。
鏈接:https://blog.csdn.net/u012024357/article/details/90679222
? ? ? ? ??https://blog.csdn.net/tjuyanming/article/details/91354244
大家了解了數據集格式后,接下來我會給大家介紹ResNeXt的數據預處理工作是怎么進行的,我在代碼部分的關鍵部分都做了詳細的注釋,大家一定要看代碼。
from torchvision import transforms, datasets import os import torch from PIL import Image import scipy.io as scioIMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm']def ImageNetData(args): # data_transform, pay attention that the input of Normalize() is Tensor and the input of RandomResizedCrop() or RandomHorizontalFlip() is PIL Imagedata_transforms = {'train': transforms.Compose([transforms.Resize(256),transforms.RandomCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'val': transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),}image_datasets = {}#image_datasets['train'] = datasets.ImageFolder(os.path.join(args.data_dir, 'ILSVRC2012_img_train'), data_transforms['train'])#參數解釋: 訓練集圖片路徑,文件夾與類別名的映射文件,設置對圖片進行的處理image_datasets['train'] = ImageNetTrainDataSet(os.path.join(args.data_dir, 'ILSVRC2012_img_train'),os.path.join(args.data_dir, 'ILSVRC2012_devkit_t12', 'data', 'meta.mat'),data_transforms['train'])#參數解釋: 驗證集圖片路徑,圖片與類別的映射文件, 設置對圖片進行的處理image_datasets['val'] = ImageNetValDataSet(os.path.join(args.data_dir, 'ILSVRC2012_img_val'),os.path.join(args.data_dir, 'ILSVRC2012_devkit_t12', 'data','ILSVRC2012_validation_ground_truth.txt'),data_transforms['val'])# wrap your data and label into Tensordataloders = {x: torch.utils.data.DataLoader(image_datasets[x],batch_size=args.batch_size,shuffle=True,num_workers=args.num_workers) for x in ['train', 'val']}dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} #返回一個字典!return dataloders, dataset_sizesclass ImageNetTrainDataSet(torch.utils.data.Dataset):def __init__(self, root_dir, img_label, data_transforms):label_array = scio.loadmat(img_label)['synsets']#讀取映射文件中的synsets部分label_dic = {}for i in range(1000):label_dic[label_array[i][0][1][0]] = i#label_array[i][0][1][0]:圖像文件夾編號(相當于讀入1000個文件夾),和對應的類別,因為共1000個類別self.img_path = os.listdir(root_dir)#遍歷訓練集的文件夾(類別)數self.data_transforms = data_transformsself.label_dic = label_dic #文件夾和對應的類別組成的字典self.root_dir = root_dirself.imgs = self._make_dataset()#這里要用self.label_dictdef __len__(self):return len(self.imgs)def __getitem__(self, item): #Python的魔法方法__getitem__ 可以讓對象實現迭代功能data, label = self.imgs[item]img = Image.open(data).convert('RGB')if self.data_transforms is not None:try:img = self.data_transforms(img)except:print("Cannot transform image: {}".format(self.img_path[item]))return img, labeldef _make_dataset(self):class_to_idx = self.label_dic# 文件夾和類別所對應的的類別images = []dir = os.path.expanduser(self.root_dir)for target in sorted(os.listdir(dir)):#target是每一類圖像文件夾的名稱d = os.path.join(dir, target)if not os.path.isdir(d):continuefor root, _, fnames in sorted(os.walk(d)):#fnames 是 該類別文件夾下的所有圖片for fname in sorted(fnames):if self._is_image_file(fname):path = os.path.join(root, fname)#每一張圖片的路徑item = (path, class_to_idx[target])#每一張圖片的路徑和它所對應的類別images.append(item)#加入imagesreturn imagesdef _is_image_file(self, filename):"""Checks if a file is an image.Args:filename (string): path to a fileReturns:bool: True if the filename ends with a known image extension"""filename_lower = filename.lower()return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS)class ImageNetValDataSet(torch.utils.data.Dataset):def __init__(self, img_path, img_label, data_transforms):self.data_transforms = data_transformsimg_names = os.listdir(img_path)#獲取驗證集中所有圖片的名稱組成img_names(list類型)img_names.sort()#對list類型的數據進行排序self.img_path = [os.path.join(img_path, img_name) for img_name in img_names]with open(img_label,"r") as input_file:lines = input_file.readlines()self.img_label = [(int(line)-1) for line in lines] #獲取label,[1,val_lengths]def __len__(self):return len(self.img_path)def __getitem__(self, item): #Python的魔法方法__getitem__ 可以讓對象實現迭代功能img = Image.open(self.img_path[item]).convert('RGB')label = self.img_label[item]if self.data_transforms is not None:try:img = self.data_transforms(img)except:print("Cannot transform image: {}".format(self.img_path[item]))return img, label #返回一個tuple數據類型。這里大家需要了解的是python中的 __getitem__方法的用法。
另外大家疑惑最多的應該是這部分:
label_array = scio.loadmat(img_label)['synsets']#讀取映射文件中的synsets部分,這里保存的最重要的信息就是類別和文件夾的對應關系label_dic = {}for i in range(1000):label_dic[label_array[i][0][1][0]] = i#label_array[i][0][1][0]:圖像文件夾編號(相當于讀入1000個文件夾),和對應的類別,因為共1000個類別?其實,這是由.mat文件中的數據類型所決定的,因為 scio.loadmat(img_label) 讀出來的是字典型數據,因此我們需要得到 'synsets' 所對應的內容。為了方便大家理解們這里大家們可以將label_array打印出來,查看他的屬性(劇透:尺寸是[1860,1]),對應的代碼:
import scipy.io as sciopath = './/ImageNet//ILSVRC2012_devkit_t12//data//meta.mat'result = scio.loadmat(path)print(type(result))for i in range(2000):print(i)print(result['synsets'][i][0][1][0])#是為了獲取圖片文件夾編號,典型的對數據做切片到這里,應該就沒有問題了。
總結
以上是生活随笔為你收集整理的ResNeXt 之 输入数据预处理代码详解的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: TypeError之: unsuppor
- 下一篇: RuntimeError 之 : CUD