(第一篇)pytorch数据预处理三剑客之——Dataset,DataLoader,Transform
前言:在深度學習中,數據的預處理是第一步,pytorch提供了非常規范的處理接口,本文將針對處理過程中的一些問題來進行說明,本文所針對的主要數據是圖像數據集。
本文的案例來源于車道線語義分割,采用的數據集是tusimple數據集,當然先需要將tusimple數據集寫一個簡單的腳本程序轉換成指定的數據格式,如下:
一、基本概述
pytorch輸入數據PipeLine一般遵循一個“三步走”的策略,一般pytorch 的數據加載到模型的操作順序是這樣的:
① 創建一個 Dataset 對象。必須實現__len__()、__getitem__()這兩個方法,這里面會用到transform對數據集進行擴充。
② 創建一個 DataLoader 對象。它是對DataSet對象進行迭代的,一般不需要事先里面的其他方法了。
③ 循環遍歷這個 DataLoader 對象。將img, label加載到模型中進行訓練
注意這三個類均在torch.utils.data 中,這個模塊中定義了下面幾個功能,
from .sampler import Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, WeightedRandomSampler, BatchSampler from .distributed import DistributedSampler from .dataset import Dataset, TensorDataset, ConcatDataset, Subset, random_split from .dataloader import DataLoader# 可見,采樣器sanpler,dataset,dataloader都是定義在里面的pytorch數據處理pipeline 三步走的 一般格式如下:
dataset = MyDataset() # 第一步:構造Dataset對象 dataloader = DataLoader(dataset)# 第二步:通過DataLoader來構造迭代對象num_epoches = 100 for epoch in range(num_epoches):# 第三步:逐步迭代數據for img, label in dataloader:# 訓練代碼二、Dataset類詳解
Dataset類是Pytorch中圖像數據集中最為重要的一個類,也是Pytorch中所有數據集加載類中應該繼承的父類。其中Dataset類中的兩個私有成員函數必須被重載,否則將會觸發錯誤提示:
- def __getitem__(self, index):
- def __len__(self):
- def __init__(self): 構造函數一般情況下我們也是要自己定義的,但是不是強制性的。
其中__len__應該返回數據集的大小,而__getitem__應該編寫支持數據集索引的函數,例如通過dataset[i]可以得到數據集中的第i+1個數據。這個Dataset抽象父類的定義如下:
class Dataset(object):def __getitem__(self, index):raise NotImplementedErrordef __len__(self):raise NotImplementedErrordef __add__(self, other):return ConcatDataset([self, other])?
總結:Dataset的子類中除了上面的三個函數以外,當然還可以添加自己的一些處理函數,比如隨機打亂,歸一化等等,但是上面這三個一般情況下是必須要自己實現的。而且這三個函數的功能也有所側重,一般情況下:
(1)__init__(self): 主要是數據的獲取,比如從某個文件中獲取
(2)__len__(self): 整個數據集的長度
(3)__getitem__(self,index): 這個是最重要的,一般情況下它會包含以下幾個業務需要處理,
- 第一,比如如果我們需要在讀取數據的同時對圖像進行增強的話,當然,圖像增強的方法可以使用Pytorch內置的圖像增強方式,也可以使用自定義或者其他的圖像增強庫這個很靈活。
- 第二,在Pytorch中得到的圖像必須是tensor,也就是說我們必須要將數據格式轉化成pytorch的tensor格式才行。
2.1 構造函數__init__()
# coding: utf-8import os import torch from torch.utils.data import Dataset, DataLoader from PIL import Image import cv2 import numpy as npfrom torchvision.transforms import ToTensor from torchvision import datasets, transformsimport randomclass LaneDataSet(Dataset):def __init__(self, dataset, transform):'''param:detaset: 實際上就是tusimple數據集的三個文本文件train.txt、val.txt、test.txt三者的文件路徑transform: 決定是否進行變換,它其實是一個函數或者是幾個函數的組合構造三個列表,存儲每一張圖片的文件路徑 '''self._gt_img_list = []self._gt_label_binary_list = []self._gt_label_instance_list = []self.transform = transformwith open(dataset, 'r') as file: # 打開其實是那個 training下面的那個train.txt 文件for _info in file:info_tmp = _info.strip(' ').split()self._gt_img_list.append(info_tmp[0])self._gt_label_binary_list.append(info_tmp[1])self._gt_label_instance_list.append(info_tmp[2])assert len(self._gt_img_list) == len(self._gt_label_binary_list) == len(self._gt_label_instance_list)self._shuffle()此構造函數主要功能是實現將tusimple的數據集的gt_image、binary_image、instance_image的路徑分別存儲在三個列表中,并且隨機打亂。
這里有一個_shuffle()函數,如下:
def _shuffle(self):# 將gt_image、binary_image、instance_image三者所對應的圖片路徑組合起來,再進行隨機打亂c = list(zip(self._gt_img_list, self._gt_label_binary_list, self._gt_label_instance_list))random.shuffle(c)self._gt_img_list, self._gt_label_binary_list, self._gt_label_instance_list = zip(*c)2.2 必須要實現的__len__()函數
def __len__(self):return len(self._gt_img_list)其實就是返回樣本的數量,
2.3 必須要實現的__getitem__()函數
def __getitem__(self, idx):assert len(self._gt_label_binary_list) == len(self._gt_label_instance_list) \== len(self._gt_img_list)# 讀取所有圖片img = cv2.imread(self._gt_img_list[idx], cv2.IMREAD_COLOR) #真實圖片 (720,1280,3)label_instance_img = cv2.imread(self._gt_label_instance_list[idx], cv2.IMREAD_UNCHANGED) # instance圖片 (720,1280)label_binary_img = cv2.imread(self._gt_label_binary_list[idx], cv2.IMREAD_GRAYSCALE) #binary圖片 (720,1280)# optional transformations,裁剪成(256,512)if self.transform:img = self.transform(img)label_binary_img = self.transform(label_binary_img)label_instance_img = self.transform(label_instance_img)img = img.reshape(img.shape[2], img.shape[0], img.shape[1]) #(3,720,1280) 這里都沒有問題return (img, label_binary_img, label_instance_img)本例沒有在__getitem__實現了使用transform來對樣本數據進行處理,但是還沒有轉化成tensor,返回的是numpy數組。后面在處理也是一樣的。
三、DataLoader類詳解(_DataLoaderIter類)
DataLoader的幾種訪問方式:
(1)dataloader本質是一個可迭代對象,使用iter()訪問,不能使用next()訪問,由于它本身就是一個可迭代對象,可以使用`for?inputs,?labels?in?dataloaders`進行可迭代對象的訪問;
(2)先使用iter對dataloader進行第一步包裝,使用iter(dataloader)返回的是一個迭代器,然后就可以可以使用next訪問了。
先來看一下DataLoader 的定義,如下:
class DataLoader(object):__initialized = Falsedef __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, def __setattr__(self, attr, val):def __iter__(self):def __len__(self):注意:
(1)我們一般不需要再自己去實現DataLoader的方法了,只需要在構造函數中指定相應的參數即可,比如常見的batch_size,shuffle等等參數。所以使用DataLoader十分簡潔方便。既然都是通過指定構造函數的參數實現,這里重點介紹一下構造函數中參數的含義。
(2)DataLoader實際上一個較為高層的封裝類,它的功能都是通過更底層的_DataLoader來完成的,但是_DataLoader類較為低層,這里就不再展開敘述了。DataLoaderIter就是_DataLoaderIter的一個框架, 用來傳給_DataLoaderIter?一堆參數, 并把自己裝進DataLoaderIter?里。
3.1 DataLoader的構造函數參數
class DataLoader(object):Arguments:dataset (Dataset): 是一個DataSet對象,表示需要加載的數據集.batch_size (int, optional): 每一個batch加載多少組樣本,即指定batch_size,默認是 1 shuffle (bool, optional): 布爾值True或者是False ,表示每一個epoch之后是否對樣本進行隨機打亂,默認是False ------------------------------------------------------------------------------------sampler (Sampler, optional): 自定義從數據集中抽取樣本的策略,如果指定這個參數,那么shuffle必須為Falsebatch_sampler (Sampler, optional): 與sampler類似,但是一次只返回一個batch的indices(索引),需要注意的是,一旦指定了這個參數,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥) ------------------------------------------------------------------------------------num_workers (int, optional): 這個參數決定了有幾個進程來處理data loading。0意味著所有的數據都會被load進主進程。(默認為0)collate_fn (callable, optional): 將一個list的sample組成一個mini-batch的函數(這個還不是很懂)pin_memory (bool, optional): 如果設置為True,那么data loader將會在返回它們之前,將tensors拷貝到CUDA中的固定內存(CUDA pinned memory)中. ------------------------------------------------------------------------------------drop_last (bool, optional): 如果設置為True:這個是對最后的未完成的batch來說的,比如你的batch_size設置為64,而一個epoch只有100個樣本,那么訓練的時候后面的36個就被扔掉了,如果為False(默認),那么會繼續正常執行,只是最后的batch_size會小一點。 ------------------------------------------------------------------------------------timeout (numeric, optional): 如果是正數,表明等待從worker進程中收集一個batch等待的時間,若超出設定的時間還沒有收集到,那就不收集這個內容了。這個numeric應總是大于等于0。默認為0worker_init_fn (callable, optional): If not ``None``, this will be called on each worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as input, after seeding and before data loading. (default: ``None``)注意事項note: By default, each worker will have its PyTorch seed set to``base_seed + worker_id``, where ``base_seed`` is a long generatedby main process using its RNG. However, seeds for other libraiesmay be duplicated upon initializing workers (w.g., NumPy), causingeach worker to return identical random numbers. (See:ref:`dataloader-workers-random-seed` section in FAQ.) You mayuse :func:`torch.initial_seed()` to access the PyTorch seed foreach worker in :attr:`worker_init_fn`, and use it to set otherseeds before data loading.警告warning: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an unpicklable object, e.g., a lambda function.其實用的較多的,就是dataset,batch_size,shuffle這三個參數。
3.2 參數解析之——batch_size和shuffle參數
看下面的簡單應用:
import time import os import sysimport cv2 import numpy as np import matplotlib.pyplot as pltimport torch from torch import cuda from torch.utils.data import DataLoader from torchvision import transforms# 這是自己項目里面的模塊 from data_loader.data_loaders import LaneDataSet from data_loader.transformers import Rescale from lanenet.lanenet import LaneNet from lanenet.Model import ESPNet,compute_loss # 導入ESPNetdef train(train_loader):t=enumerate(iter(train_loader)) # 這里使用iter對dataloader進行了包裝for batch_idx, batch in t:# 注意 ,這三個數據都是 FloatTensorimage_data = batch[0].type(torch.FloatTensor).to(DEVICE) # (8,3,256,512) binary_label = batch[1].type(torch.FloatTensor).to(DEVICE) # [8,256,512] ,只有 0,255 這兩個值instance_label = batch[2].type(torch.FloatTensor).to(DEVICE) # (8,256,512) ,只有 0,20,70,120,170 每根車道線的值# 查看每一個batch里面的第一張樣本和所對應的標簽binary_label=binary_label.detach().cpu().numpy()instance_label=instance_label.detach().cpu().numpy()image_data=image_data.detach().cpu().type(torch.IntTensor).numpy()image_data = image_data.reshape(image_data.shape[0],image_data.shape[2], image_data.shape[3], image_data.shape[1]) #(8,256,512,3) plt.figure('image_data')plt.imshow(image_data[0][:,:,::-1]) #(256,512,3)plt.figure('binary_image')plt.imshow(binary_label[0], cmap='gray') #(256,512)plt.figure('instance_image')plt.imshow(instance_label[0], cmap='gray') #(256,512)plt.show()print("--------------------------------------------")def main():train_dataset_file = 'H:/tusimple_dataset/training/train.txt'# 第一步: 構造dataset 對象train_dataset = LaneDataSet(train_dataset_file, transform=transforms.Compose([Rescale((512,256))]))# 第二步: 構造dataloader 對象train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)# 第三步:迭代dataloader,進行訓練train(train_loader) if __name__ == '__main__':main()運行結果如下:
3.3 參數解析之——sampler參數
這個參數其實就是一個“采樣器”,表示從樣本中究竟如何取樣,pytorch中采樣器有如下幾個:
class Sampler(object):class SequentialSampler(Sampler):class RandomSampler(Sampler):class SubsetRandomSampler(Sampler):class WeightedRandomSampler(Sampler):class BatchSampler(Sampler):注意:Sampler類是所有的采樣器的基類,每一個繼承自Sampler的子類都必須實現它的__iter__方法和__len__方法。前者實現如何迭代樣本,后者實現一共有多少個樣本。
其實DataLoader里面在構造函數中就定義了采樣器——如何采樣,__init__中的部分代碼如下所示:
if batch_sampler is None: # 沒有手動傳入batch_sampler參數時if sampler is None: # 沒有手動傳入sampler參數時if shuffle:sampler = RandomSampler(dataset)else:sampler = SequentialSampler(dataset)batch_sampler = BatchSampler(sampler, batch_size, drop_last)self.sampler = sampler self.batch_sampler = batch_sampler self.__initialized = True3.4 參數解析之——collate_fn(這個參數往往是出現錯誤的根源所在)
DataLoader能夠為我們自動生成一個多線程的迭代器,只要傳入幾個參數進行就可以了,第一個參數就是上面定義的數據集,后面幾個參數就是batch size的大小,是否打亂數據,讀取數據的線程數目等等,這樣一來,我們就建立了一個多線程的I/O。
讀到這里,你可能覺得PyTorch真的太方便了,真的是簡單實用,但是在使用的過程中很有可能性就報錯了,而且你也是一步一步按著實現來的,怎么就報錯了呢?
不用著急,下面就來講一下為什么會報錯,以及這一塊pyhon實現的解讀,這樣你就能夠真正知道如何進行自定義的數據讀入。
(1)問題來源
通過上面的實現,可能會遇到各種不同的問題,Dataset非常簡單,一般都不會有錯,只要Dataset實現正確,那么問題的來源只有一個,那就是torch.utils.data.DataLoader中的一個參數collate_fn,這里我們需要找到DataLoader的源碼進行查看這個參數到底是什么。
可以看到collate_fn默認是等于default_collate,那么這個函數的定義如下。
def default_collate(batch):r"""Puts each data field into a tensor with outer dimension batch size"""error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"elem_type = type(batch[0])if isinstance(batch[0], torch.Tensor):out = Noneif _use_shared_memory:# If we're in a background process, concatenate directly into a# shared memory tensor to avoid an extra copynumel = sum([x.numel() for x in batch])storage = batch[0].storage()._new_shared(numel)out = batch[0].new(storage)return torch.stack(batch, 0, out=out)elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \and elem_type.__name__ != 'string_':elem = batch[0]if elem_type.__name__ == 'ndarray':# array of string classes and objectif re.search('[SaUO]', elem.dtype.str) is not None:raise TypeError(error_msg.format(elem.dtype))return torch.stack([torch.from_numpy(b) for b in batch], 0)if elem.shape == (): # scalarspy_type = float if elem.dtype.name.startswith('float') else intreturn numpy_type_map[elem.dtype.name](list(map(py_type, batch)))elif isinstance(batch[0], int_classes):return torch.LongTensor(batch)elif isinstance(batch[0], float):return torch.DoubleTensor(batch)elif isinstance(batch[0], string_classes):return batchelif isinstance(batch[0], container_abcs.Mapping):return {key: default_collate([d[key] for d in batch]) for key in batch[0]}elif isinstance(batch[0], container_abcs.Sequence):transposed = zip(*batch)return [default_collate(samples) for samples in transposed]raise TypeError((error_msg.format(type(batch[0]))))這是他的定義,但是到目前為止,我們似乎并不知道這個函數的作用,它的輸入參數是batch,這是什么意思也不知道,我們找到這個函數的調用部分,看一看究竟給這個函數傳遞進去的是什么。
前面說到了,DataLoader實際上是通過_DataLoaderIter來實現的,進入_DataLoaderIter,找到函數的調用如下:
def __next__(self):if self.num_workers == 0: # same-process loadingindices = next(self.sample_iter) # may raise StopIterationbatch = self.collate_fn([self.dataset[i] for i in indices]) # 在這里調用了collate_fn函數,傳遞的參數是一個列表if self.pin_memory:batch = pin_memory_batch(batch)return batch由上面可以發現,default_collate(batch)中的參數就是這里的? [self.dataset[i] for i in indices] 。從這里看這就是一個list,list中的每個元素就是self.data[i],如果你在往上看,可以看到這個self.data就是我們需要預先定義的Dataset,那么這里self.data[i]就等價于我們在Dataset里面定義的__getitem__這個函數。
所以我們知道了collate_fn這個函數的輸入就是一個list,list的長度是一個batch size,list中的每個元素都是__getitem__得到的結果。
這時我們再去看看collate_fn這個函數,其實可以看到非常簡單,就是通過對一些情況的排除,然后最后輸出結果,比如第一個if,如果我們的輸入是一個tensor,那么最后會將一個batch size的tensor重新stack在一起,比如輸入的tensor是一張圖片,3x30x30,如果batch size是32,那么按第一維stack之后的結果就是32x3x30x30,這里stack和concat有一點區別就是會增加一維。
所以通過上面的源碼解讀我們知道了數據讀入具體是如何操作的,那么我們就能夠實現自定義的數據讀入了,我們需要自己按需要重新定義collate_fn這個函數,下面舉個例子。
(2)collate_fn的案例一
下面我們來舉一個麻煩的例子,比如做文本識別,需要將一張圖片上的字符識別出來,比如下面這些圖片
那么這個問題的輸入就是一張一張的圖片,他的label就是一串字符,但是由于長度是變化的,所以這個問題比較麻煩。
下面我們就來簡單實現一下。
我們有一個train.txt的文件,上面有圖片的名稱和對應的label,首先我們需要定義一個Dataset。
class custom_dset(Dataset):def __init__(self,img_path,txt_path,img_transform=None,loader=default_loader):with open(txt_path, 'r') as f:lines = f.readlines()self.img_list = [os.path.join(img_path, i.split()[0]) for i in lines]self.label_list = [i.split()[1] for i in lines]self.img_transform = img_transformself.loader = loaderdef __getitem__(self, index):img_path = self.img_list[index]label = self.label_list[index]img = img_pathif self.img_transform is not None:img = self.img_transform(img)return img, labeldef __len__(self):return len(self.label_list)這里非常簡單,就是將txt文件打開,然后分別讀取圖片名和label,由于存放圖片的文件夾我并沒有放上去,因為數據太大,所以讀取圖片以及對圖片做一些變換的操作就不進行了。
接著我們自定義一個collate_fn,這里可以使用任何名字,只要在DataLoader里面傳入就可以了。
def collate_fn(batch):batch.sort(key=lambda x: len(x[1]), reverse=True)img, label = zip(*batch)pad_label = []lens = []max_len = len(label[0])for i in range(len(label)):temp_label = [0] * max_lentemp_label[:len(label[i])] = label[i]pad_label.append(temp_label)lens.append(len(label[i]))return img, pad_label, lens(3)collate_fn的案例二
在數據處理中,有時會出現某個樣本無法讀取等問題,比如某張圖片損壞。這時在_?getitem?_函數中將出現異常,此時最好的解決方案即是將出錯的樣本剔除。如果實在是遇到這種情況無法處理,則可以返回None對象,然后在Dataloader中實現自定義的collate_fn,將空對象過濾掉。但要注意,在這種情況下dataloader返回的batch數目會少于batch_size。
?
from torch.utils.data.dataloader import default_collate # 導入這個函數 def collate_fn(batch):'''batch 實際上是一個列表,列表的長度就是一個batch_size,列表的每一個元素形如(data, label),這實際上是定義DataSet的時候,每一個__getitem__得到的元素'''# 過濾為None的數據batch = list(filter(lambda x:x[0] is not None, batch))if len(batch) == 0: return torch.Tensor()return default_collate(batch) # 用默認方式拼接過濾后的batch數據,這里的defaut_collate就是pytorch默認給collate_fn傳遞的函數,需要導入才能使用 # 第一步:定義dataset dataset = NewDogCat(root='data/dogcat_wrong/', transform=transform)# 第二步:定義dataloader,需要注意的是,這里的collate_fn是我自己定義的啊 dataloader = DataLoader(dataset, 2, collate_fn=collate_fn, num_workers=1,shuffle=True)# 第三步:迭代dataloader for batch_datas, batch_labels in dataloader:print(batch_datas.size(),batch_labels.size())總結:什么時候該使用DataLoader的collate_fn這個參數?
? ? ? ? ?當定義DataSet類中的__getitem__函數的時候,由于每次返回的是一組類似于(x,y)的樣本,但是如果在返回的每一組樣本x,y中出現什么錯誤,或者是還需要進一步對x,y進行一些處理的時候,我們就需要再定義一個collate_fn函數來實現這些功能。當然我也可以自己在實現__getitem__的時候就實現這些后處理也是可以的。
? ? ? collate_fn,中單詞collate的含義是:核對,校勘,對照,整理。顧名思義,這就是一個對每一組樣本數據進行一遍“核對和重新整理”,現在可能更好理解一些。
后面有一篇專門講解collate_fn的文章,請參考:
(第二篇)pytorch數據預處理三劍客之——Dataset,DataLoader,Transform
總結
以上是生活随笔為你收集整理的(第一篇)pytorch数据预处理三剑客之——Dataset,DataLoader,Transform的全部內容,希望文章能夠幫你解決所遇到的問題。
 
                            
                        - 上一篇: ICCV2021|你以为这是一个填色模型
- 下一篇: 行人属性识别二:添加新网络训练和自定义数
