pytorch自定义数据集和数据加载器
                                                            生活随笔
收集整理的這篇文章主要介紹了
                                pytorch自定义数据集和数据加载器
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.                        
                                假設有一個保存為npy格式的numpy數據集,現在需要將其變為pytorch的數據集,并能夠被數據加載器DataLoader所加載
首先自定義一個數據集類,繼承torch.utils.data.Dataset類
在這個類中要實現__init__,__getitem__,__len__這三個方法,否則會報錯
然后實例化這個類,得到train_data,最后將train_data放入DataLoader數據加載器,到此已經完成
注意,在下面這個代碼中的x(也就是數據加載器加載出來的數據)的類型是tensor。也就是說,上面的實現中自動把numpy數據類型轉化為了tensor類型
from torch.utils.data import Dataset, DataLoaderclass MyDataset(Dataset):"""path:數據集存放路徑"""def __init__(self, path):self.data = np.load(path)def __getitem__(self, index):return self.data[index]def __len__(self):return len(self.data)if __name__ == '__main__':train_data = MyDataset(r"D:\dataset.npy")load1 = DataLoader(train_data, batch_size=128, shuffle=True, pin_memory=True, num_workers=3)for x in load1:print(x.size())有時候我們需要同時加載數據和其對應的標簽,則需要將數據集和標簽定義在同一個數據加載器中,這時可以采用以下方法:
from torch.utils.data import Dataset, DataLoaderclass MyDataset(Dataset):def __init__(self, data, label):self.data = dataself.label = labeldef __getitem__(self, index):return self.data[index], self.label[index]def __len__(self):return len(self.label)if __name__ == '__main__':a = np.array([0,1,2,3,4,5])b = np.array([6,7,8,9,10,11])trainset = MyDataset(a, b)train_loader = torch.utils.data.DataLoader(trainset, batch_size=2, shuffle=True, pin_memory=True,num_workers=3)for x, y in train_loader:print(x, y)重點在__getitem__方法的實現,需要同時返回數據和標簽
總結
以上是生活随笔為你收集整理的pytorch自定义数据集和数据加载器的全部內容,希望文章能夠幫你解決所遇到的問題。
 
                            
                        - 上一篇: OpenCV学习笔记(五):线性滤波-方
- 下一篇: OpenCV学习笔记(十三):霍夫变换:
