建議大家可以實踐下,代碼都很詳細,有不清楚的地方評論區見~
1、前言
ResNet(Residual Neural Network)由微軟研究院的Kaiming He等四名華人提出,通過使用ResNet Unit成功訓練出了152層的神經網絡,并在ILSVRC2015比賽中取得冠軍,在top5上的錯誤率為3.57%,同時參數量比VGGNet低,效果非常突出。ResNet的結構可以極快的加速神經網絡的訓練,模型的準確率也有比較大的提升。同時ResNet的推廣性非常好,甚至可以直接用到InceptionNet網絡中。
下圖是ResNet34層模型的結構簡圖。
2、ResNet詳解
在ResNet網絡中有如下幾個亮點:
在ResNet網絡提出之前,傳統的卷積神經網絡都是通過將一系列卷積層與下采樣層進行堆疊得到的。但是當堆疊到一定網絡深度時,就會出現兩個問題。
梯度消失或梯度爆炸。
退化問題(degradation problem)。
在ResNet論文中說通過數據的預處理以及在網絡中使用BN(Batch Normalization)層能夠解決梯度消失或者梯度爆炸問題。如果不了解BN層可參考這個鏈接。但是對于退化問題(隨著網絡層數的加深,效果還會變差,如下圖所示)并沒有很好的解決辦法。所以ResNet論文提出了residual結構(殘差結構)來減輕退化問題。下圖是使用residual結構的卷積網絡,可以看到隨著網絡的不斷加深,效果并沒有變差,反而變的更好了。殘差結構(residual)
殘差指的是什么?其中ResNet提出了兩種mapping:一種是identity mapping,指的就是下圖中”彎彎的曲線”,另一種residual mapping,指的就是除了”彎彎的曲線“那部分,所以最后的輸出是 y=F(x)+x
顧名思義,就是指本身,也就是公式中的x,而residual mapping指的是“差”,也就是y?x,所以殘差指的就是F(x)部分。
下圖是論文中給出的兩種殘差結構。左邊的殘差結構是針對層數較少網絡,例如ResNet18層和ResNet34層網絡。右邊是針對網絡層數較多的網絡,例如ResNet101,ResNet152等。為什么深層網絡要使用右側的殘差結構呢。因為,右側的殘差結構能夠減少網絡參數與運算量。同樣輸入一個channel為256的特征矩陣,如果使用左側的殘差結構需要大約1170648個參數,但如果使用右側的殘差結構只需要69632個參數。明顯搭建深層網絡時,使用右側的殘差結構更合適。我們先對左側的殘差結構(針對ResNet18/34)進行一個分析。
如下圖所示,該殘差結構的主分支是由兩層3x3的卷積層組成,而殘差結構右側的連接線是shortcut分支也稱捷徑分支(注意為了讓主分支上的輸出矩陣能夠與我們捷徑分支上的輸出矩陣進行相加,必須保證這兩個輸出特征矩陣有相同的shape)。如果剛剛仔細觀察了ResNet34網絡結構圖的同學,應該能夠發現圖中會有一些虛線的殘差結構。在原論文中作者只是簡單說了這些虛線殘差結構有降維的作用,并在捷徑分支上通過1x1的卷積核進行降維處理。而下圖右側給出了詳細的虛線殘差結構,注意下每個卷積層的步距stride,以及捷徑分支上的卷積核的個數(與主分支上的卷積核個數相同)。接著我們再來分析下針對ResNet50/101/152的殘差結構,如下圖所示。在該殘差結構當中,主分支使用了三個卷積層,第一個是1x1的卷積層用來壓縮channel維度,第二個是3x3的卷積層,第三個是1x1的卷積層用來還原channel維度(注意主分支上第一層卷積層和第二次卷積層所使用的卷積核個數是相同的,第三次是第一層的4倍)。該殘差結構所對應的虛線殘差結構如下圖右側所示,同樣在捷徑分支上有一層1x1的卷積層,它的卷積核個數與主分支上的第三層卷積層卷積核個數相同,注意每個卷積層的步距。為什么殘差學習相對更容易,從直觀上看殘差學習需要學習的內容少,因為殘差一般會比較小,學習難度小點。不過我們可以從數學的角度來分析這個問題,首先殘差單元可以表示為:
其中 XL和 XL+1分別表示的是第L個殘差單元的輸入和輸出,注意每個殘差單元一般包含多層結構。F是殘差函數,表示學習到的殘差,而 h(XL)=XL表示恒等映射, F是ReLU激活函數。基于上式,我們求得從淺層 l到深層 L 的學習特征為:式子的第一個因子表示的損失函數到達L的梯度,小括號中的1表明短路機制可以無損地傳播梯度,而另外一項殘差梯度則需要經過帶有weights的層,梯度不是直接傳遞過來的。殘差梯度不會那么巧全為-1,而且就算其比較小,有1的存在也不會導致梯度消失。所以殘差學習會更容易。要注意上面的推導并不是嚴格的證明。
下面這幅圖是原論文給出的不同深度的ResNet網絡結構配置,注意表中的殘差結構給出了主分支上卷積核的大小與卷積核個數,表中的xN表示將該殘差結構重復N次。那到底哪些殘差結構是虛線殘差結構呢。對于我們ResNet18/34/50/101/152,表中conv3_x, conv4_x, conv5_x所對應的一系列殘差結構的第一層殘差結構都是虛線殘差結構。因為這一系列殘差結構的第一層都有調整輸入特征矩陣shape的使命(將特征矩陣的高和寬縮減為原來的一半,將深度channel調整成下一層殘差結構所需要的channel)。為了方便理解,下面給出了ResNet34的網絡結構圖,圖中簡單標注了一些信息。
對于我們ResNet50/101/152,其實在conv2_x所對應的一系列殘差結構的第一層也是虛線殘差結構。因為它需要調整輸入特征矩陣的channel,根據表格可知通過3x3的max pool之后輸出的特征矩陣shape應該是[56, 56, 64],但我們conv2_x所對應的一系列殘差結構中的實線殘差結構它們期望的輸入特征矩陣shape是[56, 56, 256](因為這樣才能保證輸入輸出特征矩陣shape相同,才能將捷徑分支的輸出與主分支的輸出進行相加)。所以第一層殘差結構需要將shape從[56, 56, 64] --> [56, 56, 256]。注意,這里只調整channel維度,高和寬不變(而conv3_x, conv4_x, conv5_x所對應的一系列殘差結構的第一層虛線殘差結構不僅要調整channel還要將高和寬縮減為原來的一半)。
代碼
注:
本次訓練集下載在AlexNet博客有詳細解說:https://blog.csdn.net/weixin_44023658/article/details/105798326
使用遷移學習方法實現收錄在我的這篇blog中:遷移學習 TransferLearning—通俗易懂地介紹(pytorch實例)
#model.pyimport?torch.nn?as?nn
import?torch#18/34
class?BasicBlock(nn.Module):expansion?=?1?#每一個conv的卷積核個數的倍數def?__init__(self,?in_channel,?out_channel,?stride=1,?downsample=None):#downsample對應虛線殘差結構super(BasicBlock,?self).__init__()self.conv1?=?nn.Conv2d(in_channels=in_channel,?out_channels=out_channel,kernel_size=3,?stride=stride,?padding=1,?bias=False)self.bn1?=?nn.BatchNorm2d(out_channel)#BN處理self.relu?=?nn.ReLU()self.conv2?=?nn.Conv2d(in_channels=out_channel,?out_channels=out_channel,kernel_size=3,?stride=1,?padding=1,?bias=False)self.bn2?=?nn.BatchNorm2d(out_channel)self.downsample?=?downsampledef?forward(self,?x):identity?=?x?#捷徑上的輸出值if?self.downsample?is?not?None:identity?=?self.downsample(x)out?=?self.conv1(x)out?=?self.bn1(out)out?=?self.relu(out)out?=?self.conv2(out)out?=?self.bn2(out)out?+=?identityout?=?self.relu(out)return?out#50,101,152
class?Bottleneck(nn.Module):expansion?=?4#4倍def?__init__(self,?in_channel,?out_channel,?stride=1,?downsample=None):super(Bottleneck,?self).__init__()self.conv1?=?nn.Conv2d(in_channels=in_channel,?out_channels=out_channel,kernel_size=1,?stride=1,?bias=False)??#?squeeze?channelsself.bn1?=?nn.BatchNorm2d(out_channel)self.relu?=?nn.ReLU(inplace=True)#?-----------------------------------------self.conv2?=?nn.Conv2d(in_channels=out_channel,?out_channels=out_channel,kernel_size=3,?stride=stride,?bias=False,?padding=1)self.bn2?=?nn.BatchNorm2d(out_channel)self.relu?=?nn.ReLU(inplace=True)#?-----------------------------------------self.conv3?=?nn.Conv2d(in_channels=out_channel,?out_channels=out_channel*self.expansion,#輸出*4kernel_size=1,?stride=1,?bias=False)??#?unsqueeze?channelsself.bn3?=?nn.BatchNorm2d(out_channel*self.expansion)self.relu?=?nn.ReLU(inplace=True)self.downsample?=?downsampledef?forward(self,?x):identity?=?xif?self.downsample?is?not?None:identity?=?self.downsample(x)out?=?self.conv1(x)out?=?self.bn1(out)out?=?self.relu(out)out?=?self.conv2(out)out?=?self.bn2(out)out?=?self.relu(out)out?=?self.conv3(out)out?=?self.bn3(out)out?+=?identityout?=?self.relu(out)return?outclass?ResNet(nn.Module):def?__init__(self,?block,?blocks_num,?num_classes=1000,?include_top=True):#block殘差結構?include_top為了之后搭建更加復雜的網絡super(ResNet,?self).__init__()self.include_top?=?include_topself.in_channel?=?64self.conv1?=?nn.Conv2d(3,?self.in_channel,?kernel_size=7,?stride=2,padding=3,?bias=False)self.bn1?=?nn.BatchNorm2d(self.in_channel)self.relu?=?nn.ReLU(inplace=True)self.maxpool?=?nn.MaxPool2d(kernel_size=3,?stride=2,?padding=1)self.layer1?=?self._make_layer(block,?64,?blocks_num[0])self.layer2?=?self._make_layer(block,?128,?blocks_num[1],?stride=2)self.layer3?=?self._make_layer(block,?256,?blocks_num[2],?stride=2)self.layer4?=?self._make_layer(block,?512,?blocks_num[3],?stride=2)if?self.include_top:self.avgpool?=?nn.AdaptiveAvgPool2d((1,?1))??#?output?size?=?(1,?1)自適應self.fc?=?nn.Linear(512?*?block.expansion,?num_classes)for?m?in?self.modules():if?isinstance(m,?nn.Conv2d):nn.init.kaiming_normal_(m.weight,?mode='fan_out',?nonlinearity='relu')def?_make_layer(self,?block,?channel,?block_num,?stride=1):downsample?=?Noneif?stride?!=?1?or?self.in_channel?!=?channel?*?block.expansion:downsample?=?nn.Sequential(nn.Conv2d(self.in_channel,?channel?*?block.expansion,?kernel_size=1,?stride=stride,?bias=False),nn.BatchNorm2d(channel?*?block.expansion))layers?=?[]layers.append(block(self.in_channel,?channel,?downsample=downsample,?stride=stride))self.in_channel?=?channel?*?block.expansionfor?_?in?range(1,?block_num):layers.append(block(self.in_channel,?channel))return?nn.Sequential(*layers)def?forward(self,?x):x?=?self.conv1(x)x?=?self.bn1(x)x?=?self.relu(x)x?=?self.maxpool(x)x?=?self.layer1(x)x?=?self.layer2(x)x?=?self.layer3(x)x?=?self.layer4(x)if?self.include_top:x?=?self.avgpool(x)x?=?torch.flatten(x,?1)x?=?self.fc(x)return?xdef?resnet34(num_classes=1000,?include_top=True):return?ResNet(BasicBlock,?[3,?4,?6,?3],?num_classes=num_classes,?include_top=include_top)def?resnet101(num_classes=1000,?include_top=True):return?ResNet(Bottleneck,?[3,?4,?23,?3],?num_classes=num_classes,?include_top=include_top)
#train.pyimport?torch
import?torch.nn?as?nn
from?torchvision?import?transforms,?datasets
import?json
import?matplotlib.pyplot?as?plt
import?os
import?torch.optim?as?optim
from?model?import?resnet34,?resnet101
import?torchvision.models.resnetdevice?=?torch.device("cuda:0"?if?torch.cuda.is_available()?else?"cpu")
print(device)data_transform?=?{"train":?transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485,?0.456,?0.406],?[0.229,?0.224,?0.225])]),#來自官網參數"val":?transforms.Compose([transforms.Resize(256),#將最小邊長縮放到256transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485,?0.456,?0.406],?[0.229,?0.224,?0.225])])}data_root?=?os.getcwd()
image_path?=?data_root?+?"/flower_data/"??#?flower?data?set?pathtrain_dataset?=?datasets.ImageFolder(root=image_path?+?"train",transform=data_transform["train"])
train_num?=?len(train_dataset)#?{'daisy':0,?'dandelion':1,?'roses':2,?'sunflower':3,?'tulips':4}
flower_list?=?train_dataset.class_to_idx
cla_dict?=?dict((val,?key)?for?key,?val?in?flower_list.items())
#?write?dict?into?json?file
json_str?=?json.dumps(cla_dict,?indent=4)
with?open('class_indices.json',?'w')?as?json_file:json_file.write(json_str)batch_size?=?16
train_loader?=?torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,?shuffle=True,num_workers=0)validate_dataset?=?datasets.ImageFolder(root=image_path?+?"/val",transform=data_transform["val"])
val_num?=?len(validate_dataset)
validate_loader?=?torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size,?shuffle=False,num_workers=0)
#net?=?resnet34()
net?=?resnet34(num_classes=5)
#?load?pretrain?weights#?model_weight_path?=?"./resnet34-pre.pth"
#?missing_keys,?unexpected_keys?=?net.load_state_dict(torch.load(model_weight_path),?strict=False)#載入模型參數#?for?param?in?net.parameters():
#?????param.requires_grad?=?False
#?change?fc?layer?structure#?inchannel?=?net.fc.in_features
#?net.fc?=?nn.Linear(inchannel,?5)net.to(device)loss_function?=?nn.CrossEntropyLoss()
optimizer?=?optim.Adam(net.parameters(),?lr=0.0001)best_acc?=?0.0
save_path?=?'./resNet34.pth'
for?epoch?in?range(3):#?trainnet.train()running_loss?=?0.0for?step,?data?in?enumerate(train_loader,?start=0):images,?labels?=?dataoptimizer.zero_grad()logits?=?net(images.to(device))loss?=?loss_function(logits,?labels.to(device))loss.backward()optimizer.step()#?print?statisticsrunning_loss?+=?loss.item()#?print?train?processrate?=?(step+1)/len(train_loader)a?=?"*"?*?int(rate?*?50)b?=?"."?*?int((1?-?rate)?*?50)print("\rtrain?loss:?{:^3.0f}%[{}->{}]{:.4f}".format(int(rate*100),?a,?b,?loss),?end="")print()#?validatenet.eval()acc?=?0.0??#?accumulate?accurate?number?/?epochwith?torch.no_grad():for?val_data?in?validate_loader:val_images,?val_labels?=?val_dataoutputs?=?net(val_images.to(device))??#?eval?model?only?have?last?output?layer#?loss?=?loss_function(outputs,?test_labels)predict_y?=?torch.max(outputs,?dim=1)[1]acc?+=?(predict_y?==?val_labels.to(device)).sum().item()val_accurate?=?acc?/?val_numif?val_accurate?>?best_acc:best_acc?=?val_accuratetorch.save(net.state_dict(),?save_path)print('[epoch?%d]?train_loss:?%.3f??test_accuracy:?%.3f'?%(epoch?+?1,?running_loss?/?step,?val_accurate))print('Finished?Training')
在這里插入圖片描述#predict.pyimport?torch
from?model?import?resnet34
from?PIL?import?Image
from?torchvision?import?transforms
import?matplotlib.pyplot?as?plt
import?jsondata_transform?=?transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485,?0.456,?0.406],?[0.229,?0.224,?0.225])])#?load?image
img?=?Image.open("./roses.jpg")
plt.imshow(img)
#?[N,?C,?H,?W]
img?=?data_transform(img)
#?expand?batch?dimension
img?=?torch.unsqueeze(img,?dim=0)#?read?class_indict
try:json_file?=?open('./class_indices.json',?'r')class_indict?=?json.load(json_file)
except?Exception?as?e:print(e)exit(-1)#?create?model
model?=?resnet34(num_classes=5)
#?load?model?weights
model_weight_path?=?"./resNet34.pth"
model.load_state_dict(torch.load(model_weight_path))
model.eval()
with?torch.no_grad():#?predict?classoutput?=?torch.squeeze(model(img))predict?=?torch.softmax(output,?dim=0)predict_cla?=?torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)],?predict[predict_cla].numpy())
plt.show()
往期精彩回顧適合初學者入門人工智能的路線及資料下載機器學習及深度學習筆記等資料打印機器學習在線手冊深度學習筆記專輯《統計學習方法》的代碼復現專輯
AI基礎下載機器學習的數學基礎專輯獲取一折本站知識星球優惠券,復制鏈接直接打開:https://t.zsxq.com/662nyZF本站qq群1003271085。加入微信群請掃碼進群(如果是博士或者準備讀博士請說明):
總結
以上是生活随笔為你收集整理的【深度学习】ResNet——CNN经典网络模型详解(pytorch实现)的全部內容,希望文章能夠幫你解決所遇到的問題。
如果覺得生活随笔網站內容還不錯,歡迎將生活随笔推薦給好友。