Pytorch中的grid_sample算子功能解析
? ? ? ? ?pytorch中的grid_sample是一種特殊的采樣算法。
調用接口為:
torch.nn.functional.grid_sample(input,grid,mode='bilinear',padding_mode='zeros',align_corners=None)。
???????? input參數是輸入特征圖tensor,也就是特征圖,可以是四維或者五維張量,以四維形式為例(N,C,Hin,Win),N可以理解為Batch_size,C可以理解為通道數,Hin和Win也就是特征圖高和寬。
???????? grid包含輸出特征圖特征圖的格網大小以及每個格網對應到輸入特征圖的采樣點位,對應四維input,其張量形式為(N,Hout,Wout,2),其中最后一維大小必須為2,如果輸入為五維張量,那么最后一維大小必須為3。為什么最后一維必須為2或者3?因為grid的最后一個維度實際上代表一個坐標(x,y)或者(xy,z),對應到輸入特征圖的二維或三維特征圖的坐標維度,xy取值范圍一般為[-1,1],該范圍映射到輸入特征圖的全圖。
???????? mode為選擇采樣方法,有三種內插算法可選,分別是'bilinear'雙線性差值、'nearest'最鄰近插值、'bicubic' 雙三次插值。
???????? padding_mode為填充模式,即當(x,y)取值超過輸入特征圖采樣范圍,返回一個特定值,有'zeros' 、 'border' 、 'reflection'三種可選,一般用zero。
???????? align_corners為bool類型,指設定特征圖坐標與特征值對應方式,設定為TRUE時,特征值位于像素中心。
???????? 要理解grid_sample是如何工作的,最好就是進行簡單的復現。假設輸入shape為(N,C,H,W),grid的shape設定為(N,H,W,2),以雙線性差值為例進行處理。首先根據input和grid設定,輸出特征圖tensor的shape為(N,C,H,W),輸出特征圖上每一個cell上的值由grid最后一維(x,y)確定。那么如何計算輸出tensor上每一個點的值?首先,通過(x,y)找到輸入特征圖上的采樣位置,由于xy取值范圍為[-1,1],為了便于計算,先將xy取值范圍調整為[0,1]。通過(w-1)*(x+1)/2、(wh-1)*(y+1)/2將xy映射為輸入特征圖的具體坐標位置。將xy映射到特征圖實際坐標后,取該坐標附近四個角點特征值,通過四個特征值坐標與采樣點坐標相對關系進行雙線性插值,得到采樣點的值。
注意:xy映射后的坐標可能是輸入特征圖上任意位置。假設輸出特征圖上(2,2)坐標位置上的值采樣位置可能為輸入特征圖上(3,4)位置,xy越小越靠近輸入特征圖左上角,越大則越靠近右下角。
?????????基于上面的思路,可以進行一個簡單的自定義實現。根據指定shape生成input和grid,使用pytorch中的grid_sample算子生成output。之后取grid中的第一個位置中的xy,根據xy從input中通過雙線性插值計算出output第一個位置的值。
import torch import numpy as np def grid_sample(input, grid):N, C, H_in, W_in = input.shapeN, H_out, W_out, _ = grid.shapeoutput = np.random.random((N,C,H,W))for i in range(N):for j in range(C):for k in range(H_out):for l in range(W_out):param = [0.0, 0.0]param[0] = (W_in - 1) * (grid[i][k][l][0] + 1) / 2param[1] = (H_in - 1) * (grid[i][k][l][1] + 1) / 2x0 = int(param[0])x1 = x0 + 1y0 = int(param[1])y1 = y0 + 1param[0] -= x0param[1] -= y0left_top = input[i][j][y0][x0] * (1 - param[0]) * (1 - param[1])left_bottom = input[i][j][y1][x0] * (1 - param[0]) * param[1]right_top = input[i][j][y0][x1] * param[0] * (1 - param[1])right_bottom = input[i][j][y1][x1] * param[0] * param[1]result = left_bottom + left_top + right_bottom + right_topoutput[i][j][k][l] = resultreturn outputN, C, H, W = 1, 1, 4, 4 input = np.random.random((N,C,H,W)) grid = np.random.random((N,H,W,2)) out = grid_sample(input, grid) print(f'自定義實現輸出結果:\n{out}') input = torch.from_numpy(input) grid = torch.from_numpy(grid) output = torch.nn.functional.grid_sample(input,grid,mode='bilinear', padding_mode='zeros',align_corners=True) print(f'grid_sample輸出結果:\n{output}')運行結果:
?????????從輸出結果上看,與pytorch基本一致,由于僅僅做簡單驗證,這里沒有對超出[-1,1]范圍的xy值做處理,只能處理四維input,五維input的實現思路與這里基本一致。
? ? ? ? 考慮到(x,y)取值范圍可能越界,pytorch中的padding_mode設置就是對(x,y)落在輸入特征圖外邊緣情況進行處理,一般設置'zero',也就是對靠近輸入特征圖范圍以外的采樣點進行0填充,如果不進行處理顯然會造成索引越界。要解決(x,y)越界問題,可以進行如下修改:
import torch import numpy as npdef grid_sample(input, grid):N, C, H_in, W_in = input.shapeN, H_out, W_out, _ = grid.shapeoutput = np.random.random((N, C, H_out, W_out))for i in range(N):for j in range(C):for k in range(H_out):for l in range(W_out):x, y = grid[i][k][l][0], grid[i][k][l][1]param = [0.0, 0.0]param[0] = (W_in - 1) * (x + 1) / 2param[1] = (H_in - 1) * (y + 1) / 2x1 = int(param[0] + 1)x0 = x1 - 1y1 = int(param[1] + 1)y0 = y1 - 1param[0] = abs(param[0] - x0)param[1] = abs(param[1] - y0)left_top_value, left_bottom_value, right_top_value, right_bottom_value = 0, 0, 0, 0if 0 <= x0 < W_in and 0 <= y0 < H_in:left_top_value = input[i][j][y0][x0]if 0 <= x1 < W_in and 0 <= y0 < H_in:right_top_value = input[i][j][y0][x1]if 0 <= x0 < W_in and 0 <= y1 < H_in:left_bottom_value = input[i][j][y1][x0]if 0 <= x1 < W_in and 0 <= y1 < H_in:right_bottom_value = input[i][j][y1][x1]left_top = left_top_value * (1 - param[0]) * (1 - param[1])left_bottom = left_bottom_value * (1 - param[0]) * param[1]right_top = right_top_value * param[0] * (1 - param[1])right_bottom = right_bottom_value * param[0] * param[1]result = left_bottom + left_top + right_bottom + right_topoutput[i][j][k][l] = resultreturn outputN, C, H_in, W_in = 1, 1, 4, 4 H_out, W_out = 4, 4 input = np.random.random((N, C, H_in, W_in)) grid = np.random.random((N, H_out, W_out, 2)) grid[0][0][0] = [-1.2, 1.3] out = grid_sample(input, grid) print(f'自定義實現輸出結果:\n{out}') input = torch.from_numpy(input) grid = torch.from_numpy(grid) output = torch.nn.functional.grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=True) print(f'grid_sample輸出結果:\n{output}')? ? ?測試結果:
? ?
總結
以上是生活随笔為你收集整理的Pytorch中的grid_sample算子功能解析的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 如何查看 Python 版本
- 下一篇: AI圣经《深度学习》作者斩获2018年图