Pytorch-实现ResNet-18并在Cifar-10数据集上进行验证
生活随笔
收集整理的這篇文章主要介紹了
Pytorch-实现ResNet-18并在Cifar-10数据集上进行验证
小編覺得挺不錯的,現(xiàn)在分享給大家,幫大家做個參考.
1.Pytorch上搭建ResNet-18
1 import torch 2 from torch import nn 3 from torch.nn import functional as F 4 5 6 class ResBlk(nn.Module): 7 """ 8 resnet block子模塊 9 """ 10 def __init__(self, ch_in, ch_out, stride=1): 11 12 super(ResBlk, self).__init__() 13 14 self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1) 15 self.bn1 = nn.BatchNorm2d(ch_out) 16 self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1) 17 self.bn2 = nn.BatchNorm2d(ch_out) 18 19 self.extra = nn.Sequential() 20 # 如果輸入和輸出的通道不一致,或其步長不為 1,需要將二者轉(zhuǎn)成一致 21 if ch_out != ch_in: 22 # [b, ch_in, h, w] => [b, ch_out, h, w] 23 self.extra = nn.Sequential( 24 nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride), 25 nn.BatchNorm2d(ch_out) 26 ) 27 28 def forward(self, x): 29 30 out = F.relu(self.bn1(self.conv1(x))) 31 out = self.bn2(self.conv2(out)) 32 33 out = self.extra(x) + out 34 out = F.relu(out) 35 return out 36 37 38 class ResNet18(nn.Module): 39 ''' 40 主模塊 41 ''' 42 def __init__(self): 43 super(ResNet18, self).__init__() 44 45 self.conv1 = nn.Sequential( 46 nn.Conv2d(3, 64, kernel_size=3, stride=3, padding=0), 47 nn.BatchNorm2d(64) 48 ) 49 # followed 4 blocks 50 self.blk1 = ResBlk(64, 128, stride=2) #[b, 64, h, w] => [b, 128, h ,w] 51 self.blk2 = ResBlk(128, 256, stride=2) #[b, 128, h, w] => [b, 256, h, w] 52 self.blk3 = ResBlk(256, 512, stride=2) #[b, 256, h, w] => [b, 512, h, w] 53 self.blk4 = ResBlk(512, 512, stride=2) #[b, 512, h, w] => [b, 512, h, w] 54 55 self.outlayer = nn.Linear(512*1*1, 10) #全連接層,總共10個分類 56 57 def forward(self, x): 58 x = F.relu(self.conv1(x)) 59 60 # [b, 64, h, w] => [b, 1024, h, w] 61 x = self.blk1(x) 62 x = self.blk2(x) 63 x = self.blk3(x) 64 x = self.blk4(x) 65 66 x = F.adaptive_avg_pool2d(x, [1, 1]) #[b, 512, h, w] => [b, 512, 1, 1] 67 x = x.view(x.size(0), -1) 68 x = self.outlayer(x) 69 70 return x
舉個栗子測試一下:
1 if __name__ == '__main__':
2
3 blk = ResBlk(64, 128, stride=4)
4 tmp = torch.randn(2, 64, 32, 32)
5 out = blk(tmp)
6 print('block:', out.shape) #block: torch.Size([2, 128, 8, 8])
7
8 x = torch.randn(2, 3, 32, 32)
9 model = ResNet18()
10 out = model(x)
11 print('resnet:', out.shape) #resnet: torch.Size([2, 10])
2.訓(xùn)練Cifar-10數(shù)據(jù)集
所選數(shù)據(jù)集為Cifar-10,該數(shù)據(jù)集共有60000張帶標(biāo)簽的彩色圖像,這些圖像尺寸32*32,分為10個類,每類6000張圖。這里面有50000張用于訓(xùn)練,每個類5000張,另外10000用于測試,每個類1000張。
1 import torch
2 from torch.utils.data import DataLoader
3 from torchvision import datasets,transforms
4 from torch import nn, optim
5
6 from resnet import ResNet18
7
8
9 def main():
10 batchsz = 128
11
12 #訓(xùn)練集
13 cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
14 transforms.Resize((32, 32)),
15 transforms.ToTensor(),
16 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
17 ]))
18 cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
19
20
21 #測試集
22 cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
23 transforms.Resize((32, 32)),
24 transforms.ToTensor(),
25 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
26 ]))
27 cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
28
29
30 x, label = iter(cifar_train).next()
31 print('x:', x.shape, 'label:', label.shape) #x: torch.Size([128, 3, 32, 32]) label: torch.Size([128])
32
33 #定義模型-ResNet
34 model = ResNet18()
35
36 #定義損失函數(shù)和優(yōu)化方式
37 criteon = nn.CrossEntropyLoss()
38 optimizer = optim.Adam(model.parameters(), lr=1e-3)
39 print(model)
40
41 #訓(xùn)練網(wǎng)絡(luò)
42 for epoch in range(1000):
43
44 model.train() #訓(xùn)練模式
45 for batchidx, (x, label) in enumerate(cifar_train):
46 #x: [b, 3, 32, 32]
47 #label: [b]
48
49 logits = model(x) #logits: [b, 10]
50 loss = criteon(logits, label) #標(biāo)量
51
52 optimizer.zero_grad()
53 loss.backward()
54 optimizer.step()
55
56 print(epoch, 'loss:', loss.item())
57
58
59 model.eval() #測試模式
60 with torch.no_grad():
61
62 total_correct = 0 #預(yù)測正確的個數(shù)
63 total_num = 0
64 for x, label in cifar_test:
65 #x: [b, 3, 32, 32]
66 #label: [b]
67
68 logits = model(x) #[b, 10]
69 pred = logits.argmax(dim=1) #[b]
70
71 # [b] vs [b] => scalar tensor
72 correct = torch.eq(pred, label).float().sum().item()
73 total_correct += correct
74 total_num += x.size(0)
75
76 acc = total_correct / total_num
77 print(epoch, 'test acc:', acc)
78
79
80 if __name__ == '__main__':
81 main()
迭代1000次,訓(xùn)練太久了,暫且輸出前5次。
0 loss: 1.0912220478057861
0 test acc: 0.5583
1 loss: 0.8604468107223511
1 test acc: 0.6592
2 loss: 0.6625195145606995
2 test acc: 0.6827
3 loss: 0.7064175009727478
3 test acc: 0.6904
4 loss: 0.5687283277511597
4 test acc: 0.7059
總結(jié)
以上是生活随笔為你收集整理的Pytorch-实现ResNet-18并在Cifar-10数据集上进行验证的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 微信公众号python人工智能回复_py
- 下一篇: 图论与java_算法笔记_150:图论之