cnn识别cifar10、cifar100(pytorch)
生活随笔
收集整理的這篇文章主要介紹了
cnn识别cifar10、cifar100(pytorch)
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
下面的代碼是cnn識別cifar10,如果是cifar100,將數據集的改成cifar100,然后模型的輸出神經元10改為100即可。
import torch,torchvision import torch.nn as nn import torchvision.transforms as transforms#定義模型 class CNNCifar(nn.Module):def __init__(self):super(CNNCifar,self).__init__()self.feature = nn.Sequential(nn.Conv2d(3,64,3,padding=2), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2,2),nn.Conv2d(64,128,3,padding=2), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2,2),nn.Conv2d(128,256,3,padding=1),nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d(2,2),nn.Conv2d(256,512,3,padding=1),nn.BatchNorm2d(512), nn.ReLU(), nn.MaxPool2d(2,2))self.classifier=nn.Sequential(nn.Flatten(),nn.Linear(2048, 4096),nn.ReLU(),nn.Dropout(0.5),nn.Linear(4096,4096), nn.ReLU(),nn.Dropout(0.5),nn.Linear(4096,10))def forward(self, x):x = self.feature(x)output = self.classifier(x)return outputnet = CNNCifar() print(net)#加載數據集 apply_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])train_dataset = torchvision.datasets.CIFAR10(root='./data/cifar10', train=True, download=True,transform=apply_transform) test_dataset = torchvision.datasets.CIFAR10(root='./data/cifar10', train=False, download=False,transform=apply_transform)train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False)#定義損失函數和優化器 criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(net.parameters(), lr=0.001,weight_decay=5e-4)#如果有gpu就使用gpu,否則使用cpu device = torch.device('cuda'if torch.cuda.is_available() else 'cpu') net = net.to(device)#訓練模型 print('training on: ',device) def test(): net.eval()acc = 0.0sum = 0.0loss_sum = 0for batch, (data, target) in enumerate(test_loader):data, target = data.to(device), target.to(device)output = net(data)loss = criterion(output, target)acc+=torch.sum(torch.argmax(output,dim=1)==target).item()sum+=len(target)loss_sum+=loss.item()print('test acc: %.2f%%, loss: %.4f'%(100*acc/sum, loss_sum/(batch+1)))def train(): net.train()acc = 0.0sum = 0.0loss_sum = 0for batch, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = net(data)loss = criterion(output, target)loss.backward()optimizer.step()acc +=torch.sum(torch.argmax(output,dim=1)==target).item()sum+=len(target)loss_sum+=loss.item()if batch%200==0:print('\tbatch: %d, loss: %.4f'%(batch, loss.item()))print('train acc: %.2f%%, loss: %.4f'%(100*acc/sum, loss_sum/(batch+1)))for epoch in range(20):print('epoch: %d'%epoch)train()test()實驗結果:
總結
以上是生活随笔為你收集整理的cnn识别cifar10、cifar100(pytorch)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: cnn识别mnist、Fashion-M
- 下一篇: 制作pytorch数据集