【小白学PyTorch】8.实战之MNIST小试牛刀
<<小白學PyTorch>>
小白學PyTorch | 7 最新版本torchvision.transforms常用API翻譯與講解
小白學PyTorch | 6 模型的構建訪問遍歷存儲(附代碼)
小白學PyTorch | 5 torchvision預訓練模型與數據集全覽
小白學PyTorch | 4 構建模型三要素與權重初始化
小白學PyTorch | 3 淺談Dataset和Dataloader
小白學PyTorch | 2 淺談訓練集驗證集和測試集
小白學PyTorch | 1 搭建一個超簡單的網絡
小白學PyTorch | 動態圖與靜態圖的淺顯理解
參考目錄:
1 探索性數據分析
1.1 數據集基本信息
1.2 數據集可視化
1.3 類別是否均衡
2 訓練與推理
2.1 構建dataset
2.2 構建模型類
2.3 訓練模型
2.4 推理預測
1 探索性數據分析
一般在進行模型訓練之前,都要做一個數據集分析的任務。這個在英文中一般縮寫為EDA,也就是Exploring Data Analysis(好像是這個)。
數據集獲取方面,這里本來是要使用之前課程提到的torchvision.datasets.MNIST(),但是考慮到這個torchvision提供的MNIST完整下載下來需要200M的大小,所以我就直接提供了MNIST的數據的CSV文件(包含train.csv和test.csv),大小壓縮成.zip之后只有14M,代碼就基于了這個數據文件。
1.1 數據集基本信息
import?pandas?as?pd #?讀取訓練集 train_df?=?pd.read_csv('./MNIST_csv/train.csv') n_train?=?len(train_df) n_pixels?=?len(train_df.columns)?-?1 n_class?=?len(set(train_df['label'])) print('Number?of?training?samples:?{0}'.format(n_train)) print('Number?of?training?pixels:?{0}'.format(n_pixels)) print('Number?of?classes:?{0}'.format(n_class))#?讀取測試集 test_df?=?pd.read_csv('./MNIST_csv/test.csv') n_test?=?len(test_df) n_pixels?=?len(test_df.columns) print('Number?of?test?samples:?{0}'.format(n_test)) print('Number?of?test?pixels:?{0}'.format(n_pixels))輸出結果:
訓練集有42000個圖片,每個圖片有784個像素(所以變成圖片的話需要將784的像素變成),樣本總共有10個類別,也就是0到9。測試集中有28000個樣本。
1.2 數據集可視化
#?展示一些圖片 import?numpy?as?np from?torchvision.utils?import?make_grid import?torch import?matplotlib.pyplot?as?plt random_sel?=?np.random.randint(len(train_df),?size=8) data?=?(train_df.iloc[random_sel,1:].values.reshape(-1,1,28,28)/255.)grid?=?make_grid(torch.Tensor(data),?nrow=8) plt.rcParams['figure.figsize']?=?(16,?2) plt.imshow(grid.numpy().transpose((1,2,0))) plt.axis('off') plt.show() print(*list(train_df.iloc[random_sel,?0].values),?sep?=?',?')輸出結果有一個圖片:
以及一行打印:
隨機挑選了8個樣本進行可視化,然后打印出來的是樣本對應的標簽值。
1.3 類別是否均衡
然后我們需要檢查一下訓練樣本中類別是否均衡,利用直方圖來檢查:
#?檢查類別是否不均衡 plt.figure(figsize=(8,5)) plt.bar(train_df['label'].value_counts().index,?train_df['label'].value_counts()) plt.xticks(np.arange(n_class)) plt.xlabel('Class',?fontsize=16) plt.ylabel('Count',?fontsize=16) plt.grid('on',?axis='y') plt.show()輸出圖像:
基本沒毛病,是均衡的。
2 訓練與推理
2.1 構建dataset
我們可以重新寫一個python腳本,首先還是導入庫和讀取文件:
import?pandas?as?pd train_df?=?pd.read_csv('./MNIST_csv/train.csv') test_df?=?pd.read_csv('./MNIST_csv/test.csv') n_train?=?len(train_df) n_test?=?len(test_df) n_pixels?=?len(train_df.columns)?-?1 n_class?=?len(set(train_df['label']))然后構建一個Dataset,Dataset和Dataloader的知識前面的課程已經講過了,這里直接構建一個:
import?torch from?torch.utils.data?import?Dataset,DataLoader from?torchvision?import?transformsclass?MNIST_data(Dataset):def?__init__(self,?file_path,transform=transforms.Compose([transforms.ToPILImage(),?transforms.ToTensor(),transforms.Normalize(mean=(0.5,),?std=(0.5,))])):df?=?pd.read_csv(file_path)if?len(df.columns)?==?n_pixels:#?test?dataself.X?=?df.values.reshape((-1,?28,?28)).astype(np.uint8)[:,?:,?:,?None]self.y?=?Noneelse:#?training?dataself.X?=?df.iloc[:,?1:].values.reshape((-1,?28,?28)).astype(np.uint8)[:,?:,?:,?None]self.y?=?torch.from_numpy(df.iloc[:,?0].values)self.transform?=?transformdef?__len__(self):return?len(self.X)def?__getitem__(self,?idx):if?self.y?is?not?None:return?self.transform(self.X[idx]),?self.y[idx]else:return?self.transform(self.X[idx])可以看到,這個dataset中,根據是否有標簽分成返回兩個不同的值。(訓練集的話,同時返回數據和標簽,測試集中僅僅返回數據)。
batch_size?=?64train_dataset?=?MNIST_data('./MNIST_csv/train.csv',transform=?transforms.Compose([transforms.ToPILImage(),transforms.RandomRotation(degrees=20),transforms.ToTensor(),transforms.Normalize(mean=(0.5,),?std=(0.5,))])) test_dataset?=?MNIST_data('./MNIST_csv/test.csv')train_loader?=?torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,?shuffle=True) test_loader?=?torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,?shuffle=False)關于這段代碼:
構建了一個train的dataset和test的dataset,然后再分別構建對應的dataloader
train_dataset中使用了隨機旋轉,因為這個函數是作用在PIL圖片上的,所以需要將數據先轉成PIL再進行旋轉,然后轉成Tensor做標準化,這里標準化就隨便選取了0.5,有需要的可以做進一步的更改。
需要注意的是,轉成PIL之前的數據是numpy的格式,所以數據應該是的形式,因為這里是單通道圖像,所以數據的shape為:(72000,28,28,1).(72000為樣本數量)
像是旋轉、縮放等圖像增強方法在訓練集中才會使用,這是增強模型訓練難度的操作,讓模型增加魯棒性;在測試集中常規情況是不使用旋轉、縮放這樣的圖像增強方法的。(訓練階段是讓模型學到內容,測試階段主要目的是提高預測的準確度,這句話感覺是廢話。。。)
2.2 構建模型類
import?torch.nn?as?nn class?Net(nn.Module):def?__init__(self):super(Net,?self).__init__()self.features1?=?nn.Conv2d(1,?32,?kernel_size=3,?stride=1,?padding=1)self.features?=?nn.Sequential(nn.BatchNorm2d(32),nn.ReLU(inplace=True),nn.Conv2d(32,?32,?kernel_size=3,?stride=1,?padding=1),nn.BatchNorm2d(32),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2,?stride=2),nn.Conv2d(32,?64,?kernel_size=3,?padding=1),nn.BatchNorm2d(64),nn.ReLU(inplace=True),nn.Conv2d(64,?64,?kernel_size=3,?padding=1),nn.BatchNorm2d(64),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2,?stride=2))self.classifier?=?nn.Sequential(nn.Dropout(p=0.5),nn.Linear(64?*?7?*?7,?512),nn.BatchNorm1d(512),nn.ReLU(inplace=True),nn.Dropout(p=0.5),nn.Linear(512,?512),nn.BatchNorm1d(512),nn.ReLU(inplace=True),nn.Dropout(p=0.5),nn.Linear(512,?10),)for?m?in?self.modules():if?isinstance(m,?nn.Conv2d)?or?isinstance(m,?nn.Linear):nn.init.xavier_uniform_(m.weight)elif?isinstance(m,?nn.BatchNorm2d):m.weight.data.fill_(1)m.bias.data.zero_()def?forward(self,?x):x?=?self.features1(x)x?=?self.features(x)x?=?x.view(x.size(0),?-1)x?=?self.classifier(x)return?x這個模型類整體來看中規中矩,都是之前講到的方法。小測試:還記得xavier初始化時怎么回事嗎?xavier初始化方法是一個非常常用的方法,在之前的文章中也詳細的推導了這個。
之后呢,我們對模型實例化,然后給模型的參數傳到優化器中,然后設置一個學習率衰減的策略,學習率衰減就是訓練的epoch越多,學習率就越低的這樣一個方法,在后面的文章中會詳細講述 。
import?torch.optim?as?optimdevice?=?'cuda'?if?torch.cuda.is_available()?else?'cpu' model?=?Net().to(device) #?model?=?torchvision.models.resnet50(pretrained=True).to(device) optimizer?=?optim.Adam(model.parameters(),?lr=0.003) criterion?=?nn.CrossEntropyLoss().to(device) exp_lr_scheduler?=?optim.lr_scheduler.StepLR(optimizer,?step_size=7,?gamma=0.1) print(model)運行結果自然是把整個模型打印出來了:
Net((features1):?Conv2d(1,?32,?kernel_size=(3,?3),?stride=(1,?1),?padding=(1,?1))(features):?Sequential((0):?BatchNorm2d(32,?eps=1e-05,?momentum=0.1,?affine=True,?track_running_stats=True)(1):?ReLU(inplace=True)(2):?Conv2d(32,?32,?kernel_size=(3,?3),?stride=(1,?1),?padding=(1,?1))(3):?BatchNorm2d(32,?eps=1e-05,?momentum=0.1,?affine=True,?track_running_stats=True)(4):?ReLU(inplace=True)(5):?MaxPool2d(kernel_size=2,?stride=2,?padding=0,?dilation=1,?ceil_mode=False)(6):?Conv2d(32,?64,?kernel_size=(3,?3),?stride=(1,?1),?padding=(1,?1))(7):?BatchNorm2d(64,?eps=1e-05,?momentum=0.1,?affine=True,?track_running_stats=True)(8):?ReLU(inplace=True)(9):?Conv2d(64,?64,?kernel_size=(3,?3),?stride=(1,?1),?padding=(1,?1))(10):?BatchNorm2d(64,?eps=1e-05,?momentum=0.1,?affine=True,?track_running_stats=True)(11):?ReLU(inplace=True)(12):?MaxPool2d(kernel_size=2,?stride=2,?padding=0,?dilation=1,?ceil_mode=False))(classifier):?Sequential((0):?Dropout(p=0.5,?inplace=False)(1):?Linear(in_features=3136,?out_features=512,?bias=True)(2):?BatchNorm1d(512,?eps=1e-05,?momentum=0.1,?affine=True,?track_running_stats=True)(3):?ReLU(inplace=True)(4):?Dropout(p=0.5,?inplace=False)(5):?Linear(in_features=512,?out_features=512,?bias=True)(6):?BatchNorm1d(512,?eps=1e-05,?momentum=0.1,?affine=True,?track_running_stats=True)(7):?ReLU(inplace=True)(8):?Dropout(p=0.5,?inplace=False)(9):?Linear(in_features=512,?out_features=10,?bias=True)) )2.3 訓練模型
def?train(epoch):model.train()for?batch_idx,?(data,?target)?in?enumerate(train_loader):#?讀入數據data?=?data.to(device)target?=?target.to(device)#?計算模型預測結果和損失output?=?model(data)loss?=?criterion(output,?target)optimizer.zero_grad()?#?計算圖梯度清零loss.backward()?#?損失反向傳播optimizer.step()#?然后更新參數if?(batch_idx?+?1)?%?50?==?0:print('Train?Epoch:?{}?[{}/{}?({:.0f}%)]\tLoss:?{:.6f}'.format(epoch,?(batch_idx?+?1)?*?len(data),?len(train_loader.dataset),100.?*?(batch_idx?+?1)?/?len(train_loader),?loss.item()))exp_lr_scheduler.step()先定義了一個訓練一個epoch的函數,然后下面是訓練10個epoch的主函數代碼。
log?=?[]?#?記錄一下loss的變化情況 n_epochs?=?2 for?epoch?in?range(n_epochs):train(epoch)#?把log化成折線圖 import?matplotlib.pyplot?as?plt plt.plot(log) plt.show()注意注意,這時候會報一個錯誤,我們來看一下,我詳細標注了我個人看報錯時候的一個習慣:
這時候我大概可以猜到,因為我們這個圖片是灰度圖片,是單通道的,可能這個RandomRotate函數要求輸入圖片是3個通道的(這個官方API上也沒有細說),怎么辦呢?完全可以直接在轉成PIL格式之前,把numpy的那個(72000,28,28,1)復制第四維度,變成(72000,28,28,3).但是這里我想用上一節課教的一個方法torchvision.transforms.GrayScale(num_output_channels), 活學活用嘛.
所以把train_dataset那一塊改成:
train_dataset?=?MNIST_data('./MNIST_csv/train.csv',transform=?transforms.Compose([transforms.ToPILImage(),transforms.Grayscale(num_output_channels=3),transforms.RandomRotation(degrees=20),transforms.ToTensor(),transforms.Normalize(mean=(0.5,),?std=(0.5,))])) test_dataset?=?MNIST_data('./MNIST_csv/test.csv',transform=transforms.Compose([transforms.ToPILImage(),transforms.Grayscale(num_output_channels=3),transforms.ToTensor(),transforms.Normalize(mean=(0.5,),?std=(0.5,))]))然后不要忘記把模型類中的第一個卷積層的輸入通道改成3哦~
#?self.features1?=?nn.Conv2d(1,?32,?kernel_size=3,?stride=1,?padding=1) self.features1?=?nn.Conv2d(3,?32,?kernel_size=3,?stride=1,?padding=1)然后重新運行代碼,發現可以正常訓練了,打印輸出的部分截圖如下:
然后看一下損失下降的情況,算是收斂了,訓練的epoch更多應該會更好:發現訓練是收斂的。這里需要注意的是,現在用全部的數據進行訓練,沒有使用驗證集的做法,是有可能過擬合情況出現的(但是這里只是訓練了10個epoch應該不會過擬合),更穩妥的做法是把數據分成訓練集和驗證機(可以是2:1,3:1,4:1)都可以,4:1比較常用,這也就是n-fold的方法。 在之后的學習中會詳細介紹這個,不過這個知識點也不難,也可以自行查閱。
2.4 推理預測
def?prediciton(data_loader):model.eval()test_pred?=?torch.LongTensor()for?i,?data?in?enumerate(data_loader):data?=?data.to(device)output?=?model(data)pred?=?output.cpu().data.max(1,?keepdim=True)[1]test_pred?=?torch.cat((test_pred,?pred),?dim=0)return?test_predtest_pred?=?prediciton(test_loader)類似trian,寫一個預測的函數,返回預測的值。然后像是在EDA中那樣,抽取測試集的8個數字,看看圖像和預測結果的匹配情況
from?torchvision.utils?import?make_grid random_sel?=?np.random.randint(len(test_df),?size=8) data?=?(test_df.iloc[random_sel,:].values.reshape(-1,1,28,28)/255.)grid?=?make_grid(torch.Tensor(data),?nrow=8) plt.rcParams['figure.figsize']?=?(16,?2) plt.imshow(grid.numpy().transpose((1,2,0))) plt.axis('off') plt.show() print(*list(test_pred[random_sel].numpy()),?sep?=?',?')輸出圖像是:打印輸出:
OK了,恭喜你,完成了MNIST手寫數字集的分類。
- END -往期精彩回顧適合初學者入門人工智能的路線及資料下載機器學習及深度學習筆記等資料打印機器學習在線手冊深度學習筆記專輯《統計學習方法》的代碼復現專輯 AI基礎下載機器學習的數學基礎專輯獲取一折本站知識星球優惠券,復制鏈接直接打開:https://t.zsxq.com/662nyZF本站qq群704220115。加入微信群請掃碼進群(如果是博士或者準備讀博士請說明):總結
以上是生活随笔為你收集整理的【小白学PyTorch】8.实战之MNIST小试牛刀的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 【学术相关】数学公式如何用Markdow
- 下一篇: 【算法】图文并茂,一文了解 8 种常见的