手撕 CNN 经典网络之 VGGNet(PyTorch实战篇)
大家好,我是紅色石頭!
在上一篇文章:
手撕 CNN 經典網絡之 VGGNet(理論篇)
詳細介紹了 VGGNet 的網絡結構,今天我們將使用 PyTorch 來復現VGGNet網絡,并用VGGNet模型來解決一個經典的Kaggle圖像識別比賽問題。
正文開始!
1. 數據集制作
在論文中AlexNet作者使用的是ILSVRC 2012比賽數據集,該數據集非常大(有138G),下載、訓練都很消耗時間,我們在復現的時候就不用這個數據集了。由于MNIST、CIFAR10、CIFAR100這些數據集圖片尺寸都較小,不符合AlexNet網絡輸入尺寸227x227的要求,因此我們改用kaggle比賽經典的“貓狗大戰”數據集了。
該數據集包含的訓練集總共25000張圖片,貓狗各12500張,帶標簽;測試集總共12500張,不帶標簽。我們僅使用帶標簽的25000張圖片,分別拿出2500張貓和狗的圖片作為模型的驗證集。我們按照以下目錄層級結構,將數據集圖片放好。
為了方便大家訓練,我們將該數據集放在百度云盤,下載鏈接:?
鏈接:https://pan.baidu.com/s/1UEOzxWWMLCUoLTxdWUkB4A
提取碼:cdue
1.1 制作圖片數據的索引
準備好數據集之后,我們需要用PyTorch來讀取并制作可以用來訓練和測試的數據集。對于訓練集和測試集,首先要分別制作對應的圖片數據索引,即train.txt和test.txt兩個文件,每個txt中包含每個圖片的目錄和對應類別class(cat對應的label=0,dog對應的label=1)。示意圖如下:
制作圖片數據索引train.txt和test.txt兩個文件的python腳本程序如下:
import?ostrain_txt_path?=?os.path.join("data",?"catVSdog",?"train.txt") train_dir?=?os.path.join("data",?"catVSdog",?"train_data") valid_txt_path?=?os.path.join("data",?"catVSdog",?"test.txt") valid_dir?=?os.path.join("data",?"catVSdog",?"test_data")def?gen_txt(txt_path,?img_dir):f?=?open(txt_path,?'w')for?root,?s_dirs,?_?in?os.walk(img_dir,?topdown=True):??#?獲取?train文件下各文件夾名稱for?sub_dir?in?s_dirs:i_dir?=?os.path.join(root,?sub_dir)?????????????#?獲取各類的文件夾?絕對路徑img_list?=?os.listdir(i_dir)????????????????????#?獲取類別文件夾下所有png圖片的路徑for?i?in?range(len(img_list)):if?not?img_list[i].endswith('jpg'):?????????#?若不是png文件,跳過continue#label?=?(img_list[i].split('.')[0]?==?'cat')??0?:?1?label?=?img_list[i].split('.')[0]#?將字符類別轉為整型類型表示if?label?==?'cat':label?=?'0'else:label?=?'1'img_path?=?os.path.join(i_dir,?img_list[i])line?=?img_path?+?'?'?+?label?+?'\n'f.write(line)f.close()if?__name__?==?'__main__':gen_txt(train_txt_path,?train_dir)gen_txt(valid_txt_path,?valid_dir)運行腳本之后就在./data/catVSdog/目錄下生成train.txt和test.txt兩個索引文件。
1.2 構建Dataset子類
PyTorch 加載自己的數據集,需要寫一個繼承自torch.utils.data中Dataset類,并修改其中的__init__方法、__getitem__方法、__len__方法。默認加載的都是圖片,__init__的目的是得到一個包含數據和標簽的list,每個元素能找到圖片位置和其對應標簽。然后用__getitem__方法得到每個元素的圖像像素矩陣和標簽,返回img和label。
from?PIL?import?Image from?torch.utils.data?import?Datasetclass?MyDataset(Dataset):def?__init__(self,?txt_path,?transform?=?None,?target_transform?=?None):fh?=?open(txt_path,?'r')imgs?=?[]for?line?in?fh:line?=?line.rstrip()words?=?line.split()imgs.append((words[0],?int(words[1])))?#?類別轉為整型intself.imgs?=?imgs?self.transform?=?transformself.target_transform?=?target_transformdef?__getitem__(self,?index):fn,?label?=?self.imgs[index]img?=?Image.open(fn).convert('RGB')?#img?=?Image.open(fn)if?self.transform?is?not?None:img?=?self.transform(img)?return?img,?labeldef?__len__(self):return?len(self.imgs)getitem是核心函數。self.imgs是一個list,self.imgs[index]是一個str,包含圖片路徑,圖片標簽,這些信息是從上面生成的txt文件中讀取;利用Image.open對圖片進行讀取,注意這里的img是單通道還是三通道的;self.transform(img)對圖片進行處理,這個transform里邊可以實現減均值、除標準差、隨機裁剪、旋轉、翻轉、放射變換等操作。
1.3 加載數據集和數據預處理
當Mydataset構建好,剩下的操作就交給DataLoder來加載數據集。在DataLoder中,會觸發Mydataset中的getiterm函數讀取一張圖片的數據和標簽,并拼接成一個batch返回,作為模型真正的輸入。
pipline_train?=?transforms.Compose([#transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),??#隨機旋轉圖片#將圖片尺寸resize到224x224transforms.Resize((224,224)),#將圖片轉化為Tensor格式transforms.ToTensor(),#正則化(當模型出現過擬合的情況時,用來降低模型的復雜度)transforms.Normalize((0.5,?0.5,?0.5),?(0.5,?0.5,?0.5))#transforms.Normalize(mean?=?[0.485,?0.456,?0.406],std?=?[0.229,?0.224,?0.225]) ]) pipline_test?=?transforms.Compose([#將圖片尺寸resize到224x224transforms.Resize((224,224)),transforms.ToTensor(),transforms.Normalize((0.5,?0.5,?0.5),?(0.5,?0.5,?0.5))#transforms.Normalize(mean?=?[0.485,?0.456,?0.406],std?=?[0.229,?0.224,?0.225]) ]) train_data?=?MyDataset('./data/catVSdog/train.txt',?transform=pipline_train) test_data?=?MyDataset('./data/catVSdog/test.txt',?transform=pipline_test)#train_data?和test_data包含多有的訓練與測試數據,調用DataLoader批量加載 trainloader?=?torch.utils.data.DataLoader(dataset=train_data,?batch_size=64,?shuffle=True) testloader?=?torch.utils.data.DataLoader(dataset=test_data,?batch_size=32,?shuffle=False) #?類別信息也是需要我們給定的 classes?=?('cat',?'dog')?#?對應label=0,label=1在數據預處理中,我們將圖片尺寸調整到224x224,符合VGGNet網絡的輸入要求。均值mean = [0.5, 0.5, 0.5],方差std = [0.5, 0.5, 0.5],然后使用transforms.Normalize進行歸一化操作。?
我們來看一下最終制作的數據集圖片和它們對應的標簽:
examples?=?enumerate(trainloader) batch_idx,?(example_data,?example_label)?=?next(examples) #?批量展示圖片 for?i?in?range(4):plt.subplot(1,?4,?i?+?1)plt.tight_layout()??#自動調整子圖參數,使之填充整個圖像區域img?=?example_data[i]img?=?img.numpy()?#?FloatTensor轉為ndarrayimg?=?np.transpose(img,?(1,2,0))?#?把channel那一維放到最后img?=?img?*?[0.5,?0.5,?0.5]?+?[0.5,?0.5,?0.5]#img?=?img?*?[0.229,?0.224,?0.225]?+?[0.485,?0.456,?0.406]plt.imshow(img)plt.title("label:{}".format(example_label[i]))plt.xticks([])plt.yticks([]) plt.show()2. 搭建VGGNet神經網絡結構
class?VGG(nn.Module):def?__init__(self,?features,?num_classes=2,?init_weights=False):super(VGG,?self).__init__()self.features?=?featuresself.classifier?=?nn.Sequential(nn.Linear(512*7*7,?500),nn.ReLU(True),nn.Dropout(p=0.5),nn.Linear(500,?20),nn.ReLU(True),nn.Dropout(p=0.5),nn.Linear(20,?num_classes))if?init_weights:self._initialize_weights()def?forward(self,?x):#?N?x?3?x?224?x?224x?=?self.features(x)#?N?x?512?x?7?x?7x?=?torch.flatten(x,?start_dim=1)#?N?x?512*7*7x?=?self.classifier(x)return?xdef?_initialize_weights(self):for?m?in?self.modules():if?isinstance(m,?nn.Conv2d):#?nn.init.kaiming_normal_(m.weight,?mode='fan_out',?nonlinearity='relu')nn.init.xavier_uniform_(m.weight)if?m.bias?is?not?None:nn.init.constant_(m.bias,?0)elif?isinstance(m,?nn.Linear):nn.init.xavier_uniform_(m.weight)#?nn.init.normal_(m.weight,?0,?0.01)nn.init.constant_(m.bias,?0)def?make_features(cfg:?list):layers?=?[]in_channels?=?3for?v?in?cfg:if?v?==?"M":layers?+=?[nn.MaxPool2d(kernel_size=2,?stride=2)]else:conv2d?=?nn.Conv2d(in_channels,?v,?kernel_size=3,?padding=1)layers?+=?[conv2d,?nn.ReLU(True)]in_channels?=?vreturn?nn.Sequential(*layers)cfgs?=?{'vgg11':?[64,?'M',?128,?'M',?256,?256,?'M',?512,?512,?'M',?512,?512,?'M'],'vgg13':?[64,?64,?'M',?128,?128,?'M',?256,?256,?'M',?512,?512,?'M',?512,?512,?'M'],'vgg16':?[64,?64,?'M',?128,?128,?'M',?256,?256,?256,?'M',?512,?512,?512,?'M',?512,?512,?512,?'M'],'vgg19':?[64,?64,?'M',?128,?128,?'M',?256,?256,?256,?256,?'M',?512,?512,?512,?512,?'M',?512,?512,?512,?512,?'M'], }def?vgg(model_name="vgg16",?**kwargs):assert?model_name?in?cfgs,?"Warning:?model?number?{}?not?in?cfgs?dict!".format(model_name)cfg?=?cfgs[model_name]model?=?VGG(make_features(cfg),?**kwargs)return?model首先,我們從VGG 6個結構中選擇了A、B、D、E這四個來搭建模型,建立的cfg字典包含了這4個結構。例如對于vgg16,[64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']表示了卷積層的結構。64表示conv3-64,'M'表示maxpool,128表示conv3-128,256表示conv3-256,512表示conv3-512。
選定好哪個VGG結構之后,將該列表傳入到函數make_features()中,構建VGG的卷積層,函數返回實例化模型。例如我們來構建vgg16的卷積層結構并打印看看:
cfg?=?cfgs['vgg16'] make_features(cfg)定義VGG類的時候,參數num_classes指的是類別的數量,由于我們這里的數據集只有貓和狗兩個類別,因此這里的全連接層的神經元個數做了微調。num_classes=2,輸出層也是兩個神經元,不是原來的1000個神經元。FC4096由原來的4096個神經元分別改為500、20個神經元。這里的改動大家注意一下,根據實際數據集的類別數量進行調整。整個網絡的其它結構跟論文中的完全一樣。?
函數initialize_weights()是對網絡參數進行初始化操作,這里我們默認選擇關閉初始化操作。?
函數forward()定義了VGG網絡的完整結構,這里注意最后的卷積層輸出的featureMap是N x 512 x 7 x 7,N表示batchsize,需要將其展開為一維向量,方便與全連接層連接。
3. 將定義好的網絡結構搭載到GPU/CPU,并定義優化器
#創建模型,部署gpu device?=?torch.device("cuda"?if?torch.cuda.is_available()?else?"cpu") model_name?=?"vgg16" model?=?vgg(model_name=model_name,?num_classes=2,?init_weights=True) model.to(device) #定義優化器 loss_function?=?nn.CrossEntropyLoss() optimizer?=?optim.Adam(model.parameters(),?lr=0.0001)4. 定義訓練過程
def?train_runner(model,?device,?trainloader,?loss_function,?optimizer,?epoch):#訓練模型,?啟用?BatchNormalization?和?Dropout,?將BatchNormalization和Dropout置為Truemodel.train()total?=?0correct?=0.0#enumerate迭代已加載的數據集,同時獲取數據和數據下標for?i,?data?in?enumerate(trainloader,?0):inputs,?labels?=?data#把模型部署到device上inputs,?labels?=?inputs.to(device),?labels.to(device)#初始化梯度optimizer.zero_grad()#保存訓練結果outputs?=?model(inputs)#計算損失和#loss?=?F.cross_entropy(outputs,?labels)loss?=?loss_function(outputs,?labels)#獲取最大概率的預測結果#dim=1表示返回每一行的最大值對應的列下標predict?=?outputs.argmax(dim=1)total?+=?labels.size(0)correct?+=?(predict?==?labels).sum().item()#反向傳播loss.backward()#更新參數optimizer.step()if?i?%?100?==?0:#loss.item()表示當前loss的數值print("Train?Epoch{}?\t?Loss:?{:.6f},?accuracy:?{:.6f}%".format(epoch,?loss.item(),?100*(correct/total)))Loss.append(loss.item())Accuracy.append(correct/total)return?loss.item(),?correct/total5. 定義測試過程
def?test_runner(model,?device,?testloader):#模型驗證,?必須要寫,?否則只要有輸入數據,?即使不訓練,?它也會改變權值#因為調用eval()將不啟用?BatchNormalization?和?Dropout,?BatchNormalization和Dropout置為Falsemodel.eval()#統計模型正確率,?設置初始值correct?=?0.0test_loss?=?0.0total?=?0#torch.no_grad將不會計算梯度,?也不會進行反向傳播with?torch.no_grad():for?data,?label?in?testloader:data,?label?=?data.to(device),?label.to(device)output?=?model(data)test_loss?+=?F.cross_entropy(output,?label).item()predict?=?output.argmax(dim=1)#計算正確數量total?+=?label.size(0)correct?+=?(predict?==?label).sum().item()#計算損失值print("test_avarage_loss:?{:.6f},?accuracy:?{:.6f}%".format(test_loss/total,?100*(correct/total)))6. 運行
#調用 epoch?=?20 Loss?=?[] Accuracy?=?[] for?epoch?in?range(1,?epoch+1):print("start_time",time.strftime('%Y-%m-%d?%H:%M:%S',time.localtime(time.time())))loss,?acc?=?train_runner(model,?device,?trainloader,?loss_function,?optimizer,?epoch)Loss.append(loss)Accuracy.append(acc)test_runner(model,?device,?testloader)print("end_time:?",time.strftime('%Y-%m-%d?%H:%M:%S',time.localtime(time.time())),'\n')print('Finished?Training') plt.subplot(2,1,1) plt.plot(Loss) plt.title('Loss') plt.show() plt.subplot(2,1,2) plt.plot(Accuracy) plt.title('Accuracy') plt.show()經歷 20 次 epoch 的 loss 和 accuracy 曲線如下:
經過20個epoch的訓練之后,accuracy達到了94.68%。
注意,由于 VGGNet網絡比較大,用CPU會跑得很慢甚至直接卡頓,建議使用GPU訓練。
7. 保存模型
print(model) torch.save(model,?'./models/vgg-catvsdog.pth')?#保存模型VGGNet 的模型會打印出來,并將模型模型命令為 vgg-catvsdog.pth 保存在固定目錄下。
8. 模型測試
下面使用一張貓狗大戰測試集的圖片進行模型的測試。
from?PIL?import?Image import?numpy?as?npif?__name__?==?'__main__':device?=?torch.device('cuda'?if?torch.cuda.is_available()?else?'cpu')model?=?torch.load('./models/vgg-catvsdog.pth')?#加載模型model?=?model.to(device)model.eval()????#把模型轉為test模式#讀取要預測的圖片#?讀取要預測的圖片img?=?Image.open("./images/test_dog.jpg")?#?讀取圖像#img.show()plt.imshow(img)?#?顯示圖片plt.axis('off')?#?不顯示坐標軸plt.show()#?導入圖片,圖片擴展后為[1,1,32,32]trans?=?transforms.Compose([transforms.Resize((227,227)),transforms.ToTensor(),transforms.Normalize((0.5,?0.5,?0.5),?(0.5,?0.5,?0.5))#transforms.Normalize(mean?=?[0.485,?0.456,?0.406],std?=?[0.229,?0.224,?0.225])])img?=?trans(img)img?=?img.to(device)img?=?img.unsqueeze(0)??#圖片擴展多一維,因為輸入到保存的模型中是4維的[batch_size,通道,長,寬],而普通圖片只有三維,[通道,長,寬]#?預測?#?預測?classes?=?('cat',?'dog')output?=?model(img)prob?=?F.softmax(output,dim=1)?#prob是2個分類的概率print("概率:",prob)value,?predicted?=?torch.max(output.data,?1)predict?=?output.argmax(dim=1)pred_class?=?classes[predicted.item()]print("預測類別:",pred_class)輸出:
概率: tensor([[7.6922e-08, 1.0000e+00]], device='cuda:0', grad_fn=<SoftmaxBackward>)?
預測類別: dog
模型預測結果正確!
好了,以上就是使用 PyTorch?復現 VGGNet 網絡的核心代碼。建議大家根據文章內容完整碼一下代碼,可以根據實際情況使用自己的數據集,并對網絡結構進行微調。
完整代碼我已經放在了?GitHub 上,地址:
https://github.com/RedstoneWill/CNN_PyTorch_Beginner/blob/main/VGGNet/VGGNet.ipynb
手撕 CNN 系列:
手撕 CNN 經典網絡之 LeNet-5(理論篇)
手撕 CNN 經典網絡之 LeNet-5(MNIST 實戰篇)
手撕 CNN 經典網絡之 LeNet-5(CIFAR10 實戰篇)
手撕 CNN 經典網絡之 LeNet-5(自定義實戰篇)
手撕 CNN 經典網絡之 AlexNet(理論篇)
手撕 CNN 經典網絡之 AlexNet(PyTorch 實戰篇)
手撕 CNN 經典網絡之 VGGNet(理論篇)
如果覺得這篇文章有用的話,麻煩點個在看或轉發朋友圈!
推薦閱讀
(點擊標題可跳轉閱讀)
干貨 | 公眾號歷史文章精選
我的深度學習入門路線
我的機器學習入門路線圖
重磅!
AI有道年度技術文章電子版PDF來啦!
掃描下方二維碼,添加?AI有道小助手微信,可申請入群,并獲得2020完整技術文章合集PDF(一定要備注:入群?+ 地點 + 學校/公司。例如:入群+上海+復旦。?
長按掃碼,申請入群
(添加人數較多,請耐心等待)
感謝你的分享,點贊,在看三連??
總結
以上是生活随笔為你收集整理的手撕 CNN 经典网络之 VGGNet(PyTorch实战篇)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 在 CTreeCtrl 中枚举系统中的所
- 下一篇: 对于我这个软妹子来说,为什么python