查看dataloader的大小_一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系
以下內容都是針對Pytorch 1.0-1.1介紹。
很多文章都是從Dataset等對象自下往上進行介紹,但是對于初學者而言,其實這并不好理解,因為有的時候會不自覺地陷入到一些細枝末節中去,而不能把握重點,所以本文將會自上而下地對Pytorch數據讀取方法進行介紹。
及時獲取最優質的CV內容
自上而下理解三者關系
首先我們看一下DataLoader.__next__[1]的源代碼長什么樣,為方便理解我只選取了num_works為0的情況(num_works簡單理解就是能夠并行化地讀取數據)。
class DataLoader(object): ... def __next__(self): if self.num_workers == 0: indices = next(self.sample_iter) # Sampler batch = self.collate_fn([self.dataset[i] for i in indices]) # Dataset if self.pin_memory: batch = _utils.pin_memory.pin_memory_batch(batch) return batch在閱讀上面代碼前,我們可以假設我們的數據是一組圖像,每一張圖像對應一個index,那么如果我們要讀取數據就只需要對應的index即可,即上面代碼中的indices,而選取index的方式有多種,有按順序的,也有亂序的,所以這個工作需要Sampler完成,現在你不需要具體的細節,后面會介紹,你只需要知道DataLoader和Sampler在這里產生關系。
那么Dataset和DataLoader在什么時候產生關系呢?沒錯就是下面一行。我們已經拿到了indices,那么下一步我們只需要根據index對數據進行讀取即可了。
再下面的if語句的作用簡單理解就是,如果pin_memory=True,那么Pytorch會采取一系列操作把數據拷貝到GPU,總之就是為了加速。
綜上可以知道DataLoader,Sampler和Dataset三者關系如下:
在閱讀后文的過程中,你始終需要將上面的關系記在心里,這樣能幫助你更好地理解。
Sampler
參數傳遞
要更加細致地理解Sampler原理,我們需要先閱讀一下DataLoader 的源代碼,如下:
class DataLoader(object): def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)可以看到初始化參數里有兩種sampler:sampler和batch_sampler,都默認為None。前者的作用是生成一系列的index,而batch_sampler則是將sampler生成的indices打包分組,得到一個又一個batch的index。例如下面示例中,BatchSampler將SequentialSampler生成的index按照指定的batch size分組。
>>>in : list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))>>>out: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]Pytorch中已經實現的Sampler有如下幾種:
- SequentialSampler
- RandomSampler
- WeightedSampler
- SubsetRandomSampler
需要注意的是DataLoader的部分初始化參數之間存在互斥關系,這個你可以通過閱讀源碼[2]更深地理解,這里只做總結:
- 如果你自定義了batch_sampler,那么這些參數都必須使用默認值:batch_size, shuffle,sampler,drop_last.
- 如果你自定義了sampler,那么shuffle需要設置為False
- 如果sampler和batch_sampler都為None,那么batch_sampler使用Pytorch已經實現好的BatchSampler,而sampler分兩種情況:
- 若shuffle=True,則sampler=RandomSampler(dataset)
- 若shuffle=False,則sampler=SequentialSampler(dataset)
如何自定義Sampler和BatchSampler?
仔細查看源代碼其實可以發現,所有采樣器其實都繼承自同一個父類,即Sampler,其代碼定義如下:
class Sampler(object): r"""Base class for all Samplers. Every Sampler subclass has to provide an :meth:`__iter__` method, providing a way to iterate over indices of dataset elements, and a :meth:`__len__` method that returns the length of the returned iterators. .. note:: The :meth:`__len__` method isn't strictly required by :class:`~torch.utils.data.DataLoader`, but is expected in any calculation involving the length of a :class:`~torch.utils.data.DataLoader`. """ def __init__(self, data_source): pass def __iter__(self): raise NotImplementedError def __len__(self): return len(self.data_source)所以你要做的就是定義好__iter__(self)函數,不過要注意的是該函數的返回值需要是可迭代的。例如SequentialSampler返回的是iter(range(len(self.data_source)))。
另外BatchSampler與其他Sampler的主要區別是它需要將Sampler作為參數進行打包,進而每次迭代返回以batch size為大小的index列表。也就是說在后面的讀取數據過程中使用的都是batch sampler。
Dataset
Dataset定義方式如下:
class Dataset(object): def __init__(self): ... def __getitem__(self, index): return ... def __len__(self): return ...上面三個方法是最基本的,其中__getitem__是最主要的方法,它規定了如何讀取數據。但是它又不同于一般的方法,因為它是python built-in方法,其主要作用是能讓該類可以像list一樣通過索引值對數據進行訪問。假如你定義好了一個dataset,那么你可以直接通過dataset[0]來訪問第一個數據。在此之前我一直沒弄清楚__getitem__是什么作用,所以一直不知道該怎么進入到這個函數進行調試。現在如果你想對__getitem__方法進行調試,你可以寫一個for循環遍歷dataset來進行調試了,而不用構建dataloader等一大堆東西了,建議學會使用ipdb這個庫,非常實用!!!以后有時間再寫一篇ipdb的使用教程。另外,其實我們通過最前面的Dataloader的__next__函數可以看到DataLoader對數據的讀取其實就是用了for循環來遍歷數據,不用往上翻了,我直接復制了一遍,如下:
class DataLoader(object): ... def __next__(self): if self.num_workers == 0: indices = next(self.sample_iter) batch = self.collate_fn([self.dataset[i] for i in indices]) # this line if self.pin_memory: batch = _utils.pin_memory.pin_memory_batch(batch) return batch我們仔細看可以發現,前面還有一個self.collate_fn方法,這個是干嘛用的呢?在介紹前我們需要知道每個參數的意義:
- indices: 表示每一個iteration,sampler返回的indices,即一個batch size大小的索引列表
- self.dataset[i]: 前面已經介紹了,這里就是對第i個數據進行讀取操作,一般來說self.dataset[i]=(img, label)
看到這不難猜出collate_fn的作用就是將一個batch的數據進行合并操作。默認的collate_fn是將img和label分別合并成imgs和labels,所以如果你的__getitem__方法只是返回 img, label,那么你可以使用默認的collate_fn方法,但是如果你每次讀取的數據有img, box, label等等,那么你就需要自定義collate_fn來將對應的數據合并成一個batch數據,這樣方便后續的訓練步驟。
微信公眾號:AutoML機器學習MARSGGBO?原創如有意合作或學術討論歡迎私戳聯系~
郵箱:marsggbo@foxmail.com
2019-8-6
參考資料
[1]DataLoader.next: https://github.com/pytorch/pytorch/blob/0b868b19063645afed59d6d49aff1e43d1665b88/torch/utils/data/dataloader.py#L557-L563
[2]源碼: https://github.com/pytorch/pytorch/blob/0b868b19063645afed59d6d49aff1e43d1665b88/torch/utils/data/dataloader.py#L157-L182
歡迎關注Smarter,喜歡的可以雙擊點贊在看~~
總結
以上是生活随笔為你收集整理的查看dataloader的大小_一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 华为 WS550 无线路由器上网设置
- 下一篇: 天马微电子的mes工程师_上海天马微电子