前言
最近在看新論文的過程中,發現新論文中的代碼非常簡潔,只用了unfold和fold方法便高效的將論文的思想表達出,因此學習記錄一下unfold和fold方法。
一、方法詳解
torch.nn.Unfold
(kernel_size, dilation
=1, padding
=0, stride
=1
)
-
parameters
-
kernel_size (int or tuple) – 滑動窗口的size
-
stride (int or tuple, optional) – 空間維度上滑動的步長,默認步長為1
-
padding (int or tuple, optional) – implicit zero padding to be added on both sides of input. Default: 0
-
dilation (int or tuple, optional) – 空洞卷積的擴充率,默認為1
-
釋義:提取滑動窗口滑過的所有值,例如下面的例子中,
[[ 0.4009, 0.6350, -0.5197, 0.8148, -0.7235
],
[-1.2102, 0.4621, -0.3421, -0.9261, -2.8376
],
[-1.5553, 0.1713, 0.6820, -2.0880, -0.0204
],
[ 1.1419, -0.4881, -0.9510, -0.0367, -0.8108
],
[ 0.1459, -0.4568, 1.0039, -1.2385, -1.4467
]]
kernel size =3 的窗口滑過,會首先記錄
[[ 0.4009, 0.6350, -0.5197, -1.2102, 0.4621, -0.3421, -1.5553, 0.1713, 0.6820
],
[ 0.6350, -0.5197, 0.8148, 0.4621, -0.3421, -0.9261, 0.1713, 0.6820, -2.0880
],
[-0.5197, 0.8148, -0.7235, -0.3421, -0.9261, -2.8376, 0.6820, -2.0880, -0.0204
],
[-1.2102, 0.4621, -0.3421, -1.5553, 0.1713, 0.6820, 1.1419, -0.4881, -0.9510
],
[ 0.4621, -0.3421, -0.9261, 0.1713, 0.6820, -2.0880, -0.4881, -0.9510, -0.0367
],
[-0.3421, -0.9261, -2.8376, 0.6820, -2.0880, -0.0204, -0.9510, -0.0367, -0.8108
],
[-1.5553, 0.1713, 0.6820, 1.1419, -0.4881, -0.9510, 0.1459, -0.4568, 1.0039
],
[ 0.1713, 0.6820, -2.0880, -0.4881, -0.9510, -0.0367, -0.4568, 1.0039, -1.2385
],
[ 0.6820, -2.0880, -0.0204, -0.9510, -0.0367, -0.8108, 1.0039, -1.2385, -1.4467
]]
-
Note:unfold方法的輸入只能是4維的,即(N,C,H,W)
二、如何計算輸出的size
import torch
import torch
.nn
as nn
if __name__
== '__main__':x
= torch
.randn
(2, 3, 5, 5)print(x
)unfold
= nn
.Unfold
(2)y
= unfold
(x
)print(y
.size
())print(y
)
torch.Size
([2, 12, 16
])
接下來,我們一步一步分析這個結果是怎么計算出來的!
首先,要知道的是,我們的輸入必須是4維的,即(B,C,H,W),其中,B表示Batch size;C代表通道數;H代表feature map的高;W表示feature map的寬。首先,我們假設經過Unfolder處理之后的size為(B,h,w)。然后我們需要計算h(即輸出的高),計算公式如下所示:
這里是引用舉個栗子:假設輸入通道數為3,kernel size為(2,2),圖片最常見的通道數為3(所以我們拿來舉例),經過Unfolder方法后,輸出的高變為322=12,即輸出的H為12。
計算完成之后,我們需要計算w,計算公式如下所示:
其中,d代表的是空間的所有維度數,例如空間維度為(H,W),則d=2。下面通過舉例,我們來計算輸出的w。
舉個栗子:如果輸入的H、W分別為5,kernel size為2,則輸出的w為
4*4=16,故最終的輸出size為[2,12,16]。
三、案例
import torch
import torch
.nn
as nn
if __name__
== '__main__':x
= torch
.randn
(1, 3, 5, 5)print(x
)unfold
= nn
.Unfold
(kernel_size
=3)output
= unfold
(x
)print(output
, output
.size
())
tensor
([[[[ 0.4009, 0.6350, -0.5197, 0.8148, -0.7235
],
[-1.2102, 0.4621, -0.3421, -0.9261, -2.8376
],
[-1.5553, 0.1713, 0.6820, -2.0880, -0.0204
],
[ 1.1419, -0.4881, -0.9510, -0.0367, -0.8108
],
[ 0.1459, -0.4568, 1.0039, -1.2385, -1.4467
]],
[[-0.9973, -0.7601, -0.2161, 1.2120, -0.3036
],
[-0.7279, 0.0833, -0.8886, -0.9168, 0.7503
],
[-0.6748, 0.7064, 0.6903, -1.0447, 0.8688
],
[-0.5230, -1.2308, -0.3932, 1.2521, -0.2523
],
[-0.3930, 0.6452, 0.1690, 0.3744, 0.2015
]],
[[ 0.6403, 1.3915, -1.9529, 0.2899, -0.8897
],
[-0.1720, 1.0843, -1.0177, -1.7480, -0.5217
],
[-0.9648, -0.0867, -0.2926, 0.3010, 0.3192
],
[ 0.1181, -0.2218, 0.0766, 0.5914, -0.8932
],
[-0.4508, -0.3964, 1.1163, 0.6776, -0.8948
]]]])
tensor
([[[ 0.4009, 0.6350, -0.5197, -1.2102, 0.4621, -0.3421, -1.5553,0.1713, 0.6820
],
[ 0.6350, -0.5197, 0.8148, 0.4621, -0.3421, -0.9261, 0.1713,0.6820, -2.0880
],
[-0.5197, 0.8148, -0.7235, -0.3421, -0.9261, -2.8376, 0.6820,-2.0880, -0.0204
],
[-1.2102, 0.4621, -0.3421, -1.5553, 0.1713, 0.6820, 1.1419,-0.4881, -0.9510
],
[ 0.4621, -0.3421, -0.9261, 0.1713, 0.6820, -2.0880, -0.4881,-0.9510, -0.0367
],
[-0.3421, -0.9261, -2.8376, 0.6820, -2.0880, -0.0204, -0.9510,-0.0367, -0.8108
],
[-1.5553, 0.1713, 0.6820, 1.1419, -0.4881, -0.9510, 0.1459,-0.4568, 1.0039
],
[ 0.1713, 0.6820, -2.0880, -0.4881, -0.9510, -0.0367, -0.4568,1.0039, -1.2385
],
[ 0.6820, -2.0880, -0.0204, -0.9510, -0.0367, -0.8108, 1.0039,-1.2385, -1.4467
],
[-0.9973, -0.7601, -0.2161, -0.7279, 0.0833, -0.8886, -0.6748,0.7064, 0.6903
],
[-0.7601, -0.2161, 1.2120, 0.0833, -0.8886, -0.9168, 0.7064,0.6903, -1.0447
],
[-0.2161, 1.2120, -0.3036, -0.8886, -0.9168, 0.7503, 0.6903,-1.0447, 0.8688
],
[-0.7279, 0.0833, -0.8886, -0.6748, 0.7064, 0.6903, -0.5230,-1.2308, -0.3932
],
[ 0.0833, -0.8886, -0.9168, 0.7064, 0.6903, -1.0447, -1.2308,-0.3932, 1.2521
],
[-0.8886, -0.9168, 0.7503, 0.6903, -1.0447, 0.8688, -0.3932,1.2521, -0.2523
],
[-0.6748, 0.7064, 0.6903, -0.5230, -1.2308, -0.3932, -0.3930,0.6452, 0.1690
],
[ 0.7064, 0.6903, -1.0447, -1.2308, -0.3932, 1.2521, 0.6452,0.1690, 0.3744
],
[ 0.6903, -1.0447, 0.8688, -0.3932, 1.2521, -0.2523, 0.1690,0.3744, 0.2015
],
[ 0.6403, 1.3915, -1.9529, -0.1720, 1.0843, -1.0177, -0.9648,-0.0867, -0.2926
],
[ 1.3915, -1.9529, 0.2899, 1.0843, -1.0177, -1.7480, -0.0867,-0.2926, 0.3010
],
[-1.9529, 0.2899, -0.8897, -1.0177, -1.7480, -0.5217, -0.2926,0.3010, 0.3192
],
[-0.1720, 1.0843, -1.0177, -0.9648, -0.0867, -0.2926, 0.1181,-0.2218, 0.0766
],
[ 1.0843, -1.0177, -1.7480, -0.0867, -0.2926, 0.3010, -0.2218,0.0766, 0.5914
],
[-1.0177, -1.7480, -0.5217, -0.2926, 0.3010, 0.3192, 0.0766,0.5914, -0.8932
],
[-0.9648, -0.0867, -0.2926, 0.1181, -0.2218, 0.0766, -0.4508,-0.3964, 1.1163
],
[-0.0867, -0.2926, 0.3010, -0.2218, 0.0766, 0.5914, -0.3964,1.1163, 0.6776
],
[-0.2926, 0.3010, 0.3192, 0.0766, 0.5914, -0.8932, 1.1163,0.6776, -0.8948
]]]) torch.Size
([1, 27, 9
])
覺得寫的不錯的話,歡迎點贊+評論+收藏,這對我幫助很大!
總結
以上是生活随笔為你收集整理的PyTorch基础(13)-- torch.nn.Unfold()方法的全部內容,希望文章能夠幫你解決所遇到的問題。
如果覺得生活随笔網站內容還不錯,歡迎將生活随笔推薦給好友。