【小白学习PyTorch教程】五、在 PyTorch 中使用 Datasets 和 DataLoader 自定义数据
「@Author:Runsen」
有時候,在處理大數據集時,一次將整個數據加載到內存中變得非常難。
因此,唯一的方法是將數據分批加載到內存中進行處理,這需要編寫額外的代碼來執行此操作。對此,PyTorch 已經提供了 Dataloader 功能。
DataLoader
下面顯示了 PyTorch 庫中DataLoader函數的語法及其參數信息。
DataLoader(dataset,?batch_size=1,?shuffle=False,?sampler=None,batch_sampler=None,?num_workers=0,?collate_fn=None,pin_memory=False,?drop_last=False,?timeout=0,worker_init_fn=None,?*,?prefetch_factor=2,persistent_workers=False)幾個重要參數
- dataset:必須首先使用數據集構造 DataLoader 類。 
- Shuffle :是否重新整理數據。 
- Sampler :指的是可選的 torch.utils.data.Sampler 類實例。采樣器定義了檢索樣本的策略,順序或隨機或任何其他方式。使用采樣器時應將 Shuffle 設置為 false。 
- Batch_Sampler :批處理級別。 
- num_workers :加載數據所需的子進程數。 
- collate_fn :將樣本整理成批次。Torch 中可以進行自定義整理。 
加載內置 MNIST 數據集
MNIST 是一個著名的包含手寫數字的數據集。下面介紹如何使用DataLoader功能處理 PyTorch 的內置 MNIST 數據集。
import?torch import?matplotlib.pyplot?as?plt from?torchvision?import?datasets,?transforms上面代碼,導入了 torchvision 的torch計算機視覺模塊。通常在處理圖像數據集時使用,并且可以幫助對圖像進行規范化、調整大小和裁剪。
對于 MNIST 數據集,下面使用了歸一化技術。
ToTensor()能夠把灰度范圍從0-255變換到0-1之間。
transform?=?transforms.Compose([transforms.ToTensor()])下面代碼用于加載所需的數據集。使用 PyTorchDataLoader通過給定 batch_size = 64來加載數據。shuffle=True打亂數據。
trainset?=?datasets.MNIST('~/.pytorch/MNIST_data/',?download=True,?train=True,?transform=transform) trainloader?=?torch.utils.data.DataLoader(trainset,?batch_size=64,?shuffle=True)為了獲取數據集的所有圖像,一般使用iter函數和數據加載器DataLoader。
dataiter?=?iter(trainloader) images,?labels?=?dataiter.next() print(images.shape) print(labels.shape) plt.imshow(images[1].numpy().squeeze(),?cmap='Greys_r')自定義數據集
下面的代碼創建一個包含 1000 個隨機數的自定義數據集。
from?torch.utils.data?import?Dataset import?randomclass?SampleDataset(Dataset):def?__init__(self,r1,r2):randomlist=[]for?i?in?range(120):n?=?random.randint(r1,r2)randomlist.append(n)self.samples=randomlist?def?__len__(self):return?len(self.samples)def?__getitem__(self,idx):return(self.samples[idx])dataset=SampleDataset(1,100) dataset[100:120] 在這里插入圖片描述最后,將在自定義數據集上使用 dataloader 函數。將 batch_size 設為 12,并且還啟用了num_workers =2 的并行多進程數據加載。
from?torch.utils.data?import?DataLoader loader?=?DataLoader(dataset,batch_size=12,?shuffle=True,?num_workers=2?) for?i,?batch?in?enumerate(loader):print(i,?batch)寫在后面
通過幾個示例了解了 PyTorch Dataloader 在將大量數據批量加載到內存中的作用。
往期精彩回顧適合初學者入門人工智能的路線及資料下載機器學習及深度學習筆記等資料打印機器學習在線手冊深度學習筆記專輯《統計學習方法》的代碼復現專輯 AI基礎下載機器學習的數學基礎專輯黃海廣老師《機器學習課程》課件合集 本站qq群851320808,加入微信群請掃碼:總結
以上是生活随笔為你收集整理的【小白学习PyTorch教程】五、在 PyTorch 中使用 Datasets 和 DataLoader 自定义数据的全部內容,希望文章能夠幫你解決所遇到的問題。
 
                            
                        - 上一篇: 【Python】Pandas中的宝藏函数
- 下一篇: springMVC,aop管理log4j
