05,pytorch_手写数字案例
生活随笔
收集整理的這篇文章主要介紹了
05,pytorch_手写数字案例
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
本節內容來自網易云課堂課程,主講人為:龍良曲
1、關于MNIST
- 每個數字都有7000張圖片(each number owns 7000 images)
- 訓練集/測試集 數量比例:60k vs 10k (train/test splitting: 60k vs 10k)
2、NO deep learning, just function mapping
3 Loss ?
4、In a nutshell
5、Non-linear Factor
6、Gradient Descent
7、Inference
8.手寫數字問題
案例代碼:
# -*- coding: UTF-8 -*-import torch from torch import nn from torch.nn import functional as F from torch import optimimport torchvision from matplotlib import pyplot as pltdef plot_curve(data):fig = plt.figure()plt.plot(range(len(data)), data, color='blue')plt.legend(['value'], loc='upper right')plt.xlabel('step')plt.ylabel('value')plt.show()def plot_image(img, label, name):fig = plt.figure()for i in range(6):plt.subplot(2, 3, i + 1)plt.tight_layout()plt.imshow(img[i][0] * 0.3081 + 0.1307, cmap='gray', interpolation='none')plt.title("{}: {}".format(name, label[i].item()))plt.xticks([])plt.yticks([])plt.show()def one_hot(label, depth=10):out = torch.zeros(label.size(0), depth)idx = torch.LongTensor(label).view(-1, 1)out.scatter_(dim=1, index=idx, value=1)return outbatch_size = 512# step1. load dataset train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data', train=True, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data/', train=False, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=False)x, y = next(iter(train_loader)) print(x.shape, y.shape, x.min(), x.max()) plot_image(x, y, 'image sample')class Net(nn.Module):def __init__(self):super(Net, self).__init__()# xw+bself.fc1 = nn.Linear(28 * 28, 256)self.fc2 = nn.Linear(256, 64)self.fc3 = nn.Linear(64, 10)def forward(self, x):# x: [b, 1, 28, 28]# h1 = relu(xw1+b1)x = F.relu(self.fc1(x))# h2 = relu(h1w2+b2)x = F.relu(self.fc2(x))# h3 = h2w3+b3x = self.fc3(x)return xnet = Net() # [w1, b1, w2, b2, w3, b3] optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)train_loss = []for epoch in range(3):for batch_idx, (x, y) in enumerate(train_loader):# x: [b, 1, 28, 28], y: [512]# [b, 1, 28, 28] => [b, 784]x = x.view(x.size(0), 28 * 28)# => [b, 10]out = net(x)# [b, 10]y_onehot = one_hot(y)# loss = mse(out, y_onehot)loss = F.mse_loss(out, y_onehot)optimizer.zero_grad()loss.backward()# w' = w - lr*gradoptimizer.step()train_loss.append(loss.item())if batch_idx % 10 == 0:print(epoch, batch_idx, loss.item())plot_curve(train_loss) # we get optimal [w1, b1, w2, b2, w3, b3]total_correct = 0 for x, y in test_loader:x = x.view(x.size(0), 28 * 28)out = net(x)# out: [b, 10] => pred: [b]pred = out.argmax(dim=1)correct = pred.eq(y).sum().float().item()total_correct += correcttotal_num = len(test_loader.dataset) acc = total_correct / total_num print('test acc:', acc)x, y = next(iter(test_loader)) out = net(x.view(x.size(0), 28 * 28)) pred = out.argmax(dim=1) plot_image(x, pred, 'test')總結
以上是生活随笔為你收集整理的05,pytorch_手写数字案例的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 中央司法警官学院上大学当兵回来可以当兵吗
- 下一篇: 为什么人在战争中起决定作用?