CNN卷积层图像和矩阵转换函数
生活随笔
收集整理的這篇文章主要介紹了
CNN卷积层图像和矩阵转换函数
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
將圖像維度轉換為矩陣,和將矩陣轉換為圖像維度。深度學習框架都會有這樣的功能。
import numpy as npdef im2col(input_data, filter_h, filter_w, stride=1, pad=0):"""Parameters----------input_data : 由(數據量, 通道, 高, 長)的4維數組構成的輸入數據filter_h : 濾波器的高filter_w : 濾波器的長stride : 步幅pad : 填充Returns-------col : 2維數組"""N, C, H, W = input_data.shapeout_h = (H + 2*pad - filter_h)//stride + 1out_w = (W + 2*pad - filter_w)//stride + 1img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))for y in range(filter_h):y_max = y + stride*out_hfor x in range(filter_w):x_max = x + stride*out_wcol[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)return coldef col2im(col, input_shape, filter_h, filter_w, stride=1, pad=0):"""Parameters----------col :input_shape : 輸入數據的形狀(例:(10, 1, 28, 28))filter_h :filter_wstridepadReturns-------"""N, C, H, W = input_shapeout_h = (H + 2*pad - filter_h)//stride + 1out_w = (W + 2*pad - filter_w)//stride + 1col = col.reshape(N, out_h, out_w, C, filter_h, filter_w).transpose(0, 3, 4, 5, 1, 2)img = np.zeros((N, C, H + 2*pad + stride - 1, W + 2*pad + stride - 1))for y in range(filter_h):y_max = y + stride*out_hfor x in range(filter_w):x_max = x + stride*out_wimg[:, :, y:y_max:stride, x:x_max:stride] += col[:, :, y, x, :, :]return img[:, :, pad:H + pad, pad:W + pad]x1=np.random.rand(1,3,7,7) col1=im2col(x1,5,5,stride=1,pad=0) print (col1.shape) #圖像四維轉成二維矩陣 x11=col2im(col1,(1,3,7,7),5,5,stride=1,pad=0) print (x11.shape)#二維矩陣轉換成圖像四維格式x2=np.random.rand(10,3,7,7)#批處理10個樣本 col2=im2col(x2,5,5,stride=1,pad=0) print (col2.shape)#二維,每個樣本9行 x22=col2im(col2,(10,3,7,7),5,5,stride=1,pad=0) print (x22.shape)結果:
(9, 75) (1, 3, 7, 7) (90, 75) (10, 3, 7, 7)卷積層中的應用:
class Convolution:def __init__(self, W, b, stride=1, pad=0):self.W = Wself.b = bself.stride = strideself.pad = pad# 中間數據(backward時使用)self.x = None self.col = Noneself.col_W = None# 權重和偏置參數的梯度self.dW = Noneself.db = Nonedef forward(self, x):FN, C, FH, FW = self.W.shapeN, C, H, W = x.shapeout_h = 1 + int((H + 2*self.pad - FH) / self.stride)out_w = 1 + int((W + 2*self.pad - FW) / self.stride)col = im2col(x, FH, FW, self.stride, self.pad)col_W = self.W.reshape(FN, -1).Tout = np.dot(col, col_W) + self.bout = out.reshape(N, out_h, out_w, -1).transpose(0, 3, 1, 2)self.x = xself.col = colself.col_W = col_Wreturn outdef backward(self, dout):FN, C, FH, FW = self.W.shapedout = dout.transpose(0,2,3,1).reshape(-1, FN)self.db = np.sum(dout, axis=0)self.dW = np.dot(self.col.T, dout)self.dW = self.dW.transpose(1, 0).reshape(FN, C, FH, FW)dcol = np.dot(dout, self.col_W.T)dx = col2im(dcol, self.x.shape, FH, FW, self.stride, self.pad)return dx重點關注:計算卷積輸入輸出的公式
out_h = 1 + int((H + 2*self.pad - FH) / self.stride)
out_w = 1 + int((W + 2*self.pad - FW) / self.stride)
假設輸入大小為(H,W),卷積核大小為(FH,FW),輸出大小為(OH,OW),填充pading為P,步幅stride為S,則輸出大小可通過下面兩個公式計算:
池化層中的應用:池化層不改變通道大小、沒有參數要學習、對微小變化更具有魯棒性。
class Pooling:def __init__(self, pool_h, pool_w, stride=1, pad=0):self.pool_h = pool_hself.pool_w = pool_wself.stride = strideself.pad = padself.x = Noneself.arg_max = Nonedef forward(self, x):N, C, H, W = x.shapeout_h = int(1 + (H - self.pool_h) / self.stride)out_w = int(1 + (W - self.pool_w) / self.stride)col = im2col(x, self.pool_h, self.pool_w, self.stride, self.pad)col = col.reshape(-1, self.pool_h*self.pool_w)arg_max = np.argmax(col, axis=1)out = np.max(col, axis=1)out = out.reshape(N, out_h, out_w, C).transpose(0, 3, 1, 2)self.x = xself.arg_max = arg_maxreturn outdef backward(self, dout):dout = dout.transpose(0, 2, 3, 1)pool_size = self.pool_h * self.pool_wdmax = np.zeros((dout.size, pool_size))dmax[np.arange(self.arg_max.size), self.arg_max.flatten()] = dout.flatten()dmax = dmax.reshape(dout.shape + (pool_size,)) dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1)dx = col2im(dcol, self.x.shape, self.pool_h, self.pool_w, self.stride, self.pad)return dx?
總結
以上是生活随笔為你收集整理的CNN卷积层图像和矩阵转换函数的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 神经网络最优化方法比较(代码理解)
- 下一篇: Ubuntu下浏览Json文件