55_pytorch,自定义数据集
1.55.自定義數(shù)據(jù)
1.55.1.數(shù)據(jù)傳遞機(jī)制
我們首先回顧識(shí)別手寫(xiě)數(shù)字的程序:
... Dataset = torchvision.datasets.MNIST(root='./mnist/', train=True, transform=transform, download=True,) dataloader = torch.utils.data.DataLoader(dataset=Dataset, batch_size=64, shuffle=True) ... for epoch in range(EPOCH):for i, (image, label) in enumerate(dataloader):...從上面的程序,我們可以知道,在PyTorch中,數(shù)據(jù)傳遞機(jī)制是這樣的:
1.創(chuàng)建Dataset
2.Dataset傳遞給DataLoader
3.DataLoader迭代產(chǎn)生訓(xùn)練數(shù)據(jù)提供給模型。
總結(jié)這個(gè)數(shù)據(jù)傳遞機(jī)制就是,Dataset負(fù)責(zé)建立索引到樣本的映射,DataLoader負(fù)責(zé)以特定的方式從數(shù)據(jù)集中迭代的產(chǎn)生一個(gè)個(gè)batch的樣本集合。在enumerate過(guò)程中實(shí)際上是dataloader按照其參數(shù)sampler規(guī)定的策略調(diào)用了其dataset的getitem方法(下文中將介紹該方法)。
在上面的識(shí)別手寫(xiě)數(shù)字的例子中,數(shù)據(jù)集是直接下載的,但如果我們自己收集了一些數(shù)據(jù),存在電腦文件夾里,我們?cè)撊绾伟堰@些數(shù)據(jù)變?yōu)榭梢栽赑yTorch框架下進(jìn)行神經(jīng)網(wǎng)絡(luò)訓(xùn)練的數(shù)據(jù)集呢,即如何自定義數(shù)據(jù)集呢?
1.55.1.1.PyTorch中Dataset,DataLoader,Sample的關(guān)系
PyTorch中Dataset,DataLoader,Sampler的關(guān)系可以用下圖概括:
用文字表達(dá)就是:Dataloader中包含Sampler和Dataset,Sampler產(chǎn)生索引,Dataset拿著這個(gè)索引在數(shù)據(jù)集文件夾中找到對(duì)應(yīng)的樣本(每個(gè)樣本對(duì)應(yīng)一個(gè)索引,就像列表中每個(gè)元素對(duì)應(yīng)一個(gè)索引),并給該樣本配置上標(biāo)簽,最后返回(樣本+標(biāo)簽)給調(diào)用方。
在enumerate過(guò)程中,Dataloader按照其參數(shù)BatchSampler規(guī)定的策略調(diào)用其Dataset的getitem方法batchsize次,得到一個(gè)batch,該batch中既包含樣本,也包含相應(yīng)的標(biāo)簽。
1.55.2.自定義數(shù)據(jù)集
torch.utils.data.Dataset 是一個(gè)表示數(shù)據(jù)集的抽象類(lèi)。任何自定義的數(shù)據(jù)集都需要繼承這個(gè)類(lèi)并覆寫(xiě)相關(guān)方法。所謂數(shù)據(jù)集,其實(shí)就是一個(gè)負(fù)責(zé)處理索引(index)到樣本(sample)映射的一個(gè)類(lèi)(class)。Pytorch提供兩種數(shù)據(jù)集: Map式數(shù)據(jù)集 Iterable式數(shù)據(jù)集。這里我們只介紹前者。
一個(gè)Map式的數(shù)據(jù)集必須要重寫(xiě)getitem(self, index)、 len(self) 兩個(gè)內(nèi)建方法,用來(lái)表示從索引到樣本的映射(Map)。這樣一個(gè)數(shù)據(jù)集dataset,舉個(gè)例子,當(dāng)使用dataset[idx]命令時(shí),可以在你的硬盤(pán)中讀取數(shù)據(jù)集中第idx張圖片以及其標(biāo)簽(如果有的話(huà)); len(dataset)則會(huì)返回這個(gè)數(shù)據(jù)集的容量。
自定義數(shù)據(jù)集類(lèi)的范式大致是這樣的:
class CustomDataset(torch.utils.data.Dataset):#需要繼承torch.utils.data.Datasetdef __init__(self):# TODO# 1. Initialize file path or list of file names.passdef __getitem__(self, index):# TODO# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).# 2. Preprocess the data (e.g. torchvision.Transform).# 3. Return a data pair (e.g. image and label).#這里需要注意的是,第一步:read one data,是一個(gè)data pointpassdef __len__(self):# You should change 0 to the total size of your dataset.return 0關(guān)于Dataset API的官網(wǎng)介紹https://pytorch.org/docs/stable/data.html#dataset-types:
Dataset類(lèi)的使用:所有的類(lèi)都應(yīng)該是此類(lèi)的子類(lèi)(也就是說(shuō)應(yīng)該繼承該類(lèi))。所有的子類(lèi)都要重寫(xiě)(override) len(), getitem()。
?__len()__ : 此方法應(yīng)該提供數(shù)據(jù)集的大小(容量)
?__getitem()__ : 此方法應(yīng)該提供支持下標(biāo)索引方式訪問(wèn)數(shù)據(jù)集。
DataLoader類(lèi)的使用如下:
根據(jù)這個(gè)方式,我們舉一個(gè)例子。
1.55.3.實(shí)例1
從kaggle官網(wǎng)下載dogsVScats的數(shù)據(jù)集(百度網(wǎng)盤(pán)下載鏈接見(jiàn)文末),該數(shù)據(jù)集包含test1文件夾和train文件夾,train文件夾中包含12500張貓的圖片和12500張狗的圖片,圖片的文件名中帶序號(hào):
sampleSubmission.csv中的內(nèi)容如下:
我們把其中前10000張貓的圖片和10000張狗的圖片作為訓(xùn)練集,把后面的2500張貓的圖片和2500張狗的圖片作為驗(yàn)證集。貓的label記為0,狗的label記為1。因?yàn)閳D片大小不一,所以,我們需要對(duì)圖像進(jìn)行transform。
# -*- coding: UTF-8 -*-import matplotlib.pyplot as plt import numpy as np import torch from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image""" 如果代碼執(zhí)行的時(shí)候出現(xiàn): OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized. OMP: Hint This means that multiple copies of the OpenMP runtime have been linked into the program. That is dangerous, since it can degrade performance or cause incorrect results. The best thing to do is to ensure that only a single OpenMP runtime is linked into the process, e.g. by avoiding static linking of the OpenMP runtime in any library. As an unsafe, unsupported, undocumented workaround you can set the environment variable KMP_DUPLICATE_LIB_OK=TRUE to allow the program to continue to execute, but that may cause crashes or silently produce incorrect results. For more information, please see http://www.intel.com/software/products/support/.解決辦法是加上: import os os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" """ import os os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"image_transform = transforms.Compose([transforms.Resize(256), # 把圖片resize為256*256transforms.RandomCrop(224), # 隨機(jī)裁剪224*224transforms.RandomHorizontalFlip(), # 水平翻轉(zhuǎn)transforms.ToTensor(), # 將圖像轉(zhuǎn)為T(mén)ensortransforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 標(biāo)準(zhǔn)化 ])# 創(chuàng)建一個(gè)叫做DogVsCatDataset的Dataset,繼承自父類(lèi)torch.utils.data.Dataset class DogVsCatDataset(Dataset):def __init__(self, root_dir, train=True, transform=None):"""Args:root_dir (string): Directory with all the images.transform (callable, optional): Optional transform to be applied on a sample."""self.root_dir = root_dirself.img_path = os.listdir(self.root_dir)if train:# 圖片數(shù)據(jù)中有類(lèi)似:dog.12499.jpg的圖片共12499張。# x.split('.')[1] 就是文件名dog.12473.jpg中的序號(hào)部分,也是圖片的編號(hào)self.img_path = list(filter(lambda x: int(x.split('.')[1]) < 10000, self.img_path)) # 劃分訓(xùn)練集和驗(yàn)證集else:# 序號(hào)大于10000的編號(hào)self.img_path = list(filter(lambda x: int(x.split('.')[1]) >= 10000, self.img_path))self.transform = transformdef __len__(self):return len(self.img_path)def __getitem__(self, idx):image = Image.open(os.path.join(self.root_dir, self.img_path[idx]))label = 0 if self.img_path[idx].split('.')[0] == 'cat' else 1 # label, 貓為0,狗為1if self.transform:image = self.transform(image)label = torch.from_numpy(np.array([label]))return image, label# 來(lái)測(cè)試一下 if __name__ == '__main__':catanddog_dataset = DogVsCatDataset(root_dir='E:/BaiduNetdiskDownload/kaggle/train',train=False,transform=image_transform)# num_workers=4表示用4個(gè)線(xiàn)程讀取數(shù)據(jù)train_loader = DataLoader(catanddog_dataset, batch_size=8, shuffle=True, num_workers=4)# iter()函數(shù)把train_loader變?yōu)榈?#xff0c;然后調(diào)用迭代器的next()方法image, label = iter(train_loader).next()sample = image[0].squeeze()sample = sample.permute((1, 2, 0)).numpy()sample *= [0.229, 0.224, 0.225]sample += [0.485, 0.456, 0.406]sample = np.clip(sample, 0, 1)plt.imshow(sample)plt.show()print('Label is: {}'.format(label[0].numpy()))運(yùn)行結(jié)果:
1.55.4.實(shí)例2
1.55.4.1.收集圖像樣本
以簡(jiǎn)單的貓狗二分類(lèi)為例,可以在網(wǎng)上下載一些貓狗圖片。創(chuàng)建以下目錄:
?data -----------------根目錄
?data/test -----------------測(cè)試集
?data/train -----------------訓(xùn)練集
?data/val ------------------驗(yàn)證集
在test/train/val之下在校分別創(chuàng)建2個(gè)文件夾,dog,cat
cat,dog文件夾下分別存放2類(lèi)圖像:
之后寫(xiě)一個(gè)簡(jiǎn)單的python腳本,生成txt文件,用于指明每個(gè)圖像和標(biāo)簽的對(duì)應(yīng)關(guān)系。
格式:
/cat/1.jpg 0
/dog/1.jpg 1
…
如圖:
至此,樣本集的收集以及簡(jiǎn)單歸類(lèi)完成。
1.55.4.2.實(shí)現(xiàn)
使用到python package
| numpy | 矩陣操作,對(duì)圖像進(jìn)行轉(zhuǎn)置 |
| skimage | 圖像處理,圖像I/O,圖像變換 |
| matplotlib | 圖像的顯示,可視化 |
| os | 一些文件查找操作 |
| torch | pytorch |
| torchvision | pytorch |
1.55.4.3.代碼
# -*- coding: UTF-8 -*-""" 本案例來(lái)自:https://www.jb51.net/article/199360.htm """import numpy as np from skimage import io from skimage import transform import matplotlib.pyplot as plt import os import torch import torchvision from torch.utils.data import Dataset, DataLoader from torchvision.transforms import transforms from torchvision.utils import make_grid""" 第一步: 定義一個(gè)子類(lèi),繼承Dataset類(lèi),重寫(xiě)__len()__,__getitem()__方法。 細(xì)節(jié): 1、數(shù)據(jù)集中一個(gè)一樣的表示:采用字典的形式sample = {'image': image, 'label': label}。 2、圖像的讀取:采用skimage.io進(jìn)行讀取,讀取之后的結(jié)果為numpy.ndarray形式。 3、圖像變換:transform參數(shù) """class MyDataset(Dataset):def __init__(self, root_dir, names_file, transform=None):self.root_dir = root_dirself.names_file = names_fileself.transform = transformself.size = 0self.names_list = []if not os.path.isfile(self.names_file):print(self.names_file + 'does not exist!')file = open(self.names_file)for f in file:self.names_list.append(f)self.size += 1def __len__(self):return self.sizedef __getitem__(self, idx):image_path = self.root_dir + self.names_list[idx].split(' ')[0]if not os.path.isfile(image_path):print(image_path + 'does not exists!')return Noneimage = io.imread(image_path) # use skitimagelabel = int(self.names_list[idx].split(' ')[1])sample = {'image': image, 'label': label}if self.transform:sample = self.transform(sample)return sample""" 第二步 實(shí)例化一個(gè)對(duì)象,并讀取和顯示數(shù)據(jù)集 """ train_dataset = MyDataset(root_dir='./data/train',names_file='./data/train/train.txt',transform=None)plt.figure() for (cnt, i) in enumerate(train_dataset):image = i['image']label = i['label']ax = plt.subplot(4, 4, cnt + 1)ax.axis('off')ax.imshow(image)ax.set_title('label {}'.format(label))plt.pause(0.001)if cnt == 15:break""" 第三步(可選optional) 對(duì)數(shù)據(jù)集進(jìn)行變換:一般收集到的圖像大小尺寸,亮度等存在差異,變換的目的就是使得數(shù)據(jù)歸一化。另一方面,可 以通過(guò)變換進(jìn)行數(shù)據(jù)增加data argument關(guān)于pytorch中的變換transforms,請(qǐng)參考該系列之前的文章。由于數(shù)據(jù)集中樣本采用字典dicts形式表示。 因此不能直接調(diào)用torchvision.transofrms中的方法。 本實(shí)驗(yàn)只進(jìn)行尺寸歸一化Resize, 數(shù)據(jù)類(lèi)型變換ToTensor操作。Resize """# 變換Resize class Resize(object):def __init__(self, output_size: tuple):self.output_size = output_sizedef __call__(self, sample):# 圖像image = sample['image']# 使用skitimage.transform對(duì)圖像進(jìn)行縮放image_new = transform.resize(image, self.output_size)return {'image': image_new, 'label': sample['label']}# ToTensor ## 變換ToTensor class ToTensor(object):def __call__(self, sample):image = sample['image']image_new = np.transpose(image, (2, 0, 1))return {'image': torch.from_numpy(image_new), 'label': sample['label']}""" 第四步:對(duì)整個(gè)數(shù)據(jù)集應(yīng)用變換 細(xì)節(jié):transformers.Compose()將不同的幾個(gè)組合起來(lái)。先進(jìn)行Resize,再進(jìn)行ToTensor """ # 對(duì)原始的訓(xùn)練數(shù)據(jù)集進(jìn)行變換 transformed_trainset = MyDataset(root_dir='./data/train',names_file='./data/train/train.txt',transform=transforms.Compose([Resize((224, 224)),ToTensor()]))""" 第五步:使用DataLoader進(jìn)行包裝 為何要使用DataLoader? 1、深度學(xué)習(xí)的輸入是mini_batch形式 2、樣本加載時(shí)候可能需要隨機(jī)打亂順序,shuffle操作 3、樣本加載需要采用多線(xiàn)程 pytorch提供的DataLoader封裝了上述的功能,這樣使用起來(lái)更方便。 """ # 使用DataLoader可以利用多線(xiàn)程,batch,shuffle等 # 使用DataLoader可以利用多線(xiàn)程,batch,shuffle等 trainset_dataloader = DataLoader(dataset=transformed_trainset,batch_size=4,shuffle=True,num_workers=4)# 可視化 def show_images_batch(sample_batched):images_batch, labels_batch = \sample_batched['image'], sample_batched['label']grid = make_grid(images_batch)plt.imshow(grid.numpy().transpose(1, 2, 0))# sample_batch: Tensor , NxCxHxW plt.figure() for i_batch, sample_batch in enumerate(trainset_dataloader):show_images_batch(sample_batch)plt.axis('off')plt.ioff()plt.show()plt.show() """ 通過(guò)DataLoader包裝之后,樣本以min_batch形式輸出,而且進(jìn)行了隨機(jī)打亂順序。至此,自定義數(shù)據(jù)集的完整流程已經(jīng)實(shí)現(xiàn),test, val集只需要改路徑即可。 """輸出類(lèi)似:
補(bǔ)充:
更簡(jiǎn)單的方法
上述繼承Dataset,重寫(xiě)__len()__,__getitem()是通用的方法,過(guò)程相對(duì)繁瑣。對(duì)于簡(jiǎn)單的分類(lèi)數(shù)據(jù)集,pytorch中提供了更簡(jiǎn)便的方式----ImageFolder。
如果每種類(lèi)別的樣本放在各自的文件夾中,則可以直接使用ImageFolder。仍然以cat, dog二分類(lèi)數(shù)據(jù)集為例:
文件結(jié)構(gòu):
Code
import torch from torch.utils.data import DataLoader from torchvision import transforms, datasets import matplotlib.pyplot as plt import numpy as np# https://pytorch.org/tutorials/beginner/data_loading_tutorial.html# data_transform = transforms.Compose([ # transforms.RandomResizedCrop(224), # transforms.RandomHorizontalFlip(), # transforms.ToTensor(), # transforms.Normalize(mean=[0.485, 0.456, 0.406], # std=[0.229, 0.224, 0.225]) # ])data_transform = transforms.Compose([transforms.Resize((224,224)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),])train_dataset = datasets.ImageFolder(root='./data/train',transform=data_transform) train_dataloader = DataLoader(dataset=train_dataset,batch_size=4,shuffle=True,num_workers=4)def show_batch_images(sample_batch):labels_batch = sample_batch[1]images_batch = sample_batch[0]for i in range(4):label_ = labels_batch[i].item()image_ = np.transpose(images_batch[i], (1, 2, 0))ax = plt.subplot(1, 4, i + 1)ax.imshow(image_)ax.set_title(str(label_))ax.axis('off')plt.pause(0.01)plt.figure() for i_batch, sample_batch in enumerate(train_dataloader):show_batch_images(sample_batch)plt.show()由于 train 目錄下只有2個(gè)文件夾,分別為cat, dog, 因此ImageFolder安裝順序?qū)at使用標(biāo)簽0, dog使用標(biāo)簽1。(輸出類(lèi)似:)
1.55.5.參考文章
https://www.cnblogs.com/picassooo/p/12846617.html
https://www.jb51.net/article/199360.htm
總結(jié)
以上是生活随笔為你收集整理的55_pytorch,自定义数据集的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 当兵五年还能回去上大学吗?
- 下一篇: flink报错:Error: Stati