PyTorch基础(四)-----数据加载和预处理
前言
之前已經簡單講述了PyTorch的Tensor、Autograd、torch.nn和torch.optim包,通過這些我們已經可以簡單的搭建一個網絡模型,但這是不夠的,我們還需要大量的數據,眾所周知,數據是深度學習的靈魂,深度學習的模型是由數據“喂”出來的,這篇我們來講述一下數據的加載和預處理。
- 首先,我們要引入torch包
一、數據的加載
PyTorch通過torch.utils.data對一般常用的數據加載進行了封裝,可以很容易地實現多線程數據預讀和批量加載。
1.1 Dataset
Dataset是一個抽象類,為了能夠方便的讀取,需要將要使用的數據包裝為Dataset類。自定義的Dataset類需要繼承它并且實現2個成員方法:
- 1.__getitem__():該方法定義用索引(0-len(self))獲取一條數據或一個樣本
- 2.__len__():該方法返回數據集的總長度
下面我們使用Kaggle上的一個競賽bluebook for bulldozers自定義一個數據集,為了方便介紹,我們使用里面的數據字典來做說明
- 首先,我們需要引用相關的包
- 自定義一個數據集
- 至此,我們的數據集已經定義完成了,我們可以實例化一個對象來訪問
- 我們可以直接使用如下命令查看數據集數據
- 使用索引可以直接訪問對應的數據
自定義的數據集已經創建好了,下面我們使用官方提供的數據載入器,讀取數據
1.2 DataLoader
DataLoader為我們提供了對Dataset的讀取操作,常用參數有:batch_size(每個batch的大小)、shuffle(是否進行shuffle操作)、num_workers(加載數據時使用幾個子進程)。下面做一個簡單的演示:
dl = torch.utils.data.DataLoader(ds_demo,batch_size = 10,shuffle = True,num_workers = 0)DataLoader返回的是一個可迭代對象,我們可以使用迭代器分次獲取數據
idata=iter(dl) print(next(idata))常見的用法是使用for循環對其進行遍歷
for i, data in enumerate(dl):print(i,data)# 為了節約空間,這里只循環一遍break至此,我們已經可以通過dataset定義數據集,并使用DataLorder載入和遍歷數據集。
二、torchvision包
torchvision 是PyTorch中專門用來處理圖像的庫,PyTorch官網的安裝教程中最后的pip install torchvision 就是安裝這個包。
torchvision已經預先實現了常用圖像數據集,包括前面使用過的CIFAR-10,ImageNet、COCO、MNIST、LSUN等數據集,可通過torchvision.datasets方便的調用。
- 這里總結一下torchvision已經預裝的數據集:
| MNIST |
| COCO |
| CIFAR-10 |
| ImageNet |
| Captions |
| Detection |
| LSUN |
| ImageFolder |
| Imagenet-12 |
| STL10 |
| SVHN |
| PhotoTour |
PyTorch中自帶的數據集由2個上層api提供,分別是torchvision和torchtext
- torchvision提供了對圖像數據處理的相關數據和api
- 數據位置:torchvision.datasets;例如:torchvision.datasets.MNIST
- torchtext提供了對文本數據處理的相關數據和api
- 數據位置:torchtext.datasets;例如:torchtext.datasets.IMDB
下面我們做一個簡單的演示
- 首先,我們要引入torchvision包
2.1 torchvision.models
torchvision不僅提供了常用的圖像數據集,而且還提供了一些訓練好的網絡模型,可以加載之后直接使用,或者繼續進行遷移學習。torchvision.models模塊的子模塊中包含以下模型:
| AlexNet |
| VGG |
| ResNet |
| SqueezeNet |
| DenseNet |
我們直接可以使用訓練好的模型,當然這個與datasets相同,都是需要從服務器下載的。
- 首先,我們需要導入torchvision.models
- 直接使用
2.2 torchvision.tranforms
transforms 模塊提供了一般的圖像轉換操作類,用作數據處理和數據增強
- 首先,我們需要引入torchvision.tranforms,然后做一個簡單的演示
肯定有人會問:(0.485, 0.456, 0.406), (0.2023, 0.1994, 0.2010) 這幾個數字是什么意思?
官方的這個帖子有詳細的說明: https://discuss.pytorch.org/t/normalization-in-the-mnist-example/457/21 這些都是根據ImageNet訓練的歸一化參數,可以直接使用,我們認為這個是固定值就可以。
到這里,我們已經完成了PyTorch的基本內容介紹。
參考文獻
https://github.com/zergtant/pytorch-handbook/blob/master/chapter2
總結
以上是生活随笔為你收集整理的PyTorch基础(四)-----数据加载和预处理的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 手把手教你做用户画像
- 下一篇: BRD、MRD 和 PRD 之间的区别与