【小白学PyTorch】12.SENet详解及PyTorch实现
<<小白學(xué)PyTorch>>
小白學(xué)PyTorch | 11 MobileNet詳解及PyTorch實(shí)現(xiàn)
小白學(xué)PyTorch | 10 pytorch常見(jiàn)運(yùn)算詳解
小白學(xué)PyTorch | 9 tensor數(shù)據(jù)結(jié)構(gòu)與存儲(chǔ)結(jié)構(gòu)
小白學(xué)PyTorch | 8 實(shí)戰(zhàn)之MNIST小試牛刀
小白學(xué)PyTorch | 7 最新版本torchvision.transforms常用API翻譯與講解
小白學(xué)PyTorch | 6 模型的構(gòu)建訪問(wèn)遍歷存儲(chǔ)(附代碼)
小白學(xué)PyTorch | 5 torchvision預(yù)訓(xùn)練模型與數(shù)據(jù)集全覽
小白學(xué)PyTorch | 4 構(gòu)建模型三要素與權(quán)重初始化
小白學(xué)PyTorch | 3 淺談Dataset和Dataloader
小白學(xué)PyTorch | 2 淺談?dòng)?xùn)練集驗(yàn)證集和測(cè)試集
小白學(xué)PyTorch | 1 搭建一個(gè)超簡(jiǎn)單的網(wǎng)絡(luò)
小白學(xué)PyTorch | 動(dòng)態(tài)圖與靜態(tài)圖的淺顯理解
參考目錄:
1 網(wǎng)絡(luò)結(jié)構(gòu)
2 參數(shù)量分析
3 PyTorch實(shí)現(xiàn)與解析
上一節(jié)課講解了MobileNet的一個(gè)DSC深度可分離卷積的概念,希望大家可以在實(shí)際的任務(wù)中使用這種方法,現(xiàn)在再來(lái)介紹EfficientNet的另外一個(gè)基礎(chǔ)知識(shí),Squeeze-and-Excitation Networks壓縮-激活網(wǎng)絡(luò)
1 網(wǎng)絡(luò)結(jié)構(gòu)
可以看出來(lái),左邊的圖是一個(gè)典型的Resnet的結(jié)構(gòu),Resnet這個(gè)殘差結(jié)構(gòu)特征圖求和而不是通道拼接,這一點(diǎn)可以注意一下
這個(gè)SENet結(jié)構(gòu)式融合在殘差網(wǎng)絡(luò)上的,我來(lái)分析一下上圖右邊的結(jié)構(gòu):
輸出特征圖假設(shè)shape是的;
一般的Resnet就是這個(gè)特征圖經(jīng)過(guò)殘差網(wǎng)絡(luò)的基本組塊,得到了輸出特征圖,然后輸入特征圖和輸入特征圖通過(guò)殘差結(jié)構(gòu)連在一起(通過(guò)加和的方式連在一起);
SE模塊就是輸出特征圖先經(jīng)過(guò)一個(gè)全局池化層,shape從變成了,這個(gè)就變成了一個(gè)全連接層的輸入啦
壓縮Squeeze:先放到第一個(gè)全連接層里面,輸入個(gè)元素,輸出,r是一個(gè)事先設(shè)置的參數(shù);
激活Excitation:在接上一個(gè)全連接層,輸入是個(gè)神經(jīng)元,輸出是個(gè)元素,實(shí)現(xiàn)激活的過(guò)程;
現(xiàn)在我們有了一個(gè)個(gè)元素的經(jīng)過(guò)了兩層全連接層的輸出,這個(gè)C個(gè)元素,剛好表示的是原來(lái)輸出特征圖中C個(gè)通道的一個(gè)權(quán)重值,所以我們讓C個(gè)通道上的像素值分別乘上全連接的C個(gè)輸出,這個(gè)步驟在圖中稱(chēng)為Scale。而這個(gè)調(diào)整過(guò)特征圖每一個(gè)通道權(quán)重的特征圖是SE-Resnet的輸出特征圖,之后再考慮殘差接連的步驟。
在原文論文中還有另外一個(gè)結(jié)構(gòu)圖,供大家參考:
2 參數(shù)量分析
每一個(gè)卷積層都增加了額外的兩個(gè)全連接層,不夠好在全連接層的參數(shù)非常小,所以直觀來(lái)看應(yīng)該整體不會(huì)增加很多的計(jì)算量。Resnet50的參數(shù)量為25M的大小,增加了SE模塊,增加了2.5M的參數(shù)量,所以大概增加了10%左右,而且這2.5M的參數(shù)主要集中在final stage的se模塊,因?yàn)樵谧詈笠粋€(gè)卷積模塊中,特征圖擁有最大的通道數(shù),所以這個(gè)final stage的參數(shù)量占據(jù)了增加的2.5M參數(shù)的96%。
這里放一個(gè)幾個(gè)網(wǎng)絡(luò)結(jié)構(gòu)的對(duì)比:
3 PyTorch實(shí)現(xiàn)與解析
先上完整版的代碼,大家可以復(fù)制本地IDE跑一跑,如果代碼有什么問(wèn)題可以聯(lián)系我:
import?torch import?torch.nn?as?nn import?torch.nn.functional?as?Fclass?PreActBlock(nn.Module):def?__init__(self,?in_planes,?planes,?stride=1):super(PreActBlock,?self).__init__()self.bn1?=?nn.BatchNorm2d(in_planes)self.conv1?=?nn.Conv2d(in_planes,?planes,?kernel_size=3,?stride=stride,?padding=1,?bias=False)self.bn2?=?nn.BatchNorm2d(planes)self.conv2?=?nn.Conv2d(planes,?planes,?kernel_size=3,?stride=1,?padding=1,?bias=False)if?stride?!=?1?or?in_planes?!=?planes:self.shortcut?=?nn.Sequential(nn.Conv2d(in_planes,?planes,?kernel_size=1,?stride=stride,?bias=False))#?SE?layersself.fc1?=?nn.Conv2d(planes,?planes//16,?kernel_size=1)self.fc2?=?nn.Conv2d(planes//16,?planes,?kernel_size=1)def?forward(self,?x):out?=?F.relu(self.bn1(x))shortcut?=?self.shortcut(out)?if?hasattr(self,?'shortcut')?else?xout?=?self.conv1(out)out?=?self.conv2(F.relu(self.bn2(out)))#?Squeezew?=?F.avg_pool2d(out,?out.size(2))w?=?F.relu(self.fc1(w))w?=?F.sigmoid(self.fc2(w))#?Excitationout?=?out?*?wout?+=?shortcutreturn?outclass?SENet(nn.Module):def?__init__(self,?block,?num_blocks,?num_classes=10):super(SENet,?self).__init__()self.in_planes?=?64self.conv1?=?nn.Conv2d(3,?64,?kernel_size=3,?stride=1,?padding=1,?bias=False)self.bn1?=?nn.BatchNorm2d(64)self.layer1?=?self._make_layer(block,??64,?num_blocks[0],?stride=1)self.layer2?=?self._make_layer(block,?128,?num_blocks[1],?stride=2)self.layer3?=?self._make_layer(block,?256,?num_blocks[2],?stride=2)self.layer4?=?self._make_layer(block,?512,?num_blocks[3],?stride=2)self.linear?=?nn.Linear(512,?num_classes)def?_make_layer(self,?block,?planes,?num_blocks,?stride):strides?=?[stride]?+?[1]*(num_blocks-1)layers?=?[]for?stride?in?strides:layers.append(block(self.in_planes,?planes,?stride))self.in_planes?=?planesreturn?nn.Sequential(*layers)def?forward(self,?x):out?=?F.relu(self.bn1(self.conv1(x)))out?=?self.layer1(out)out?=?self.layer2(out)out?=?self.layer3(out)out?=?self.layer4(out)out?=?F.avg_pool2d(out,?4)out?=?out.view(out.size(0),?-1)out?=?self.linear(out)return?outdef?SENet18():return?SENet(PreActBlock,?[2,2,2,2])net?=?SENet18() y?=?net(torch.randn(1,3,32,32)) print(y.size()) print(net)輸出和注解我都整理了一下:
- END -往期精彩回顧適合初學(xué)者入門(mén)人工智能的路線及資料下載機(jī)器學(xué)習(xí)及深度學(xué)習(xí)筆記等資料打印機(jī)器學(xué)習(xí)在線手冊(cè)深度學(xué)習(xí)筆記專(zhuān)輯《統(tǒng)計(jì)學(xué)習(xí)方法》的代碼復(fù)現(xiàn)專(zhuān)輯 AI基礎(chǔ)下載機(jī)器學(xué)習(xí)的數(shù)學(xué)基礎(chǔ)專(zhuān)輯獲取一折本站知識(shí)星球優(yōu)惠券,復(fù)制鏈接直接打開(kāi):https://t.zsxq.com/662nyZF本站qq群704220115。加入微信群請(qǐng)掃碼進(jìn)群(如果是博士或者準(zhǔn)備讀博士請(qǐng)說(shuō)明):總結(jié)
以上是生活随笔為你收集整理的【小白学PyTorch】12.SENet详解及PyTorch实现的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 【论文解读】基于图卷积的价格感知推荐
- 下一篇: 【效率】这个神器可以摆脱变量命名纠结!