生活随笔
收集整理的這篇文章主要介紹了
MNIST手写数字识别
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
進入到研究生階段了,從頭學一下Pytorch,在這個小破站上記錄一下自己的學習過程。 本文使用的是Pytorch來做手寫數字的識別。
step0:先引入一些相關的包和庫
import torch
from torch
import nn
from torch
. nn
import functional
as F
from torch
import optim
import torchvision
from matplotlib
import pyplot
as plt
from utils
import plot_image
, plot_curve
, one_hot
這里的utils是定義的一些輔助工具,包括loss下降的繪圖函數和one_hot編碼及圖片顯示的輔助函數。代碼如下: utils.py
import torch
from matplotlib
import pyplot
as plt
def plot_curve ( data
) : fig
= plt
. figure
( ) plt
. plot
( range ( len ( data
) ) , data
, color
= 'blue' ) plt
. legend
( [ 'value' ] , loc
= 'upper right' ) plt
. xlabel
( 'step' ) plt
. ylabel
( 'value' ) plt
. show
( ) def plot_image ( img
, label
, name
) : fig
= plt
. figure
( ) for i
in range ( 6 ) : plt
. subplot
( 2 , 3 , i
+ 1 ) plt
. tight_layout
( ) plt
. imshow
( img
[ i
] [ 0 ] * 0.3081 + 0.1307 , cmap
= 'gray' , interpolation
= 'none' ) plt
. title
( "{}:{}" . format ( name
, label
[ i
] . item
( ) ) ) plt
. xticks
( [ ] ) plt
. yticks
( [ ] ) plt
. show
( ) def one_hot ( labels
, depth
= 10 ) : out
= torch
. zeros
( labels
. size
( 0 ) , depth
) idx
= torch
. LongTensor
( labels
) . view
( - 1 , 1 ) out
. scatter_
( dim
= 1 , index
= idx
, value
= 1 ) return out
step1:加載數據 使用torch的DataLoader方法加載數據,MNIST數據集中的圖片大小為28*28,比較小,batch_size可以設置大一點。
batch_size
= 512
train_loader
= torch
. utils
. data
. DataLoader
( torchvision
. datasets
. MNIST
( 'mnist_data' , train
= True , download
= True , transform
= torchvision
. transforms
. Compose
( [ torchvision
. transforms
. ToTensor
( ) , torchvision
. transforms
. Normalize
( ( 0.1307 , ) , ( 0.3081 , ) ) ] ) ) , batch_size
= batch_size
, shuffle
= True
) test_loader
= torch
. utils
. data
. DataLoader
( torchvision
. datasets
. MNIST
( 'mnist_data/' , train
= False , download
= True , transform
= torchvision
. transforms
. Compose
( [ torchvision
. transforms
. ToTensor
( ) , torchvision
. transforms
. Normalize
( ( 0.1307 , ) , ( 0.3081 , ) ) ] ) ) , batch_size
= batch_size
, shuffle
= False
)
transforms.Compose方法將數據轉為Tensor和做數據歸一化,訓練集中設置shuffle=True是將訓練數據打亂.
step2:定義網絡結構 使用簡單的三層線性模型來做簡單的識別。
class Net ( nn
. Module
) : def __init__ ( self
) : super ( Net
, self
) . __init__
( ) self
. fc1
= nn
. Linear
( 28 * 28 , 256 ) self
. fc2
= nn
. Linear
( 256 , 64 ) self
. fc3
= nn
. Linear
( 64 , 10 ) def forward ( self
, x
) : x
= F
. relu
( self
. fc1
( x
) ) x
= F
. relu
( self
. fc2
( x
) ) x
= self
. fc3
( x
) return x
step3:train 訓練3個epoch
train_loss
= [ ]
net
= Net
( )
optimizer
= optim
. SGD
( net
. parameters
( ) , lr
= 0.01 , momentum
= 0.9 ) for epoch
in range ( 3 ) : for batch_idx
, ( x
, y
) in enumerate ( train_loader
) : x
= x
. view
( x
. size
( 0 ) , 28 * 28 ) out
= net
( x
) y_onehot
= one_hot
( y
) loss
= F
. mse_loss
( out
, y_onehot
) optimizer
. zero_grad
( ) loss
. backward
( ) optimizer
. step
( ) train_loss
. append
( loss
. item
( ) ) if batch_idx
% 10 == 0 : print ( epoch
, batch_idx
, loss
. item
( ) )
plot_curve
( train_loss
)
step4:test 最后在驗證集測試訓練的準確率
total_correct
= 0
for x
, y
in test_loader
: x
= x
. view
( x
. size
( 0 ) , 28 * 28 ) out
= net
( x
) pred
= out
. argmax
( dim
= 1 ) correct
= pred
. eq
( y
) . sum ( ) . float ( ) . item
( ) total_correct
+= correct
total_num
= len ( test_loader
. dataset
) acc
= total_correct
/ total_num
print ( "test acc:" , acc
)
總結
以上是生活随笔 為你收集整理的MNIST手写数字识别 的全部內容,希望文章能夠幫你解決所遇到的問題。
如果覺得生活随笔 網站內容還不錯,歡迎將生活随笔 推薦給好友。