RNN代码解释pytorch
生活随笔
收集整理的這篇文章主要介紹了
RNN代码解释pytorch
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
簡述
還是跟之前的CNN一樣,都是學于莫煩Python的。
解釋
- 關于數據導入部分的代碼含義,其實跟之前的CNN幾乎完全一致。
- 而且還需要部分的源代碼–MNIST(在之前的地方有超鏈接)
- 這些都可以在下面的CNN的鏈接中看到
- 卷積神經網絡CNN入門【pytorch學習】
模型含義
這里使用RNN,這是跟之前的CNN唯一的不同的地方,其他的都是完全一致的。
class RNN(nn.Module):def __init__(self):super(RNN, self).__init__()self.rnn = nn.LSTM(input_size=28,hidden_size=64,num_layers=1,batch_first=True)self.out = nn.Linear(64, 10) # fully connected layer, output 10 classesdef forward(self, x):r_out, (h_n, h_c) = self.rnn(x, None) # None 表示 hidden state 會用全0的 state# r_out = [BATCH_SIZE, input_size, hidden_size]# r_out[:, -1, :] = [BATCH_SIZE, hidden_size] '-1',表示選取最后一個時間點的 r_out 輸出out = self.out(r_out[:, -1, :])# out = [BATCH_SIZE, 10]return outrnn = RNN()LSTM參數解釋
- 輸入參數,其實是表示有多少序列。這里的最小單位,考慮的其實不是整個圖片的完整全部序列。而是每一行為最小單位的。
- 所以說經過LSTM之后,輸出的結果就是r_out = [BATCH_SIZE, input_size, hidden_size]。 第一個input_size其實是恰好這個圖片大小是(input_size, input_size)的
out中輸入的有-1
- 會發現這里有一個數字-1,其實就是表示要選最后的一列作為最后的結果。其實就是說只看最后的一行。
完整代碼
import osimport torch import torch.nn as nn import torch.utils.data as Data import torchvision# Hyper Parameters EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch BATCH_SIZE = 50 LR = 0.001 # learning rate DOWNLOAD_MNIST = False# Mnist digits dataset if not (os.path.exists('./mnist/')) or not os.listdir('./mnist/'):# not mnist dir or mnist is empyt dirDOWNLOAD_MNIST = Truetrain_data = torchvision.datasets.MNIST(root='./mnist/',train=True, # this is training datatransform=torchvision.transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to# torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]download=DOWNLOAD_MNIST, )# Data Loader for easy mini-batch return in training, the image batch shape will be (50, 1, 28, 28) train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)# pick 2000 samples to speed up testing test_data = torchvision.datasets.MNIST(root='./mnist/', train=False) test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000] / 255. # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1) test_y = test_data.test_labels[:2000]class RNN(nn.Module):def __init__(self):super(RNN, self).__init__()self.rnn = nn.LSTM(input_size=28,hidden_size=64,num_layers=1,batch_first=True)self.out = nn.Linear(64, 10) # fully connected layer, output 10 classesdef forward(self, x):r_out, (h_n, h_c) = self.rnn(x, None) # None 表示 hidden state 會用全0的 state# r_out = [BATCH_SIZE, input_size, hidden_size]# r_out[:, -1, :] = [BATCH_SIZE, hidden_size] '-1',表示選取最后一個時間點的 r_out 輸出out = self.out(r_out[:, -1, :])# out = [BATCH_SIZE, 10]return outrnn = RNN()optimizer = torch.optim.Adam(rnn.parameters(), lr=LR) # optimize all cnn parameters loss_func = nn.CrossEntropyLoss() # the target label is not one-hottedfor epoch in range(EPOCH):for step, (x, b_y) in enumerate(train_loader): # gives batch datab_x = x.view(-1, 28, 28) # reshape x to (batch, time_step, input_size)output = rnn(b_x) # rnn outputloss = loss_func(output, b_y) # cross entropy lossoptimizer.zero_grad() # clear gradients for this training steploss.backward() # backpropagation, compute gradientsoptimizer.step()test_output = rnn(test_x[:10].view(-1, 28, 28)) pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()print(pred_y, 'prediction number') print(test_y[:10], 'real number')- 結果:
總結
以上是生活随笔為你收集整理的RNN代码解释pytorch的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 创建表名作为参数的mysq存储过程【pr
- 下一篇: 多彩投网站动态爬取[python+sel