U-Net网络模型(添加通道与空间注意力机制)代码---亲测提高精度
                                                            生活随笔
收集整理的這篇文章主要介紹了
                                U-Net网络模型(添加通道与空间注意力机制)代码---亲测提高精度
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.                        
                                U-Net網絡模型(簡單改進版)
這一段時間做項目用到了U-Net網絡模型,但是原始的U-Net網絡還有很大的改良空間,隨手加了點東西:
每次的下采樣的過程中加入了通道注意力和空間注意力(大概就是這樣)
代碼跑出來后,效果比原來的U-Net大概提升了一個點左右,證明是有效的,改動很少,放出代碼:
class ChannelAttentionModule(nn.Module):def __init__(self, channel, ratio=16):super(ChannelAttentionModule, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.shared_MLP = nn.Sequential(nn.Conv2d(channel, channel // ratio, 1, bias=False),nn.ReLU(),nn.Conv2d(channel // ratio, channel, 1, bias=False))self.sigmoid = nn.Sigmoid()def forward(self, x):avgout = self.shared_MLP(self.avg_pool(x))maxout = self.shared_MLP(self.max_pool(x))return self.sigmoid(avgout + maxout)class SpatialAttentionModule(nn.Module):def __init__(self):super(SpatialAttentionModule, self).__init__()self.conv2d = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3)self.sigmoid = nn.Sigmoid()def forward(self, x):avgout = torch.mean(x, dim=1, keepdim=True)maxout, _ = torch.max(x, dim=1, keepdim=True)out = torch.cat([avgout, maxout], dim=1)out = self.sigmoid(self.conv2d(out))return outclass CBAM(nn.Module):def __init__(self, channel):super(CBAM, self).__init__()self.channel_attention = ChannelAttentionModule(channel)self.spatial_attention = SpatialAttentionModule()def forward(self, x):out = self.channel_attention(x) * xout = self.spatial_attention(out) * outreturn out class conv_block(nn.Module):def __init__(self,ch_in,ch_out):super(conv_block,self).__init__()self.conv = nn.Sequential(nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),nn.BatchNorm2d(ch_out),nn.ReLU(inplace=True),nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),nn.BatchNorm2d(ch_out),nn.ReLU(inplace=True))def forward(self,x):x = self.conv(x)return xclass up_conv(nn.Module):def __init__(self,ch_in,ch_out):super(up_conv,self).__init__()self.up = nn.Sequential(nn.Upsample(scale_factor=2),nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),nn.BatchNorm2d(ch_out),nn.ReLU(inplace=True))def forward(self,x):x = self.up(x)return xclass U_Net_v1(nn.Module): #添加了空間注意力和通道注意力def __init__(self,img_ch=3,output_ch=2):super(U_Net_v1,self).__init__()self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)self.Conv1 = conv_block(ch_in=img_ch,ch_out=64) #64self.Conv2 = conv_block(ch_in=64,ch_out=128) #64 128self.Conv3 = conv_block(ch_in=128,ch_out=256) #128 256self.Conv4 = conv_block(ch_in=256,ch_out=512) #256 512self.Conv5 = conv_block(ch_in=512,ch_out=1024) #512 1024self.cbam1 = CBAM(channel=64)self.cbam2 = CBAM(channel=128)self.cbam3 = CBAM(channel=256)self.cbam4 = CBAM(channel=512)self.Up5 = up_conv(ch_in=1024,ch_out=512) #1024 512self.Up_conv5 = conv_block(ch_in=1024, ch_out=512) self.Up4 = up_conv(ch_in=512,ch_out=256) #512 256self.Up_conv4 = conv_block(ch_in=512, ch_out=256) self.Up3 = up_conv(ch_in=256,ch_out=128) #256 128self.Up_conv3 = conv_block(ch_in=256, ch_out=128) self.Up2 = up_conv(ch_in=128,ch_out=64) #128 64self.Up_conv2 = conv_block(ch_in=128, ch_out=64) self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0) #64def forward(self,x):# encoding pathx1 = self.Conv1(x)x1 = self.cbam1(x1) + x1x2 = self.Maxpool(x1)x2 = self.Conv2(x2)x2 = self.cbam2(x2) + x2x3 = self.Maxpool(x2)x3 = self.Conv3(x3)x3 = self.cbam3(x3) + x3x4 = self.Maxpool(x3)x4 = self.Conv4(x4)x4 = self.cbam4(x4) + x4x5 = self.Maxpool(x4)x5 = self.Conv5(x5)# decoding + concat pathd5 = self.Up5(x5)d5 = torch.cat((x4,d5),dim=1)d5 = self.Up_conv5(d5)d4 = self.Up4(d5)d4 = torch.cat((x3,d4),dim=1)d4 = self.Up_conv4(d4)d3 = self.Up3(d4)d3 = torch.cat((x2,d3),dim=1)d3 = self.Up_conv3(d3)d2 = self.Up2(d3)d2 = torch.cat((x1,d2),dim=1)d2 = self.Up_conv2(d2)d1 = self.Conv_1x1(d2)return d1總結
以上是生活随笔為你收集整理的U-Net网络模型(添加通道与空间注意力机制)代码---亲测提高精度的全部內容,希望文章能夠幫你解決所遇到的問題。
 
                            
                        - 上一篇: 初学者学Python必看的几个练手小项目
- 下一篇: C语言scanf为啥有时候要输入两次(解
