PyTorch中的数据输入和预处理
文章目錄
- PyTorch中的數據輸入和預處理
- 數據載入類
- 映射類型的數據集
- torchvision工具包的使用
- 可迭代類型的數據集
- 總結
PyTorch中的數據輸入和預處理
數據載入類
在使用PyTorch構建和訓練模型的過程中,經常需要將原始的數據轉換為張量。為了能夠方便地批量處理圖片數據,PyTorch引入了一系列工具來對這個過程進行包裝。
PyTorch數據的載入使用torch.utils.data.DataLoader類。該類的簽名如下:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)其中,dataset是一個torch.utils.data.Dataset類的實例,batch_size是Mini-batch的大小,shuffle代表數據會不會被隨機打亂,sampler是自定義的采樣器,每次迭代的時候會返回一個數據的下標索引,batch_sampler類似于sampler,不過返回的是一個Mini-batch的數據索引,而sampler僅僅返回下標索引。num_workers是數據載入器使用的進程數目。默認為0,即使用單進程來處理輸入數據,collate_fn定義如何把一批dataset的實例轉化為包含Mini-batch的張量。pin_memory參數會把數據轉移到和GPU相關聯的CPU內存中,從而加快GPU載入數據的速度,drop_last的設置決定了是否要把最后一個Mini-batch的數據丟棄掉,加入最后一個MIni-batch的數據數目小于預先設置的batch_size參數,timeout值如果大于0,就會決定在多進程情況下對數據的等待時間,worker_init_fn決定了每個數據載入的子進程開始時運行的函數,這個函數運行在隨機種子設置之后、數據載入之前。
映射類型的數據集
為了能夠使用DataLoader類,首先需要構造關于單個數據的torch.utils.data.Dataset類。這個類有兩種,第一種是映射類型的,對于這個類型,每個數據有一個對應的索引,通過輸入具體的索引,就能得到對應的數據,其構造方法如下所示:
class Dataset(object):def __getitm__(self, index):# index: 數據索引# ...# 返回數據張量def __len__(self):# 返回數據的數目# ...對于這個類,主要需要重寫兩個方法,第一個方法是__geitem__,該方法是Python內置的操作符方法,對應的操作符是索引操作符[],通過輸入整數數據索引,其大小在0至N-1之間,返回具體的某一條數據記錄。另一個方法是__len__,該方法返回數據的總數,若是一個Dataset類重寫了該方法可以通過使用len內置函數來獲取數據的數目。
torchvision工具包的使用
一個簡單torch.utils.data.Dataset類的實現如下:
class VisionDataset(data.Dataset):def __init__(self, root, transforms=None, transform=None, target_transform=None):# ...def __getitem__(self, index):raise NotImplementedErrordef __len__(self):raise NotImplementedErrorclass DatasetFolder(VisionDataset):def __init__(self, root, loader, extensions=None, transform=None, target_transform=None, is_valid_file=None):super(DatasetFolder, self).__init__(root, transform=transform, target_transform=target_transform)classes, class_to_idx = self._find_classes(self.root)self.samples = make_dataset(self.root, class_to-idx, extensions, is_valid_file)self.loader = loader# ...def __getitem__(self, index):path, target = self.samples[inex]sample = self.loader(path)if self.transform is not None:sample = self.transform(sample)if self.target_transform is not None:target = self.target_transform(target)return sample, targetdef __len__(self):return len(self.samples)從DataFolder類開始看,該類的使用情形是數據集存儲在一個目錄下,每個目錄有很多子目錄,子目錄的個數是圖片類的數目,每個子目錄下都存儲有很多圖片,且這些圖片都屬于一類。DataFolder類繼承了VisionDataset類。在DataFolder類的構造函數中一開始調用了類內部的_find_classes來找到具體的預測目標的類別和類別對應的class_to_idx,得到包含所有數據記錄的一個列表。這個列表記錄著數據的路徑和數據的預測目標。另外這個構造函數還傳入了一個參數loader,用來載入數據。
__ getitem__這個方法會傳入一個index,根據index從self.samples取得一條數據記錄,得到數據記錄的路徑和預測目標,然后使用loader來對數據進行載入,并使用self.transform和self.target_transform對數據進行變換。最后返回變換以后的數據和預測目標。
torchvision包中有一些內置的轉換函數,有一類主要作用于PyTorch張量。首先將張量轉換為圖片的類。其次在生成深度學習訓練模型的時候,轉換圖片為張量以后,使用torchvision.transforms.Nomalize類標準化。這個類需要傳入兩個參數,第一個參數是所有圖片的平均值張量,另一個是所有圖片的標準差張量,輸出的結果是輸入圖片張量減去平均值張量,然后除以標準差張量。最后,前文所述所有的轉換類可以組成一個大的轉換類,構造一個整體的包含所有列表按次序轉換的轉換類,這個類的調用效果是輸出這些轉換一次作用后的結果。
可迭代類型的數據集
可迭代類型的數據集相比于映射類型的數據集,不需要實現__getitem__方法和__len__方法,它本身更像一個Python迭代器。
對于不同的映射類型,因為索引之間相互獨立,在使用多個進程載入數據的情況下,多個進程可以獨立分配索引,迭代器在使用過程中,因為索引之間有先后順序關系,需要考慮如何分割數據,使得不同的進程可以得到不同的數據。對這一類型的數據,可以根據不同工作進程的序號worker_id,設定不同數據迭代器的取值范圍,保證不同的進程獲取不同的迭代器,而且迭代器返回的數據各不相同。
總結
在進行深度學習的過程中,數據的輸入和預處理十分重要。PyTorch提供的數據抽象類以及數據載入器的類,通過繼承數據的抽象類,可以構造出針對某一個特殊數據的實例,然后輸入數據載入器中,數據載入器可以自動對數據進行多進程處理,最后輸出數據的張量供深度學習模型使用。
《新程序員》:云原生和全面數字化實踐50位技術專家共同創作,文字、視頻、音頻交互閱讀總結
以上是生活随笔為你收集整理的PyTorch中的数据输入和预处理的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: PyTorch的损失函数和优化器
- 下一篇: PyTorch模型的保存加载以及数据的可