文章目錄
一、文章出發點
每個像素點的類別(label)應該是它所屬目標(object)的類別。
所以這篇文章對像素的上下文信息建模
建模方法:求每個像素點和每個類別的相關性
二、方法
1、prev_output:
2、feats獲得
3、context獲得
4、self-attention
import torch
import torch
.nn
as nn
import torch
.nn
.functional
as F
from mmcv
.cnn
import ConvModule
from mmseg
.ops
import resize
from ..builder
import HEADS
from ..utils
import SelfAttentionBlock
as _SelfAttentionBlock
from .cascade_decode_head
import BaseCascadeDecodeHead
class SpatialGatherModule(nn
.Module
):"""Aggregate the context features according to the initial predictedprobability distribution.Employ the soft-weighted method to aggregate the context."""def __init__(self
, scale
):super(SpatialGatherModule
, self
).__init__
()self
.scale
= scale
def forward(self
, feats
, probs
):"""Forward function."""batch_size
, num_classes
, height
, width
= probs
.size
()channels
= feats
.size
(1)probs
= probs
.view
(batch_size
, num_classes
, -1)feats
= feats
.view
(batch_size
, channels
, -1)feats
= feats
.permute
(0, 2, 1)probs
= F
.softmax
(self
.scale
* probs
, dim
=2)ocr_context
= torch
.matmul
(probs
, feats
)ocr_context
= ocr_context
.permute
(0, 2, 1).contiguous
().unsqueeze
(3)return ocr_context
class ClassRelationGatherModule(nn
.Module
):"""Aggregate the context features according to the initial predictedprobability distribution.Employ the soft-weighted method to aggregate the context."""def __init__(self
, scale
):super(ClassRelationGatherModule
, self
).__init__
()self
.scale
= scale
def forward(self
, feats
, probs
):"""Forward function."""batch_size
, num_classes
, height
, width
= probs
.size
()channels
= feats
.size
(1)probs_1
= probs
.view
(batch_size
, num_classes
, -1)probs_2
= probs
.view
(batch_size
, num_classes
, -1)probs_2
= probs_2
.permute
(0, 2, 1)probs_1
= F
.softmax
(self
.scale
* probs_1
, dim
=2)class_gather
= torch
.matmul
(probs_1
, probs_2
)class_gather
= class_gather
.permute
(0, 2, 1).contiguous
().unsqueeze
(3)return class_gather
class ObjectAttentionBlock(_SelfAttentionBlock
):"""Make a OCR used SelfAttentionBlock."""def __init__(self
, in_channels
, channels
, scale
, conv_cfg
, norm_cfg
,act_cfg
):if scale
> 1:query_downsample
= nn
.MaxPool2d
(kernel_size
=scale
)else:query_downsample
= Nonesuper(ObjectAttentionBlock
, self
).__init__
(key_in_channels
=in_channels
,query_in_channels
=in_channels
,channels
=channels
,out_channels
=in_channels
,share_key_query
=False,query_downsample
=query_downsample
,key_downsample
=None,key_query_num_convs
=2,key_query_norm
=True,value_out_num_convs
=1,value_out_norm
=True,matmul_norm
=True,with_out
=True,conv_cfg
=conv_cfg
,norm_cfg
=norm_cfg
,act_cfg
=act_cfg
)self
.bottleneck
= ConvModule
(in_channels
* 2,in_channels
,1,conv_cfg
=self
.conv_cfg
,norm_cfg
=self
.norm_cfg
,act_cfg
=self
.act_cfg
)def forward(self
, query_feats
, key_feats
):"""Forward function."""context
= super(ObjectAttentionBlock
, self
).forward
(query_feats
, key_feats
)output
= self
.bottleneck
(torch
.cat
([context
, query_feats
], dim
=1))if self
.query_downsample
is not None:output
= resize
(query_feats
)return output@HEADS
.register_module
()
class OCRHead(BaseCascadeDecodeHead
):"""Object-Contextual Representations for Semantic Segmentation.This head is the implementation of `OCRNet<https://arxiv.org/abs/1909.11065>`_.Args:ocr_channels (int): The intermediate channels of OCR block.scale (int): The scale of probability map in SpatialGatherModule inDefault: 1."""def __init__(self
, ocr_channels
, scale
=1, **kwargs
):super(OCRHead
, self
).__init__
(**kwargs
)self
.ocr_channels
= ocr_channelsself
.scale
= scaleself
.object_context_block
= ObjectAttentionBlock
(self
.channels
,self
.ocr_channels
,self
.scale
,conv_cfg
=self
.conv_cfg
,norm_cfg
=self
.norm_cfg
,act_cfg
=self
.act_cfg
)self
.spatial_gather_module
= SpatialGatherModule
(self
.scale
)self
.class_relation_gather_module
= ClassRelationGatherModule
(self
.scale
)self
.bottleneck
= ConvModule
(self
.in_channels
,self
.channels
,3,padding
=1,conv_cfg
=self
.conv_cfg
,norm_cfg
=self
.norm_cfg
,act_cfg
=self
.act_cfg
)def forward(self
, inputs
, prev_output
):"""Forward function."""x
= self
._transform_inputs
(inputs
) feats
= self
.bottleneck
(x
) context
= self
.spatial_gather_module
(feats
, prev_output
) object_context
= self
.object_context_block
(feats
, context
) output
= self
.cls_seg
(object_context
) return output
三、效果
經過OCR頭后的效果對比如下圖,每個類別的響應比較全面且穩定。
cityscape類別和通道的對應:
總結
以上是生活随笔為你收集整理的【语义分割】OCRNet:Object-Context Representations for Semantic Segmentation的全部內容,希望文章能夠幫你解決所遇到的問題。
如果覺得生活随笔網站內容還不錯,歡迎將生活随笔推薦給好友。