基于改进注意力机制的U-Net模型实现及应用(keras框架实现)
1.摘要
上節我們基于U-Net模型設計并實現了在醫學細胞分割上的應用(ISBI 挑戰數據集),并給出了模型的詳細代碼解釋,在上個博客中,我們為了快速訓練U-Net模型對其進行了縮減,將龐大的U-Net的轉換為很小&的結構,導致其準確率才達到75%左右。為了進一步提高U-Net模型在細胞分割上的準確率,本文將主要研究兩個方面:一是基于U-Net的原始模型結構進行改進,引入卷積注意力機制模塊(CBAM)和Focal?Tversky損失函數;二是引入深監督方法(DEEP SUPERVISION)及多尺度輸入作為U-Net模型的原始輸入,該模型被命名為DAMU-Net。為了進一步驗證該模型的性能,我們同樣在ISBI 挑戰數據集上進行實驗,并給出相應的實驗結果。
2.相關技術概述
2.1 Focal Tversky損失函數
醫學影像中存在很多的數據不平衡現象,使用不平衡數據進行訓練會導致嚴重偏向高精度但低召回率(sensitivity)的預測,這是我們不希望的,特別是在醫學應用中,假陰性比假陽性多更難容忍。而Tversky廣義損失函數可以有效解決了三維全卷積深神經網絡訓練中數據不平衡的問題,在精度和召回率之間找到更好的平衡。與Focal loss相似,Focal Tversky Loss著重于通過通過調整超參數α和β,我們可以控制假陽性和假陰性之間的權衡。較大的β會使召回的準確性高于精確度(通過更加強調假陰性)。其公式如下:
2.2??深監督方法
? 所謂深監督(Deep Supervision),就是在深度神經網絡的某些中間隱藏層加了一個輔助的分類器作為一種網絡分支來對主干網絡進行監督的技巧,用來解決深度神經網絡訓練梯度消失和收斂速度過慢等問題。?深監督作為一個訓練trick在2014年就已經通過DSN(Deeply-Supervised Nets)提出來了.
?通常而言,增加神經網絡的深度可以一定程度上提高網絡的表征能力,但隨著深度加深,會逐漸出現神經網絡難以訓練的情況,其中就包括像梯度消失和梯度爆炸等現象。為了更好的訓練深度網絡,人們嘗試給神經網絡的某些層添加一些輔助的分支分類器來解決這個問題。這種輔助的分支分類器能夠起到一種判斷隱藏層特征圖質量好壞的作用。其結構如下:
其中各個模塊含義如下:
?可以看到,圖中在第四個卷積塊之后添加了一個監督分類器作為分支。Conv4輸出的特征圖除了隨著主網絡進入Conv5之外,也作為輸入進入了分支分類器。往往分支與主網絡一起訓練。
3.模型實現
為了在精確性和召回性之間實現進一步的平衡,本文設計實現一種基于卷積注意力機制的U-Net模型,?該體 系結構基于流行的UNet,并將輸入圖像的多尺寸特征張量作為輸入,以便更好的提取局部特征。其模型結構如圖所示: 為了進一步細化模型實驗,我們將分三個步驟實現上述最終模型。3.1 基于卷積注意力機制的U-Net模型
該模型只是單純的將注意力機制引入U-Net模型中,目的是將輸入圖像的低級特征映射中識別相關的空間信息,并將其傳播到解碼階段,以達到真正地提取出積極有效的特征。其具體代碼實現可以查看上篇博客:https://haosen.blog.csdn.net/article/details/117755633。在該博客中有模型的具體結構圖及代碼實現。3.2 基于卷積注意力機制和深監督的U-Net模型
其具體代碼實現可以查看上篇博客:https://haosen.blog.csdn.net/article/details/117756027;
3.3 模型代碼實現
def attn_reg(opt,input_size, lossfxn):img_input = Input(shape=input_size, name='input_scale1')scale_img_2 = AveragePooling2D(pool_size=(2, 2), name='input_scale2')(img_input)scale_img_3 = AveragePooling2D(pool_size=(2, 2), name='input_scale3')(scale_img_2)scale_img_4 = AveragePooling2D(pool_size=(2, 2), name='input_scale4')(scale_img_3)conv1 = UnetConv2D(img_input, 32, is_batchnorm=True, name='conv1')pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)input2 = Conv2D(64, (3, 3), padding='same', activation='relu', name='conv_scale2')(scale_img_2)input2 = concatenate([input2, pool1], axis=3)conv2 = UnetConv2D(input2, 64, is_batchnorm=True, name='conv2')pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)input3 = Conv2D(128, (3, 3), padding='same', activation='relu', name='conv_scale3')(scale_img_3)input3 = concatenate([input3, pool2], axis=3)conv3 = UnetConv2D(input3, 128, is_batchnorm=True, name='conv3')pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)input4 = Conv2D(256, (3, 3), padding='same', activation='relu', name='conv_scale4')(scale_img_4)input4 = concatenate([input4, pool3], axis=3)conv4 = UnetConv2D(input4, 64, is_batchnorm=True, name='conv4')pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)center = UnetConv2D(pool4, 512, is_batchnorm=True, name='center')g1 = UnetGatingSignal(center, is_batchnorm=True, name='g1')attn1 = AttnGatingBlock(conv4, g1, 128, '_1')up1 = concatenate([Conv2DTranspose(32, (3,3), strides=(2,2), padding='same', activation='relu', kernel_initializer=kinit)(center), attn1], name='up1')g2 = UnetGatingSignal(up1, is_batchnorm=True, name='g2')attn2 = AttnGatingBlock(conv3, g2, 64, '_2')up2 = concatenate([Conv2DTranspose(64, (3,3), strides=(2,2), padding='same', activation='relu', kernel_initializer=kinit)(up1), attn2], name='up2')g3 = UnetGatingSignal(up1, is_batchnorm=True, name='g3')attn3 = AttnGatingBlock(conv2, g3, 32, '_3')up3 = concatenate([Conv2DTranspose(32, (3,3), strides=(2,2), padding='same', activation='relu', kernel_initializer=kinit)(up2), attn3], name='up3')up4 = concatenate([Conv2DTranspose(32, (3,3), strides=(2,2), padding='same', activation='relu', kernel_initializer=kinit)(up3), conv1], name='up4')conv6 = UnetConv2D(up1, 256, is_batchnorm=True, name='conv6')conv7 = UnetConv2D(up2, 128, is_batchnorm=True, name='conv7')conv8 = UnetConv2D(up3, 64, is_batchnorm=True, name='conv8')conv9 = UnetConv2D(up4, 32, is_batchnorm=True, name='conv9')out6 = Conv2D(1, (1, 1), activation='sigmoid', name='pred1')(conv6)out7 = Conv2D(1, (1, 1), activation='sigmoid', name='pred2')(conv7)out8 = Conv2D(1, (1, 1), activation='sigmoid', name='pred3')(conv8)out9 = Conv2D(1, (1, 1), activation='sigmoid', name='final')(conv9)model = Model(inputs=[img_input], outputs=[out6, out7, out8, out9])loss = {'pred1':lossfxn,'pred2':lossfxn,'pred3':lossfxn,'final': losses.tversky_loss}loss_weights = {'pred1':1,'pred2':1,'pred3':1,'final':1}model.compile(optimizer=opt, loss=loss, loss_weights=loss_weights,metrics=[losses.dsc])model.summary()from keras.utils.vis_utils import plot_modelplot_model(model, to_file='model1.png', show_shapes=True)return model模型參數結構圖(點擊觀看)?
4. 實驗結果
| 模型 | ? DES |
| U-Net | 0.878 |
| ATT-U-Net | ? |
| DATT-U-Net | ? |
| DAMU-Net | ? |
?
還有一些結果正在用CPU運行,太慢了.....
?
?
總結
以上是生活随笔為你收集整理的基于改进注意力机制的U-Net模型实现及应用(keras框架实现)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 泡泡屏保
- 下一篇: 一个空间多个php网站,一个空间多个域名