[pytorch学习笔记] 3.Datasets Dataloaders
目錄
介紹
加載數據集
迭代和可視化數據集
?自定義數據集
__init__
__len__
__getitem__
使用 DataLoaders 準備訓練數據
遍歷 DataLoader
參考
介紹
理想情況下,我們希望我們的數據集代碼與我們的模型訓練代碼分離,以獲得更好的可讀性和模塊化。 PyTorch 提供了兩個數據模塊:torch.utils.data.DataLoader 和 torch.utils.data.Dataset,
允許使用預加載的數據集以及自己的數據。 Dataset 存儲樣本及其對應的標簽,DataLoader 在 Dataset 周圍包裝了一個可迭代對象,可以輕松訪問樣本。
PyTorch 域庫提供了許多預加載的數據集(例如 FashionMNIST),這些數據集是 torch.utils.data.Dataset 的子類,并實現了特定于特定數據的功能。 它們可用于對模型進行原型設計和基準測試。 可以在此處找到它們:?Image Datasets,?Text Datasets, and?Audio Datasets.
加載數據集
Fashion-MNIST數據集 由 60,000 個訓練示例和 10,000 個測試示例組成。 每個示例都包含 28×28 灰度圖像和來自 10 個類別之一的相關標簽。我們通過加載此數據集演示加載數據集的操作。
我們使用以下參數加載 FashionMNIST 數據集:
- root 是存儲訓練/測試數據的路徑,
- train 指定訓練或測試數據集,
- download=True 如果數據在根目錄下不可用,則從 Internet 下載數據。
- transform 和 target_transform 指定特征和標簽轉換
注:transforms.ToTensor()函數的作用是將原始的PILImage格式或者numpy.array格式的數據格式化為可被pytorch快速處理的張量類型。
輸入模式為(L、LA、P、I、F、RGB、YCbCr、RGBA、CMYK、1)的PIL Image 或 numpy.ndarray (形狀為H x W x C)數據范圍是[0, 255] 到一個 Torch.FloatTensor,其形狀 (C x H x W) 在 [0.0, 1.0] 范圍內。
?
迭代和可視化數據集
我們可以像列表一樣手動索引數據集:training_data[index]。
我們使用 matplotlib 來可視化我們訓練數據中的一些樣本。
labels_map = {0: "T-Shirt",1: "Trouser",2: "Pullover",3: "Dress",4: "Coat",5: "Sandal",6: "Shirt",7: "Sneaker",8: "Bag",9: "Ankle Boot", } figure = plt.figure(figsize=(8, 8)) cols, rows = 3, 3 for i in range(1, cols * rows + 1):sample_idx = torch.randint(len(training_data), size=(1,)).item()img, label = training_data[sample_idx]figure.add_subplot(rows, cols, i)plt.title(labels_map[label])plt.axis("off")plt.imshow(img.squeeze(), cmap="gray") plt.show()?自定義數據集
自定義數據集類必須實現三個函數:__init__、__len__ 和 __getitem__。 以FashionMNIST為例; FashionMNIST 圖像存儲在目錄 img_dir 中,它們的標簽分別存儲在 CSV 文件 annotations_file 中。
import os import pandas as pd from torchvision.io import read_imageclass CustomImageDataset(Dataset):def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transformdef __len__(self):return len(self.img_labels)def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image = read_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label__init__
__init__ 函數在實例化 Dataset 對象時運行一次。 始化化圖像、注釋文件和兩種轉換的目錄。
labels.csv 文件如下所示:
tshirt1.jpg, 0 tshirt2.jpg, 0 ...... ankleboot999.jpg, 9 def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transform__len__
__len__ 函數返回數據集中的樣本數。
def __len__(self):return len(self.img_labels)__getitem__
__getitem__ 函數從給定索引 idx 的數據集中加載并返回一個樣本。 根據索引識別圖像在磁盤上的位置,使用 read_image函數 將其轉換為張量,從 self.img_labels 中的 csv 數據中檢索相應的標簽,調用它們的轉換函數(如果適用),并返回張量圖像 和元組中的相應標簽。
def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image = read_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label使用 DataLoaders 準備訓練數據
檢索數據集的特征(Dataset)并一次標記一個樣本。 在訓練模型時,我們通常希望以“小批量”的形式傳遞樣本,在每個 epoch 重新洗牌以減少模型過擬合,并使用 Python 的多通道處理(multiprocessing)來加速數據檢索。
DataLoader 是一個迭代器,它通過一個簡單的 API 為我們實現這種功能。
from torch.utils.data import DataLoadertrain_dataloader = DataLoader(training_data, batch_size=64, shuffle=True) test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)遍歷 DataLoader
我們已經將該數據集加載到 DataLoader 中,并且可以根據需要遍歷數據集。 下面的每次迭代都會返回一批 train_features 和 train_labels(分別包含 batch_size=64 個特征和標簽)。 因為我們指定了 shuffle=True,所以在我們遍歷所有批次之后,數據會被打亂(為了更精準地控制數據加載順序,請查看 Samplers)。
# Display image and label. train_features, train_labels = next(iter(train_dataloader)) print(f"Feature batch shape: {train_features.size()}") print(f"Labels batch shape: {train_labels.size()}") img = train_features[0].squeeze() label = train_labels[0] plt.imshow(img, cmap="gray") plt.show() print(f"Label: {label}")out: Feature batch shape: torch.Size([64, 1, 28, 28]) Labels batch shape: torch.Size([64]) Label: 2官方文檔:
torch.squeeze
python 函數 iter()
python 函數 next()
不想看官方文檔,就看這個:(11條消息) next()函數___泡泡茶壺的博客-CSDN博客_next()
參考
(10條消息) pytorch數據處理之 transforms.ToTensor()解釋_菜根檀的博客-CSDN博客_totensor()
總結
以上是生活随笔為你收集整理的[pytorch学习笔记] 3.Datasets Dataloaders的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Paper Reading(1) : I
- 下一篇: HTML - label标签