camvid数据集使用方法_使用PyTorch处理CIFAR10数据集并显示
生活随笔
收集整理的這篇文章主要介紹了
camvid数据集使用方法_使用PyTorch处理CIFAR10数据集并显示
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
在訓練圖像分類的時候,我們通常會使用CIFAR10數據集,今天就先寫一下如何展示數據集的圖片及預處理。
第一部分代碼,展示原始圖像:
import numpy as npimport torch#導入內置cifarfrom torchvision.datasets import cifar#預處理模塊import torchvision.transforms as transformsfrom torch.utils.data import DataLoaderimport matplotlib.pyplot as pltclasses = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')#Compose將一些轉換函數組合在一起#ToTensor,原始數據是numpy,現在改成Tensor。會將數據從[0,255]歸一化到[0,1] 除以255transforms=transforms.Compose([transforms.ToTensor()])trainData=cifar.CIFAR10('./picdata',train=True,transform=transforms,download=True)testData=cifar.CIFAR10('./picdata',train=False,transform=transforms)x=0for images, labels in trainData: plt.subplot(3,3,x+1) plt.tight_layout() images = images.numpy().transpose(1, 2, 0) # 把channel那一維放到最后 plt.title(str(classes[labels])) plt.imshow(images) plt.xticks([]) plt.yticks([]) x+=1 if x==9: breakplt.show()圖片展示如下:
第二部分代碼,灰度化圖片:
import numpy as npimport torch#導入內置cifarfrom torchvision.datasets import cifar#預處理模塊import torchvision.transforms as transformsfrom torch.utils.data import DataLoaderimport matplotlib.pyplot as pltclasses = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')#Compose將一些轉換函數組合在一起#ToTensor,原始數據是numpy,現在改成Tensor。會將數據從[0,255]歸一化到[0,1] 除以255#Normalize則是將數據按照通道進行標準化,(輸入[通道]-均值[通道])/標準差[通道],將數據歸一化到[-1,1]#如果數據在[0,1]之間,則實際的偏移量bias會很大。而一般模型初始化的時候,bias=0,這樣收斂的就會慢。經過Normalize后加快收斂速度#后面兩個0.5就是制定mean和std,原來[0,1]變成:(0-0.5)/0.5=-1,(1-0.5)/0.5=1。本例是要灰度化,就一個通道,如果是三通道RGB,則應該為[0.5,0.5,0.5] ,transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])transforms=transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5],[0.5])])trainData=cifar.CIFAR10('./picdata',train=True,transform=transforms,download=True)testData=cifar.CIFAR10('./picdata',train=False,transform=transforms)#shuffle隨機打亂trainLoader=DataLoader(trainData,batch_size=64,shuffle=False)testLoader=DataLoader(testData,batch_size=128,shuffle=False)#enumerate組合成一個索引序列,同時列出數據下標和數據examples=enumerate(trainLoader)batchIndex,(imgData,labels)=next(examples)fig=plt.figure()for i in range(9): plt.subplot(3,3,i+1) plt.tight_layout() plt.imshow(imgData[i][0],cmap='gray',interpolation='none') plt.title("{}".format(classes[labels[i]])) plt.xticks([]) plt.yticks([])plt.show()圖片展示如下:
總結
以上是生活随笔為你收集整理的camvid数据集使用方法_使用PyTorch处理CIFAR10数据集并显示的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: argparse模块_Argparse:
- 下一篇: rhel 8.2不识别unicode_基