Involution代码
生活随笔
收集整理的這篇文章主要介紹了
Involution代码
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
github原址https://github.com/d-li14/involution/blob/21c3158fcbb4ecda8ed4626fcae8b01be511a598/cls/mmcls/models/utils/involution_naive.py#L5
之后會有相關解讀
class involution(nn.Module):def __init__(self,channels,kernel_size,stride):super(involution, self).__init__()self.kernel_size = kernel_sizeself.stride = strideself.channels = channelsreduction_ratio = 4self.group_channels = 16self.groups = self.channels // self.group_channelsself.conv1 = ConvModule(in_channels=channels,out_channels=channels // reduction_ratio,kernel_size=1,conv_cfg=None,norm_cfg=dict(type='BN'),act_cfg=dict(type='ReLU'))self.conv2 = ConvModule(in_channels=channels // reduction_ratio,out_channels=kernel_size**2 * self.groups,kernel_size=1,stride=1,conv_cfg=None,norm_cfg=None,act_cfg=None)if stride > 1:self.avgpool = nn.AvgPool2d(stride, stride)self.unfold = nn.Unfold(kernel_size, 1, (kernel_size-1)//2, stride)def forward(self, x):weight = self.conv2(self.conv1(x if self.stride == 1 else self.avgpool(x)))b, c, h, w = weight.shapeweight = weight.view(b, self.groups, self.kernel_size**2, h, w).unsqueeze(2)out = self.unfold(x).view(b, self.groups, self.group_channels, self.kernel_size**2, h, w)out = (weight * out).sum(dim=3).view(b, self.channels, h, w)return out?
總結
以上是生活随笔為你收集整理的Involution代码的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: tensorflow tf.nn.max
- 下一篇: Tensorflow broadcast