使用pytorch自定义DataSet,以加载图像数据集为例,实现一些骚操作
生活随笔
收集整理的這篇文章主要介紹了
使用pytorch自定义DataSet,以加载图像数据集为例,实现一些骚操作
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
使用pytorch自定義DataSet,以加載圖像數據集為例,實現一些騷操作
總共分為四步
- 構造一個my_dataset類,繼承自torch.utils.data.Dataset
- 重寫__getitem__ 和__len__ 類函數
- 建立兩個函數find_classes、has_file_allowed_extension,直接從這copy過去
- 建立my_make_dataset函數用來構造(path,lable)對
一、構造一個my_dataset類,繼承自torch.utils.data.Dataset
二、 重寫__getitem__ 和__len__ 類函數
要構造Dataset的子類,就必須要實現兩個方法:
- getitem_(self, index):根據index來返回數據集中標號為index的元素及其標簽。
- len_(self):返回數據集的長度。
三、建立兩個函數find_classes、has_file_allowed_extension,直接從這copy過去
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:"""Finds the class folders in a dataset.See :class:`DatasetFolder` for details."""classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())if not classes:raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}return classes, class_to_idxdef has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:"""Checks if a file is an allowed extension.Args:filename (string): path to a fileextensions (tuple of strings): extensions to consider (lowercase)Returns:bool: True if the filename ends with one of given extensions"""return filename.lower().endswith(extensions)- 建立my_make_dataset函數用來構造(path,lable)對
附錄:完整代碼
我這里傳入兩個root_dir,因為我要用一個dataset加載兩個數據集,分別放在data1和data2里
class my_dataset(Dataset):def __init__(self,root_original, root_cdtfed, transform=None):super(my_dataset, self).__init__()self.transform = transformself.root_original = root_originalself.root_cdtfed = root_cdtfedself.original_imgs = []self.cdtfed_imgs = []#add (img_path, label) to listsself.original_imgs = my_make_dataset(root_original, class_to_idx=None, extensions=('.jpg', '.png'), is_valid_file=None)self.cdtfed_imgs = my_make_dataset(root_original, class_to_idx=None, extensions=('.jpg', '.png'), is_valid_file=None)# super(my_dataset, self).__init__()def __getitem__(self, index): #這個方法是必須要有的,用于按照索引讀取每個元素的具體內容fn1, label1 = self.original_imgs[index] #fn是圖片path #fn和label分別獲得imgs[index]也即是剛才每行中word[0]和word[1]的信息fn2, label2 = self.cdtfed_imgs[index]img1 = Image.open(fn1).convert('RGB') #按照path讀入圖片from PIL import Image # 按照路徑讀取圖片img2 = Image.open(fn2).convert('RGB') #按照path讀入圖片from PIL import Image # 按照路徑讀取圖片if self.transform is not None:img1 = self.transform(img1) #是否進行transformimg2 = self.transform(img2) #是否進行transformimg_list = [img1, img2]label = label1name = fn1return img_list,label,name #return很關鍵,return回哪些內容,那么我們在訓練時循環讀取每個batch時,就能獲得哪些內容def __len__(self): #這個函數也必須要寫,它返回的是數據集的長度,也就是多少張圖片,要和loader的長度作區分return len(self.original_imgs)def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:"""Finds the class folders in a dataset.See :class:`DatasetFolder` for details."""classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())if not classes:raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}return classes, class_to_idxdef has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:"""Checks if a file is an allowed extension.Args:filename (string): path to a fileextensions (tuple of strings): extensions to consider (lowercase)Returns:bool: True if the filename ends with one of given extensions"""return filename.lower().endswith(extensions)def my_make_dataset(directory: str,class_to_idx: Optional[Dict[str, int]] = None,extensions: Optional[Tuple[str, ...]] = None,is_valid_file: Optional[Callable[[str], bool]] = None, ) -> List[Tuple[str, int]]:"""Generates a list of samples of a form (path_to_sample, class).See :class:`DatasetFolder` for details.Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` functionby default."""directory = os.path.expanduser(directory)if class_to_idx is None:_, class_to_idx = find_classes(directory)elif not class_to_idx:raise ValueError("'class_to_index' must have at least one entry to collect any samples.")both_none = extensions is None and is_valid_file is Noneboth_something = extensions is not None and is_valid_file is not Noneif both_none or both_something:raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")if extensions is not None:def is_valid_file(x: str) -> bool:return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))is_valid_file = cast(Callable[[str], bool], is_valid_file)instances = []available_classes = set()for target_class in sorted(class_to_idx.keys()):class_index = class_to_idx[target_class]target_dir = os.path.join(directory, target_class)if not os.path.isdir(target_dir):continuefor root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):for fname in sorted(fnames):if is_valid_file(fname):path = os.path.join(root, fname)# item = path, [int(cl) for cl in target_class.split('_')]item = path, target_classinstances.append(item)if target_class not in available_classes:available_classes.add(target_class)empty_classes = set(class_to_idx.keys()) - available_classesif empty_classes:msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "if extensions is not None:msg += f"Supported extensions are: {', '.join(extensions)}"raise FileNotFoundError(msg)return instances #instance:[item:(path, int(class_name)), ]總結
以上是生活随笔為你收集整理的使用pytorch自定义DataSet,以加载图像数据集为例,实现一些骚操作的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 3dsmax导出html,3dsmax导
- 下一篇: 这是一场数学、数学、数学的盛会