制作pytorch数据集
生活随笔
收集整理的這篇文章主要介紹了
制作pytorch数据集
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
方法1、自定義的dataset類
需要實現必要的魔法方法:
- __init__魔法方法里面進行讀取數據文件
- __getitem__魔法方法進行支持下標訪問
- __len__魔法方法返回自定義數據集的大小,方便后期遍歷
面已經定義好了抽象數據,只需給出自己的dataset和idxs(數據的索引列表))
from torch.utils.data import DataLoader, Dataset class DatasetSplit(Dataset):"""An abstract Dataset class wrapped around Pytorch Dataset class."""def __init__(self, dataset, idxs):self.dataset = datasetself.idxs = [int(i) for i in idxs]def __len__(self):return len(self.idxs)def __getitem__(self, item):image, label = self.dataset[self.idxs[item]]return torch.as_tensor(image), torch.as_tensor(label) train_loader = DataLoader(DatasetSplit(train_dataset, client_idxs),batch_size=args.local_bs, shuffle=True)上面的train_dataset是你的數據集,client_idx是你的數據的索引列表,比如[1,2,345,33,54...........],數字代表數據在dataset中的位置。這樣制作后的數據集就是client_idx索引的數據集。
方法2:直接使用torch.utils.data.TensorDataset()封裝數據集
#劃分數據集 import torch import numpy as np import torch.utils.data as Data from sklearn.model_selection import train_test_split x_train, x_test, y_train,y_test = train_test_split(feature, labels, test_size=0.25) #制作pytorch識別的數據集 train_dataset = Data.TensorDataset(torch.from_numpy(x_train).float(), torch.from_numpy(y_train)) test_dataset = Data.TensorDataset(torch.from_numpy(x_test).float(), torch.from_numpy(y_test)) #制作可迭代的數據集 train_iter = Data.DataLoader(dataset = train_dataset,batch_size = batch_size,shuffle = True, num_workers = 2) test_iter = Data.DataLoader(dataset = test_dataset, batch_size= batch_size,shuffle = True, num_workers = 2)總結
以上是生活随笔為你收集整理的制作pytorch数据集的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: cnn识别cifar10、cifar10
- 下一篇: DCGAN生成cifar10, cifa