cnn识别mnist、Fashion-MNIST(pytorch)
生活随笔
收集整理的這篇文章主要介紹了
cnn识别mnist、Fashion-MNIST(pytorch)
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
?下面的代碼是cnn是被MNIST,如果識別Fashion-MNIST,可以將數據集換成Fashion-MNIST即可。
第一個全連接的輸入神經元個數如何確定,可以參考我的另一篇博客。即nn.lInear(1600,128)的中數字1600如何確定的?
import torch,torchvision import torch.nn as nn#定義模型 class CNNMnist(nn.Module):def __init__(self):super(CNNMnist,self).__init__()self.feature = nn.Sequential(nn.Conv2d(1,32,3), nn.ReLU(), nn.MaxPool2d(2,2),nn.Conv2d(32,64,3), nn.ReLU(), nn.MaxPool2d(2,2))self.classifier=nn.Sequential(nn.Flatten(),nn.Linear(1600, 128),nn.ReLU(),nn.Linear(128,10))def forward(self, x):x = self.feature(x)output = self.classifier(x)return outputnet = CNNMnist()#加載數據集 apply_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])train_dataset = torchvision.datasets.MNIST(root='./data/mnist', train=True, download=True,transform=apply_transform) test_dataset = torchvision.datasets.MNIST(root='./data/mnist', train=False, download=False,transform=apply_transform)train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False)#定義損失函數和優化器 criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(net.parameters(), lr=0.001)#如果有gpu就使用gpu,否則使用cpu device = torch.device('cuda'if torch.cuda.is_available() else 'cpu') net = net.to(device)#訓練模型 print('training on: ',device)def test(test_loader): net.eval()acc = 0.0sum = 0.0loss_sum = 0for batch, (data, target) in enumerate(test_loader):data, target = data.to(device), target.to(device)output = net(data)loss = criterion(output, target)acc+=torch.sum(torch.argmax(output,dim=1)==target).item()sum+=len(target)loss_sum+=loss.item()print('test acc: %.2f%%, loss: %.4f'%(100*acc/sum, loss_sum/(batch+1)))def train(): net.train()loss_sum = 0for batch, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = net(data)loss = criterion(output, target)loss.backward()optimizer.step()if batch%200==0:print('\tbatch: %d, loss: %.4f'%(batch, loss.item()))for epoch in range(5):print('epoch: %d'%epoch)train()test(test_loader)實驗結果:
總結
以上是生活随笔為你收集整理的cnn识别mnist、Fashion-MNIST(pytorch)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: KDD_cup99 pytorch
- 下一篇: cnn识别cifar10、cifar10