Multi-Scale Boosted Dehazing Network with Dense Feature Fusion笔记和代码
Multi-Scale Boosted Dehazing Network with Dense Feature Fusion筆記和代碼
本篇論文的主要?jiǎng)?chuàng)新點(diǎn)是SOS增強(qiáng)策略和密集特征融合,創(chuàng)新點(diǎn)均是從其他領(lǐng)域進(jìn)行挖掘。
摘要
-
提出了一種基于U-Net結(jié)構(gòu)的具有密集特征融合的多尺度增強(qiáng)去霧網(wǎng)絡(luò)。
-
該方法基于增強(qiáng)反饋和誤差反饋兩種原理進(jìn)行了設(shè)計(jì),并證明了該方法適用于脫霧問題。
-
通過在該模型的解碼器中加入增強(qiáng)-操作-減弱(SOS)的增強(qiáng)策略,開發(fā)了一個(gè)簡單而有效的增強(qiáng)解碼器來逐步恢復(fù)無霧圖像。
-
為了解決在U-Net架構(gòu)中保留空間信息的問題,我們?cè)O(shè)計(jì)了一個(gè)使用反投影反饋方法的密集特征融合模塊。結(jié)果表明,密集特征融合模塊可以同時(shí)彌補(bǔ)高分辨率特征中缺失的空間信息,并利用非相鄰特征。
提出的方法
網(wǎng)絡(luò)結(jié)構(gòu)
網(wǎng)絡(luò)分為三部分:編碼器GEnc、增強(qiáng)解碼器GDec和特征恢復(fù)模塊GRes。如圖為網(wǎng)絡(luò)結(jié)構(gòu)圖。
為了逐步恢復(fù)有特征恢復(fù)模塊Gres得到的結(jié)構(gòu)JL,設(shè)計(jì)了基于SOS增強(qiáng)策略的解碼器Gdec。
SOS增強(qiáng)策略
SOS增強(qiáng)策略如下:
Jn+1=g(I + Jn) ? Jn
Jn+1是第n次迭代的預(yù)測結(jié)果,g(?)為去霧操作,I + Jn表示用霧圖I 增強(qiáng)Jn。
如圖是五種不同的增強(qiáng)模塊。為了完整起見,我們還列出了四個(gè)針對(duì)SOS提升模塊的替代方案。擴(kuò)散[44]和扭曲[6]方案可以用于設(shè)計(jì)增強(qiáng)模塊,如圖(a)和圖(b).所示它們可以分別表述為jn = Gnθn ((jn+1) ↑2), 和jn=Gnθn(in?(jn+1)↑2)+(jn+1)↑2。
由于(5)和(6)中的細(xì)化單元沒有充分利用中的特征,與上采樣特征(jn+1)↑2相比,我們采用了SOS結(jié)構(gòu)和空間信息。
密集特征融合模塊
如模型結(jié)構(gòu)圖所示,在每個(gè)級(jí)別上都引入了兩個(gè)DFF模塊,一個(gè)在編碼器中的殘差組之前,另一個(gè)在解碼器中的SOS增強(qiáng)模塊之后。編碼器/解碼器中增強(qiáng)的DFF輸出直接連接到編碼器/解碼器中的所有以下DFF模塊,以進(jìn)行特征融合。
與其他采樣和串聯(lián)融合方法相比,該模塊由于其反饋機(jī)制,可以更好地從后續(xù)層的高分辨率特征中提取高頻信息。通過逐步將這些差異融合回降采樣的潛在特征中,可以彌補(bǔ)缺失的空間信息。另一方面,該模塊可以利用之前所有的高級(jí)特征,作為一種糾錯(cuò)反饋機(jī)制來改進(jìn)增強(qiáng)的特征,以獲得更好的結(jié)果。
密集特征融合模塊如圖所示。
代碼閱讀
MSBDN-RDFF.py
這個(gè)文件定義的是整個(gè)模型的網(wǎng)絡(luò)結(jié)構(gòu)。文章中說明了所有圖像均為3* 256 *256,我通過輸入一個(gè)對(duì)應(yīng)大小的全一張量對(duì)網(wǎng)絡(luò)中各個(gè)模塊的輸出進(jìn)行觀察。根據(jù)論文中的網(wǎng)絡(luò)結(jié)構(gòu)圖,我通過print方法,對(duì)其中的模塊進(jìn)行標(biāo)記并輸出經(jīng)過每個(gè)網(wǎng)絡(luò)層時(shí),特征圖的形狀和通道數(shù)。例如:
print('Residual Group', x.shape)代碼(帶各個(gè)模塊輸出):
import torch import numpy as np import torch.nn as nn import torch.nn.functional as F from networks.base_networks import Encoder_MDCBlock1, Decoder_MDCBlock1 def make_model(args, parent=False):return Net()class make_dense(nn.Module):def __init__(self, nChannels, growthRate, kernel_size=3):super(make_dense, self).__init__()self.conv = nn.Conv2d(nChannels, growthRate, kernel_size=kernel_size, padding=(kernel_size-1)//2, bias=False)def forward(self, x):out = F.relu(self.conv(x))out = torch.cat((x, out), 1)return out# Residual dense block (RDB) architecture class RDB(nn.Module):def __init__(self, nChannels, nDenselayer, growthRate, scale = 1.0):super(RDB, self).__init__()nChannels_ = nChannelsself.scale = scalemodules = []for i in range(nDenselayer):modules.append(make_dense(nChannels_, growthRate))nChannels_ += growthRateself.dense_layers = nn.Sequential(*modules)self.conv_1x1 = nn.Conv2d(nChannels_, nChannels, kernel_size=1, padding=0, bias=False)def forward(self, x):out = self.dense_layers(x)out = self.conv_1x1(out) * self.scaleout = out + xreturn outclass ConvLayer(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, stride):super(ConvLayer, self).__init__()reflection_padding = kernel_size // 2self.reflection_pad = nn.ReflectionPad2d(reflection_padding) # 四個(gè)方向填充self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride)def forward(self, x):out = self.reflection_pad(x)out = self.conv2d(out)return outclass UpsampleConvLayer(torch.nn.Module):def __init__(self, in_channels, out_channels, kernel_size, stride):super(UpsampleConvLayer, self).__init__()self.conv2d = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride)def forward(self, x):out = self.conv2d(x)return outclass ResidualBlock(torch.nn.Module):def __init__(self, channels):super(ResidualBlock, self).__init__()self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)self.relu = nn.PReLU()def forward(self, x):residual = xout = self.relu(self.conv1(x))out = self.conv2(out) * 0.1out = torch.add(out, residual)return outclass Net(nn.Module):def __init__(self, res_blocks=18):super(Net, self).__init__()self.conv_input = ConvLayer(3, 16, kernel_size=11, stride=1)self.dense0 = nn.Sequential(ResidualBlock(16),ResidualBlock(16),ResidualBlock(16))self.conv2x = ConvLayer(16, 32, kernel_size=3, stride=2)self.conv1 = RDB(16, 4, 16)self.fusion1 = Encoder_MDCBlock1(16, 2, mode='iter2')self.dense1 = nn.Sequential(ResidualBlock(32),ResidualBlock(32),ResidualBlock(32))self.conv4x = ConvLayer(32, 64, kernel_size=3, stride=2)self.conv2 = RDB(32, 4, 32)self.fusion2 = Encoder_MDCBlock1(32, 3, mode='iter2')self.dense2 = nn.Sequential(ResidualBlock(64),ResidualBlock(64),ResidualBlock(64))self.conv8x = ConvLayer(64, 128, kernel_size=3, stride=2)self.conv3 = RDB(64, 4, 64)self.fusion3 = Encoder_MDCBlock1(64, 4, mode='iter2')self.dense3 = nn.Sequential(ResidualBlock(128),ResidualBlock(128),ResidualBlock(128))self.conv16x = ConvLayer(128, 256, kernel_size=3, stride=2)self.conv4 = RDB(128, 4, 128)self.fusion4 = Encoder_MDCBlock1(128, 5, mode='iter2')self.dehaze = nn.Sequential()for i in range(0, res_blocks):self.dehaze.add_module('res%d' % i, ResidualBlock(256))self.convd16x = UpsampleConvLayer(256, 128, kernel_size=3, stride=2)self.dense_4 = nn.Sequential(ResidualBlock(128),ResidualBlock(128),ResidualBlock(128))self.conv_4 = RDB(64, 4, 64)self.fusion_4 = Decoder_MDCBlock1(64, 2, mode='iter2')self.convd8x = UpsampleConvLayer(128, 64, kernel_size=3, stride=2)self.dense_3 = nn.Sequential(ResidualBlock(64),ResidualBlock(64),ResidualBlock(64))self.conv_3 = RDB(32, 4, 32)self.fusion_3 = Decoder_MDCBlock1(32, 3, mode='iter2')self.convd4x = UpsampleConvLayer(64, 32, kernel_size=3, stride=2)self.dense_2 = nn.Sequential(ResidualBlock(32),ResidualBlock(32),ResidualBlock(32))self.conv_2 = RDB(16, 4, 16)self.fusion_2 = Decoder_MDCBlock1(16, 4, mode='iter2')self.convd2x = UpsampleConvLayer(32, 16, kernel_size=3, stride=2)self.dense_1 = nn.Sequential(ResidualBlock(16),ResidualBlock(16),ResidualBlock(16))self.conv_1 = RDB(8, 4, 8)self.fusion_1 = Decoder_MDCBlock1(8, 5, mode='iter2')self.conv_output = ConvLayer(16, 3, kernel_size=3, stride=1)def forward(self, x):# Encoderres1x = self.conv_input(x)print('Conv_stride1', res1x.shape)res1x_1, res1x_2 = res1x.split([(res1x.size()[1] // 2), (res1x.size()[1] // 2)], dim=1)feature_mem = [res1x_1]x = self.dense0(res1x) + res1xprint('Residual Group', x.shape)res2x = self.conv2x(x)print('Conv_stride2',res2x.shape)res2x_1, res2x_2 = res2x.split([(res2x.size()[1] // 2), (res2x.size()[1] // 2)], dim=1)res2x_1 = self.fusion1(res2x_1, feature_mem)res2x_2 = self.conv1(res2x_2)print('Dense Feature', res2x_2.shape)feature_mem.append(res2x_1)res2x = torch.cat((res2x_1, res2x_2), dim=1)res2x =self.dense1(res2x) + res2xprint('Residual Group', res2x.shape)res4x =self.conv4x(res2x)print('Conv_stride2', res4x.shape)res4x_1, res4x_2 = res4x.split([(res4x.size()[1] // 2), (res4x.size()[1] // 2)], dim=1)res4x_1 = self.fusion2(res4x_1, feature_mem)res4x_2 = self.conv2(res4x_2)print('Dense Feature', res4x_2.shape)feature_mem.append(res4x_1)res4x = torch.cat((res4x_1, res4x_2), dim=1)res4x = self.dense2(res4x) + res4xprint('Residual Group', res4x.shape)res8x = self.conv8x(res4x)print('Conv_stride2', res8x.shape)res8x_1, res8x_2 = res8x.split([(res8x.size()[1] // 2), (res8x.size()[1] // 2)], dim=1)res8x_1 = self.fusion3(res8x_1, feature_mem)res8x_2 = self.conv3(res8x_2)print('Dense Feature', res8x_2.shape)feature_mem.append(res8x_1)res8x = torch.cat((res8x_1, res8x_2), dim=1)res8x = self.dense3(res8x) + res8xres16x = self.conv16x(res8x)print('Encoder Output', res16x.shape)# Gresres16x_1, res16x_2 = res16x.split([(res16x.size()[1] // 2), (res16x.size()[1] // 2)], dim=1)res16x_1 = self.fusion4(res16x_1, feature_mem)res16x_2 = self.conv4(res16x_2)res16x = torch.cat((res16x_1, res16x_2), dim=1)res_dehaze = res16xin_ft = res16x*2res16x = self.dehaze(in_ft) + in_ft - res_dehazeres16x_1, res16x_2 = res16x.split([(res16x.size()[1] // 2), (res16x.size()[1] // 2)], dim=1)feature_mem_up = [res16x_1]# Boosted Decoderprint('Decoder Input',res16x.shape)res16x = self.convd16x(res16x)res16x = F.upsample(res16x, res8x.size()[2:], mode='bilinear')res8x = torch.add(res16x, res8x)print('Deconv_stride2', res8x.shape)res8x = self.dense_4(res8x) + res8x - res16xprint('Residual Group',res8x.shape)res8x_1, res8x_2 = res8x.split([(res8x.size()[1] // 2), (res8x.size()[1] // 2)], dim=1)res8x_1 = self.fusion_4(res8x_1, feature_mem_up)res8x_2 = self.conv_4(res8x_2)feature_mem_up.append(res8x_1)res8x = torch.cat((res8x_1, res8x_2), dim=1)print('Dense Feature',res8x.shape)res8x = self.convd8x(res8x)res8x = F.upsample(res8x, res4x.size()[2:], mode='bilinear')print('Deconv_stride2', res8x.shape)res4x = torch.add(res8x, res4x)res4x = self.dense_3(res4x) + res4x - res8xprint('Residual Group', res4x.shape)res4x_1, res4x_2 = res4x.split([(res4x.size()[1] // 2), (res4x.size()[1] // 2)], dim=1)res4x_1 = self.fusion_3(res4x_1, feature_mem_up)res4x_2 = self.conv_3(res4x_2)feature_mem_up.append(res4x_1)res4x = torch.cat((res4x_1, res4x_2), dim=1)print('Dense Feature',res4x.shape)res4x = self.convd4x(res4x)res4x = F.upsample(res4x, res2x.size()[2:], mode='bilinear')print('Deconv_stride2', res4x.shape)res2x = torch.add(res4x, res2x)res2x = self.dense_2(res2x) + res2x - res4xprint('Residual Group', res2x.shape)res2x_1, res2x_2 = res2x.split([(res2x.size()[1] // 2), (res2x.size()[1] // 2)], dim=1)res2x_1 = self.fusion_2(res2x_1, feature_mem_up)res2x_2 = self.conv_2(res2x_2)feature_mem_up.append(res2x_1)res2x = torch.cat((res2x_1, res2x_2), dim=1)print('Dense Feature', res2x.shape)res2x = self.convd2x(res2x)res2x = F.upsample(res2x, x.size()[2:], mode='bilinear')x = torch.add(res2x, x)x = self.dense_1(x) + x - res2xprint('Residual Group', x.shape)x_1, x_2 = x.split([(x.size()[1] // 2), (x.size()[1] // 2)], dim=1)x_1 = self.fusion_1(x_1, feature_mem_up)x_2 = self.conv_1(x_2)x = torch.cat((x_1, x_2), dim=1)print('Dense Feature', x.shape)x = self.conv_output(x)print('Conv_stride1',x.shape)return ximage_example = np.ones(shape=(3,256,256)) image = torch.Tensor(image_example).unsqueeze(0)print('Input:', image.shape) net = Net() out = net(image) print('Output:', out.shape)總結(jié)
以上是生活随笔為你收集整理的Multi-Scale Boosted Dehazing Network with Dense Feature Fusion笔记和代码的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 2019.5.8_此书真乃宝书也_从定位
- 下一篇: 前端学习(3212):解决类中的this