【TensorFlow-windows】扩展层之STN
前言
讀TensorFlow相關(guān)代碼看到了STN的應(yīng)用,搜索以后發(fā)現(xiàn)可替代池化,增強(qiáng)網(wǎng)絡(luò)對(duì)圖像變換(旋轉(zhuǎn)、縮放、偏移等)的抗干擾能力,簡(jiǎn)單說(shuō)就是提高卷積神經(jīng)網(wǎng)絡(luò)的空間不變性。
國(guó)際慣例,參考博客:
理解Spatial Transformer Networks
github-STN
Deep Learning Paper Implementations: Spatial Transformer Networks - Part I
Deep Learning Paper Implementations: Spatial Transformer Networks - Part II
將STN加入網(wǎng)絡(luò)訓(xùn)練的一個(gè)關(guān)于圖像隱寫術(shù)的案例:StegaStamp
理論
圖像變換
因?yàn)閳D像的本質(zhì)就是矩陣,那么圖像變換就是矩陣變換,先復(fù)習(xí)一下與圖像相關(guān)的矩陣變換。假設(shè)MMM為變換矩陣,NNN為圖像,為了簡(jiǎn)化表達(dá),設(shè)MMM的維度是(2,2)(2,2)(2,2),NNN代表像素點(diǎn)坐標(biāo),則維度是(2,1)(2,1)(2,1),以下操作均為對(duì)像素位置的調(diào)整操作,而非對(duì)像素值的操作。
-
縮放
M×N=[p00q]×[xy]=[pxqy]M\times N=\begin{bmatrix} p&0\\ 0&q \end{bmatrix}\times \begin{bmatrix} x\\y \end{bmatrix}=\begin{bmatrix} px\\qy \end{bmatrix} M×N=[p0?0q?]×[xy?]=[pxqy?] -
旋轉(zhuǎn):繞原點(diǎn)順時(shí)針旋轉(zhuǎn)θ\thetaθ角
M×N=[cos?θ?sin?θsin?θcos?θ]×[xy]=[xcos?θ?ysin?θxsin?θ+ycos?θ]M\times N=\begin{bmatrix} \cos\theta&-\sin\theta\\ \sin\theta&\cos\theta \end{bmatrix}\times \begin{bmatrix} x\\y \end{bmatrix}=\begin{bmatrix} x\cos\theta-y\sin\theta\\x\sin\theta+y\cos\theta \end{bmatrix} M×N=[cosθsinθ??sinθcosθ?]×[xy?]=[xcosθ?ysinθxsinθ+ycosθ?] -
錯(cuò)切(shear):類似于將字的正體變成斜體
M×N=[1mn1]×[xy]=[x+myy+nx]M\times N=\begin{bmatrix} 1&m\\ n&1 \end{bmatrix}\times \begin{bmatrix} x\\y \end{bmatrix}=\begin{bmatrix} x+my\\y+nx \end{bmatrix} M×N=[1n?m1?]×[xy?]=[x+myy+nx?] -
平移:要轉(zhuǎn)換為齊次矩陣做平移
M′×N′=[10a01b]×[xy1]=[x+ay+b]M'\times N'=\begin{bmatrix} 1&0&a\\ 0&1&b \end{bmatrix}\times \begin{bmatrix} x\\y\\1 \end{bmatrix}=\begin{bmatrix} x+a\\y+b \end{bmatrix} M′×N′=[10?01?ab?]×???xy1????=[x+ay+b?]
盜用參考博客的圖解就是:
注意,我們進(jìn)行多次變換的時(shí)候有多個(gè)變換矩陣,如果每次計(jì)算一個(gè)變換會(huì)比較耗時(shí),參考矩陣的乘法特性,我們可以先將變換矩陣相乘,得到一個(gè)完整的矩陣代表所有變換,最后乘以圖像,就可將圖像按照組合變換順序得到變換圖像。這個(gè)代表一系列的變換的矩陣通常表示為:
M=[abcdef]M=\begin{bmatrix} a&b&c\\d&e&f \end{bmatrix} M=[ad?be?cf?]
因?yàn)橹苯佑?jì)算位置的值,很可能得到小數(shù),比如將(3,3)(3,3)(3,3)的圖像放大到(9,9)(9,9)(9,9),也就是放大3倍,那么新圖像(8,8)(8,8)(8,8)位置的像素就是原圖(8/3,8/3)(8/3,8/3)(8/3,8/3)位置的像素,但是像素位置不可能是小數(shù),因而出現(xiàn)了解決方案:雙線性插值
雙線性插值
先復(fù)習(xí)一下線性插值,直接去看之前寫的這篇博客,知道(x1,y1)(x_1,y_1)(x1?,y1?)與(x2,y2)(x_2,y_2)(x2?,y2?),求(x1,x2)區(qū)間內(nèi)的點(diǎn)(x_1,x_2)區(qū)間內(nèi)的點(diǎn)(x1?,x2?)區(qū)間內(nèi)的點(diǎn)xxx位置的y值,結(jié)果是:
y=x?x2x1?x2y1+x?x1x2?x1y2y=\frac{x-x_2}{x_1-x_2}y_1+\frac{x-x_1}{x_2-x_1}y_2 y=x1??x2?x?x2??y1?+x2??x1?x?x1??y2?
可以發(fā)現(xiàn)線性插值是針對(duì)一維坐標(biāo)的,即給xxx求yyy,但是雙線性插值是針對(duì)二維坐標(biāo)點(diǎn)的,即給(x,y)(x,y)(x,y)求值QQQ。方法是先在xxx軸方向做兩次線性插值,再在yyy軸上做一次線性插值。
設(shè)需要求(x,y)(x,y)(x,y)處的值,我們需要預(yù)先知道其附近四個(gè)坐標(biāo)點(diǎn)及其對(duì)應(yīng)的值,如:
- (x,y)(x,y)(x,y)左下角坐標(biāo)為(x1,y1)(x_1,y_1)(x1?,y1?),值為Q1Q_1Q1?
- (x,y)(x,y)(x,y)右下角坐標(biāo)為(x2,y1)(x_2,y_1)(x2?,y1?), 值為Q2Q_2Q2?
- (x,y)(x,y)(x,y)左上角坐標(biāo)為(x1,y2)(x_1,y_2)(x1?,y2?), 值為Q3Q_3Q3?
- (x,y)(x,y)(x,y)右上角坐標(biāo)為(x2,y2)(x_2,y_2)(x2?,y2?),值為Q4Q_4Q4?
首先對(duì)下面的(x1,y1)(x_1,y_1)(x1?,y1?)和(x2,y1)(x_2,y_1)(x2?,y1?)做線性插值,方法是把它兩看做一維坐標(biāo)(x1,Q1)(x_1,Q_1)(x1?,Q1?)和(x2,Q2)(x_2,Q2)(x2?,Q2),得到:
P1=x?x2x1?x2Q1+x?x1x2?x1Q2P_1=\frac{x-x_2}{x_1-x_2}Q_1+\frac{x-x_1}{x_2-x_1}Q_2 P1?=x1??x2?x?x2??Q1?+x2??x1?x?x1??Q2?
同理得到上面的兩個(gè)坐標(biāo)(x1,y2)(x_1,y_2)(x1?,y2?)與(x2,y2)(x_2,y_2)(x2?,y2?)的插值結(jié)果,也就是(x1,Q3)(x_1,Q_3)(x1?,Q3?)和(x2,Q4)(x_2,Q_4)(x2?,Q4?)的線性插值結(jié)果:
P2=x?x2x1?x2Q3+x?x1x2?x1Q4P_2=\frac{x-x_2}{x_1-x_2}Q_3+\frac{x-x_1}{x_2-x_1}Q_4 P2?=x1??x2?x?x2??Q3?+x2??x1?x?x1??Q4?
再對(duì)(y1,P1)(y_1,P_1)(y1?,P1?)和(y2,P2)(y_2,P_2)(y2?,P2?)做線性插值:
P=x?y2y1?y2P1+y?y1y2?y1P2P=\frac{x-y_2}{y_1-y_2}P_1+\frac{y-y_1}{y_2-y_1}P_2 P=y1??y2?x?y2??P1?+y2??y1?y?y1??P2?
解決上面圖像變換的問(wèn)題,假設(shè)變換后的坐標(biāo)不是整數(shù),那么就選擇這個(gè)坐標(biāo)四個(gè)角的坐標(biāo)的雙線性插值的結(jié)果,比如(8/3,8/3)(8/3,8/3)(8/3,8/3)位置的像素就是(2,2),(3,2),(2,3),(3,3)(2,2),(3,2),(2,3),(3,3)(2,2),(3,2),(2,3),(3,3)位置像素的雙線性插值結(jié)果。
總之就是先計(jì)算目標(biāo)圖像像素在源圖像中的位置,然后得到源圖像位置是小數(shù),針對(duì)小數(shù)位置的四個(gè)頂點(diǎn)做雙線性插值。
上面就是STN做的工作,也可以發(fā)現(xiàn)STN接受的參數(shù)就是6個(gè),接下來(lái)看看為什么STN能提高卷積網(wǎng)絡(luò)的旋轉(zhuǎn)、平移、縮放不變性。
總結(jié)一下:
圖像處理中的仿射變換通常包含三個(gè)步驟:
- 創(chuàng)建由(x,y)(x,y)(x,y)組成的采樣網(wǎng)格,比如(400,400)(400,400)(400,400)的灰度圖對(duì)應(yīng)創(chuàng)建一個(gè)同樣大小的網(wǎng)格。
- 將變換矩陣應(yīng)用到采樣網(wǎng)格上
- 使用插值技術(shù)從原圖中計(jì)算變換圖的像素值
池化
強(qiáng)行翻譯一波這篇文章關(guān)于池化的部分,建議看原文,這里摘取個(gè)人認(rèn)為重要部分:
池化在某種程度上增加了模型的空間不變性,因?yàn)槌鼗且环N下采樣技術(shù),減少了每層特征圖的空間大小,極大減少了參數(shù)數(shù)量,提高了運(yùn)算速度。
池化提供的不變性確切來(lái)說(shuō)是什么?池化的思路是將一個(gè)圖像切分成多個(gè)單元,這些復(fù)雜單元被池化以后得到了可以描述輸出的簡(jiǎn)單的單元。比如有3張不同方向的數(shù)字7的圖像,池化是通過(guò)圖像上的小網(wǎng)格來(lái)檢測(cè)7,不受7的位置影響,因?yàn)橥ㄟ^(guò)聚集的像素值,我們得到的信息大致一樣。個(gè)人覺得,作者的本意是單看小網(wǎng)格,是有很多一樣的塊。
池化的缺點(diǎn)在于:
- 丟失了75%的信息(應(yīng)該是(2,2)(2,2)(2,2)的最大值池化方法),意味著我們一定丟了是精確的位置信息。有人會(huì)問(wèn),這樣可以增加空間魯棒性哇。然而,對(duì)于視覺識(shí)別人物,空間信息是非常重要的。比如分類貓的時(shí)候,知道貓的胡須的位置相對(duì)于鼻子的位置有可能很重要,但是如果使用最大池化,可能丟失了這個(gè)信息。
- 池化是局部的且預(yù)定義好的。一個(gè)小的接受域,池化操作的影響僅僅是針對(duì)更深的網(wǎng)絡(luò)層(越深感受野越大),也就是中間的特征圖可能受到嚴(yán)重的輸入失真的影響。我們不能任意增加接受域,這樣會(huì)過(guò)度下采樣。
主要結(jié)論就是卷積網(wǎng)絡(luò)對(duì)于相對(duì)大的輸入失真不具有不變性。
The pooling operation used in convolutional neural networks is a big mistake and the fact that it works so well is a disaster. (Geoffrey Hinton, Reddit AMA)STN理論
STN的全稱是Spatial Transformer Networks,空間變換網(wǎng)絡(luò)。時(shí)空變換機(jī)制就是通過(guò)給CNN提供顯式的空間變換能力,以解決上述池化出現(xiàn)的問(wèn)題。有三種特性:
- Modular:STN能夠被插入到網(wǎng)絡(luò)的任意地方,僅需很小的調(diào)整
- differentiable:STN可以通過(guò)反向傳播訓(xùn)練
- dynamic:STN是對(duì)每個(gè)輸入樣本的一個(gè)特征圖做空間變換,而池化是針對(duì)所有樣本。
上圖是STN網(wǎng)絡(luò)的主要框架。所以到底什么是空間變換?通過(guò)結(jié)構(gòu)圖可發(fā)現(xiàn)模型包含三部分:localisation network、grid generator、sampler。
Localisation Network
主要是提取被應(yīng)用到輸入特征圖上的仿射變換的參數(shù)θ\thetaθ,網(wǎng)絡(luò)結(jié)構(gòu)是:
- 輸入:大小為(H,W,C)(H,W,C)(H,W,C)的特征圖UUU
- 輸出:大小為(6,1)(6,1)(6,1)的變換矩陣θ\thetaθ
- 結(jié)構(gòu):全連接或者卷積
Parametrised Sampling Grid
輸出參數(shù)化的采樣網(wǎng)格,是一系列的點(diǎn),每個(gè)輸入特征圖能夠產(chǎn)生期望的變換輸出。
具體就是:網(wǎng)格生成器首先產(chǎn)生于輸入圖像UUU大小相同的標(biāo)準(zhǔn)網(wǎng)格,然后將仿射變換應(yīng)用到網(wǎng)格。公式表達(dá)即,假設(shè)輸入圖的索引是(xt,yt)(x^t,y^t)(xt,yt),將θ\thetaθ代表的變換應(yīng)用到坐標(biāo)上得到新的坐標(biāo):
[xsys]=[θ1θ2θ3θ4θ5θ6]×[xtyt1]\begin{bmatrix} x^s\\y^s \end{bmatrix}=\begin{bmatrix} \theta_1&\theta_2&\theta_3\\\theta_4&\theta_5&\theta_6 \end{bmatrix}\times\begin{bmatrix} x^t\\y^t\\1 \end{bmatrix} [xsys?]=[θ1?θ4??θ2?θ5??θ3?θ6??]×???xtyt1????
Differentiable Image Sampling
依據(jù)輸入特征圖和參數(shù)化采樣網(wǎng)格,我們可以利用雙線性插值方法獲得輸出特征圖。注意,這一步我們可以通過(guò)制定采樣網(wǎng)格的大小執(zhí)行上采樣或者下采樣,很像池化。
左圖使用了單位變換,右圖使用了旋轉(zhuǎn)的仿射變換。
【注】因?yàn)殡p線性插值是可微的,所以STN可以作為訓(xùn)練網(wǎng)絡(luò)的一部分。
代碼
利用STN前向過(guò)程做圖像變換
GitHub上有作者提供了源碼,也可以用pip直接安裝。
代碼直接貼了,稍微改了一點(diǎn)點(diǎn):
導(dǎo)入包
import tensorflow as tf import cv2 import numpy as npfrom stn import spatial_transformer_network as transformer讀入圖像,轉(zhuǎn)換為四維矩陣:
img=cv2.imread('test_img.jpg') img=np.array(img) H,W,C=img.shape img=img[np.newaxis,:] print(img.shape)旋轉(zhuǎn)變換的角度
degree=np.deg2rad(45) theta=np.array([[np.cos(degree),-np.sin(degree),0],[np.sin(degree),np.cos(degree),0] ])構(gòu)建網(wǎng)絡(luò)結(jié)構(gòu)
x=tf.placeholder(tf.float32,shape=[None,H,W,C]) with tf.variable_scope('spatial_transformer'):theta=theta.astype('float32')theta=theta.flatten()loc_in=H*W*C #輸入維度loc_out=6 #輸出維度W_loc=tf.Variable(tf.zeros([loc_in,loc_out]),name='W_loc')b_loc=tf.Variable(initial_value=theta,name='b_loc')#運(yùn)算fc_loc=tf.matmul(tf.zeros([1,loc_in]),W_loc)+b_loch_trans=transformer(x,fc_loc)把圖像喂進(jìn)去,并顯示圖像
init=tf.global_variables_initializer() with tf.Session() as sess:sess.run(init)y=sess.run(h_trans,feed_dict={x:img})print(y.shape)y=np.squeeze(np.array(y,dtype=np.uint8)) print(y.shape) cv2.imshow('trasformedimg',y) cv2.waitKey() cv2.destroyAllWindows()重點(diǎn)關(guān)注網(wǎng)絡(luò)構(gòu)建:
權(quán)重w_loc是全零的大小為(HWC,6)(HWC,6)(HWC,6)的矩陣,偏置b_loc是大小為(1,6)(1,6)(1,6)的向量,這樣經(jīng)過(guò)運(yùn)算
fc_loc=tf.matmul(tf.zeros([1,loc_in]),W_loc)+b_loc得到的其實(shí)就是我們指定的旋轉(zhuǎn)角度對(duì)應(yīng)的6維變換參數(shù),最后利用變換函數(shù)transformer執(zhí)行此變換就行了。
將STN加入網(wǎng)絡(luò)中訓(xùn)練
主要參考StegaStamp作者的寫法,這里做STN部分加入網(wǎng)絡(luò)的方法:
輸入一張圖片到如下網(wǎng)絡(luò)結(jié)構(gòu)(Keras網(wǎng)絡(luò)結(jié)構(gòu)搭建語(yǔ)法):
得到(1,128)(1,128)(1,128)維的向量,其實(shí)用一個(gè)網(wǎng)絡(luò)替換上面前向計(jì)算中的loc_in,目的是為了得到二維圖像對(duì)應(yīng)的一維信息
后面的過(guò)程就和前向計(jì)算一樣了,定義權(quán)重和偏置:
然后利用一維信息得到圖像變換所需的6個(gè)值:
x = tf.matmul(stn_params, self.W_fc1) + self.b_fc1最后利用STN庫(kù)將變換應(yīng)用到圖像中,得到下一層網(wǎng)絡(luò)結(jié)構(gòu)的輸入
transformed_image = stn_transformer(image, x, [self.height, self.width, 3])可以看出,STN加入到網(wǎng)絡(luò)后,訓(xùn)練參數(shù)有:
- 二維圖像到一維特征向量的卷積+全連接網(wǎng)絡(luò)的權(quán)重和偏置
- 一維向量到6維變換參數(shù)的權(quán)重和偏置
總結(jié)
通篇就是對(duì)池化方案的改變,使用STN能夠增加網(wǎng)絡(luò)的變換不變性,比池化的效果更好。
代碼:
鏈接:https://pan.baidu.com/s/1kDs9T-Mf1F_mzQyvslcROA
提取碼:crdu
總結(jié)
以上是生活随笔為你收集整理的【TensorFlow-windows】扩展层之STN的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 《恋与制作人》周棋洛生日快乐 协奏情意绵
- 下一篇: 跨越星汉相会七夕 《终结者2》结婚系统浪