unet 层_UNet解释及Python实现
介紹
在圖像分割中,機器必須將圖像分割成不同的segments,每個segment代表不同的實體。
圖像分割示例
正如你在上面看到的,圖像如何變成兩個部分,一個代表貓,另一個代表背景。圖像分割在從自動駕駛汽車到衛星的許多領域都很有用。也許其中最重要的是醫學影像。醫學圖像的微妙之處是相當復雜的。一臺能夠理解這些細微差別并識別出必要區域的機器,可以對醫療保健產生深遠的影響。
卷積神經網絡在簡單的圖像分割問題上取得了不錯的效果,但在復雜的圖像分割問題上卻沒有取得任何進展。這就是UNet的作用。UNet最初是專門為醫學圖像分割而設計的。該方法取得了良好的效果,并在以后的許多領域得到了應用。在本文中,我們將討論UNet工作的原因和方式
UNet背后的直覺
卷積神經網絡(CNN)背后的主要思想是學習圖像的特征映射,并利用它進行更細致的特征映射。這在分類問題中很有效,因為圖像被轉換成一個向量,這個向量用于進一步的分類。但是在圖像分割中,我們不僅需要將feature map轉換成一個向量,還需要從這個向量重建圖像。這是一項巨大的任務,因為要將向量轉換成圖像比反過來更困難。UNet的整個理念都圍繞著這個問題。
在將圖像轉換為向量的過程中,我們已經學習了圖像的特征映射,為什么不使用相同的映射將其再次轉換為圖像呢?這就是UNet背后的秘訣。用同樣的 feature maps,將其用于contraction 來將矢量擴展成segmented image。這將保持圖像的結構完整性,這將極大地減少失真。讓我們更簡單地理解架構。
UNet架構
UNet架構
該架構看起來像一個'U'。該體系結構由三部分組成:contraction,bottleneck和expansion 部分。contraction部分由許多contraction塊組成。每個塊接受一個輸入,應用兩個3X3的卷積層,然后是一個2X2的最大池化。在每個塊之后,核或特征映射的數量會加倍,這樣體系結構就可以有效地學習復雜的結構。最底層介于contraction層和expansion 層之間。它使用兩個3X3 CNN層,然后是2X2 up convolution層。
這種架構的核心在于expansion 部分。與contraction層類似,它也包含幾個expansion 塊。每個塊將輸入傳遞到兩個3X3 CNN層,然后是2X2上采樣層。此外,卷積層使用的每個塊的feature map數量得到一半,以保持對稱性。每次輸入也被相應的收縮層的 feature maps所附加。這個動作將確保在contracting 圖像時學習到的特征將被用于重建圖像。expansion 塊的數量與contraction塊的數量相同。之后,生成的映射通過另一個3X3 CNN層,feature map的數量等于所需的segment的數量。
UNet中的損失計算
UNet對每個像素使用了一種新穎的損失加權方案,使得分割對象的邊緣具有更高的權重。這種損失加權方案幫助U-Net模型以不連續的方式分割生物醫學圖像中的細胞,以便在binary segmentation map中容易識別單個細胞。
首先,在所得圖像上應用pixel-wise softmax,然后是交叉熵損失函數。所以我們將每個像素分類為一個類。我們的想法是,即使在分割中,每個像素都必須存在于某個類別中,我們只需要確保它們可以。因此,我們只是將分段問題轉換為多類分類問題,與傳統的損失函數相比,它表現得非常好。
UNet實現的Python代碼
Python代碼如下:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
class UNet(nn.Module):
def contracting_block(self, in_channels, out_channels, kernel_size=3):
block = torch.nn.Sequential(
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(out_channels),
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=out_channels, out_channels=out_channels),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(out_channels),
)
return block
def expansive_block(self, in_channels, mid_channel, out_channels, kernel_size=3):
block = torch.nn.Sequential(
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(mid_channel),
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(mid_channel),
torch.nn.ConvTranspose2d(in_channels=mid_channel, out_channels=out_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
)
return block
def final_block(self, in_channels, mid_channel, out_channels, kernel_size=3):
block = torch.nn.Sequential(
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(mid_channel),
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(mid_channel),
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=out_channels, padding=1),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(out_channels),
)
return block
def __init__(self, in_channel, out_channel):
super(UNet, self).__init__()
#Encode
self.conv_encode1 = self.contracting_block(in_channels=in_channel, out_channels=64)
self.conv_maxpool1 = torch.nn.MaxPool2d(kernel_size=2)
self.conv_encode2 = self.contracting_block(64, 128)
self.conv_maxpool2 = torch.nn.MaxPool2d(kernel_size=2)
self.conv_encode3 = self.contracting_block(128, 256)
self.conv_maxpool3 = torch.nn.MaxPool2d(kernel_size=2)
# Bottleneck
self.bottleneck = torch.nn.Sequential(
torch.nn.Conv2d(kernel_size=3, in_channels=256, out_channels=512),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(512),
torch.nn.Conv2d(kernel_size=3, in_channels=512, out_channels=512),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(512),
torch.nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, stride=2, padding=1, output_padding=1)
)
# Decode
self.conv_decode3 = self.expansive_block(512, 256, 128)
self.conv_decode2 = self.expansive_block(256, 128, 64)
self.final_layer = self.final_block(128, 64, out_channel)
def crop_and_concat(self, upsampled, bypass, crop=False):
if crop:
c = (bypass.size()[2] - upsampled.size()[2]) // 2
bypass = F.pad(bypass, (-c, -c, -c, -c))
return torch.cat((upsampled, bypass), 1)
def forward(self, x):
# Encode
encode_block1 = self.conv_encode1(x)
encode_pool1 = self.conv_maxpool1(encode_block1)
encode_block2 = self.conv_encode2(encode_pool1)
encode_pool2 = self.conv_maxpool2(encode_block2)
encode_block3 = self.conv_encode3(encode_pool2)
encode_pool3 = self.conv_maxpool3(encode_block3)
# Bottleneck
bottleneck1 = self.bottleneck(encode_pool3)
# Decode
decode_block3 = self.crop_and_concat(bottleneck1, encode_block3, crop=True)
cat_layer2 = self.conv_decode3(decode_block3)
decode_block2 = self.crop_and_concat(cat_layer2, encode_block2, crop=True)
cat_layer1 = self.conv_decode2(decode_block2)
decode_block1 = self.crop_and_concat(cat_layer1, encode_block1, crop=True)
final_layer = self.final_layer(decode_block1)
return final_layer
以上Python代碼中的UNet模塊代表了UNet的整體架構。使用contracaction_block和expansive_block分別創建contraction部分和expansion部分。crop_and_concat函數的作用是將contraction層的輸出添加到新的expansion層輸入中。訓練部分的Python代碼可以寫成
unet = Unet(in_channel=1,out_channel=2)
#out_channel represents number of segments desired
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(unet.parameters(), lr = 0.01, momentum=0.99)
optimizer.zero_grad()
outputs = unet(inputs)
# permute such that number of desired segments would be on 4th dimension
outputs = outputs.permute(0, 2, 3, 1)
m = outputs.shape[0]
# Resizing the outputs and label to caculate pixel wise softmax loss
outputs = outputs.resize(m*width_out*height_out, 2)
labels = labels.resize(m*width_out*height_out)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
結論
圖像分割是一個重要的問題,每天都有一些新的研究論文發表。UNet在這類研究中做出了重大貢獻。許多新架構的靈感都來自UNet。在業界,這種體系結構有很多變體,因此有必要理解第一個變體,以便更好地理解它們。
總結
以上是生活随笔為你收集整理的unet 层_UNet解释及Python实现的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: soap响应报文拼装_SOAP学习,构建
- 下一篇: jackson 驼峰注解_jackson