(深度学习)构造属于你自己的Pytorch数据集
(深度學習)構造屬于你自己的Pytorch數據集
1.綜述
2.實現原理
3.代碼細節
4.詳細代碼
綜述
Pytorch可以說是一個非常便利的深度學習庫,它甚至在torchvision.datasets中擁有許多一步到位完成數據集下載、解析、讀取的類——然鵝,這樣也就養成了我們懶惰依賴的心理。當我們需要用到torchvision.datasets中不曾擁有的數據集時,我們可能就會不知所措。
這篇文章中,我將以CIFAR-10數據集為例(雖然有torchvision.datasets.CIFAR10了),擺脫對torchvision.datasets的依賴,構建一個自己的數據集。
在開始之前,首先你要有CIFAR-10數據集,直接去官網上下載可能較慢(再次感謝我國著名建筑師方斌新院士 ),可以在https://pan.baidu.com/s/1bGVGeeiw001qz-PUk7q1Uw(提取碼:m35y)中下載python版本的數據集。
數據集解壓后目錄情況如下:
實現原理
首先,torch.utils.data.DataLoader不僅生成迭代數據非常方便,而且它也是經過優化的,效率十分之高(肯定比我們自己寫一個要高多了),因此我們最好不要舍棄。
因此,我們的目標是根據CIFAR-10數據集構造一個Dataset的子類,使之能夠作為torch.utils.data.DataLoader的參數,從而使數據集能被我們用于生成迭代數據進行訓練:
cifar10 = MyCIFAR10.MyCIFAR10('./data/cifar-10-batches-py', train=True) train_loader = torch.utils.data.DataLoader(dataset=cifar10, batch_size=batch_size, shuffle=True)要構造Dataset的子類,就必須要實現兩個方法:
- _getitem_(self, index):根據index來返回數據集中標號為index的元素及其標簽。
- _len_(self):返回數據集的長度。
因此,實質上我們主要是要通過__init__初始化之時讀取數據集,再實現這兩個函數便輕而易舉。
代碼細節
_init_:
- root是存放解壓后的數據集的根目錄,根據上圖我這里是'./data/cifar-10-batches-py'。
- X的類型是numpy數組,Y的類型是List;由于X作為數據要送入網絡中,因此最后需要將其累加值從numpy數組轉為Tensor。
_getitem_:
較為簡單,直接給出:
def __getitem__(self, index):img, label = self.imgs[index], self.labels[index]if self.transform is not None:img = self.transform(img)if self.target_transform is not None:label = self.target_transform(label)return img, label_len_:
極其簡單,直接給出:
def __len__(self):return len(self.imgs)詳細代碼
class MyCIFAR10(Dataset):"""根據CIFAR-10定義的個人數據集類繼承自Dataset類,因此能夠被torch.utils.data.DataLoader使用,從而更高效地在訓練和測試中迭代"""def __init__(self, root, train=True, transform=None, target_transform=None):super(MyCIFAR10, self).__init__()self.transform = transformself.target_transform = target_transformself.imgs = Noneself.labels = []# 根據CIFAR-10官網上下載的數據,訓練集分為5個batch文件,每個里有10000張32*32的圖片;測試集只有1個batch文件,里面有10000張32*32的圖片train_lists = ['data_batch_1','data_batch_2','data_batch_3','data_batch_4','data_batch_5']test_lists = ['test_batch']# 根據train是否為True來選擇測試集或訓練集if train:lists = train_listselse:lists = test_lists# 讀取數據集,構造類中的圖像集和標簽for list in lists:filename = os.path.join(root, list)with open(filename, 'rb') as f: # 這里需要'rb' + 'latin1'才能讀取datadict = pickle.load(f, encoding='latin1')X = datadict['data'].reshape(-1, 3, 32, 32)Y = datadict['labels']if self.imgs is None:self.imgs = np.vstack(X).reshape(-1, 3, 32, 32)else:self.imgs = np.vstack((self.imgs, X)).reshape(-1, 3, 32, 32)self.labels = self.labels + Yself.imgs = torch.from_numpy(self.imgs).type(torch.FloatTensor) # 最后需要將numpy數組轉為Tensor# 繼承的Dataset類需要實現兩個方法之一:__getitem__(self, index)def __getitem__(self, index):img, label = self.imgs[index], self.labels[index]if self.transform is not None:img = self.transform(img)if self.target_transform is not None:label = self.target_transform(label)return img, label# 繼承的Dataset類需要實現兩個方法之一:__len__(self)def __len__(self):return len(self.imgs)總結
以上是生活随笔為你收集整理的(深度学习)构造属于你自己的Pytorch数据集的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: STM32 启动代码分析
- 下一篇: n1怎么进入线刷模式_OPPO N1怎么