【知识蒸馏】ICCV21_Channel-wise Knowledge Distillation for Dense Prediction
文章目錄
- 一、背景
- 二、動機
- 三、方法
- 3.1 回顧 Spatial Distillation
- 3.2 Channel-wise Distillation
- 四、效果
- 五、訓練和測試
- 六、代碼解析
論文鏈接:https://arxiv.org/pdf/2011.13256.pdf
代碼鏈接:https://github.com/irfanICMLL/TorchDistiller
MMDetection:https://github.com/pppppM/mmdetection-distiller
MMSegmentation:https://github.com/pppppM/mmsegmentation-distiller
一、背景
密集預測是計算機視覺的一個重要基礎,如語義分割和目標檢測,這些任務需要學習特征的良好表達。目前較好的方法都需要大量的計算資源,難以在移動端部署。
分類任務上的蒸餾起到了明顯的效果[16, 2],但沒法直接用到語義分割,因為將逐個像素分類的任務嚴格對齊會導致 student 模型過度學習 teacher 的輸出,無法獲得最優結果。
于是有一些方法 [25,24,18] 聚焦于加強不同 spatial 的聯系,如圖2a:
- 首先,每個空間位置上的特征圖都被歸一化
- 然后,通過聚合不同空間位置的子集來分析一些特定任務的關系,如 pair-wise 關系[25,35],和 inter-class 關系[18]。
二、動機
- Spatial distillation: 空間方向的蒸餾,可以理解成對所有通道的相同位置的點做歸一化,然后讓學生網絡學習這個歸一化后的分布,可以理解成對類別的蒸餾。
- Channel distillation: 通道方向的蒸餾,可以理解成對單個通道內做歸一化,然后讓學生網絡學習這個歸一化后的分布,可以理解成對位置的蒸餾。
雖然上面的這些方法比逐點對比好一些,但特征圖中的每個空間位置都對 konwledge transfering 貢獻相同,這樣可能從 teacher 帶來一些冗余信息。
還有一些方法使用了 channel 蒸餾,[50] 提出了將每個 channel 的 activation 聚合到一個聚合向量,這樣更有利于 image-level 的分類,但不適合于需要空間信息的密集預測。
所以本文通過歸一化每個 channel 的特征圖來得到 soft probability map,如圖2b,然后最小化兩個網絡的 channel-wise probability map 的 asymmetry Kullback-Leibler(KL)散度,該KL 散度也就是 teacher 和 student 網絡的每個channel間的分布。一個例子如圖2c,每個 channel 的 activation map 會更關注于每個 channel 中的突出區域,也就是每個類別的突出區域,而這些區域恰恰是對密集預測很有用的。
- COCO 上使用 RetinaNet(res50)提了3.4% mAP
- Cityscape 上使用 PSPNet 提了5.81% mIoU
三、方法
*The activation values in this work include the final logits and the inner
feature maps
3.1 回顧 Spatial Distillation
通常的蒸餾方法是使用 point-wise 對齊的方式,形式如下:
3.2 Channel-wise Distillation
為了更好的利用每個 channel 中的知識,作者提出了對 teacher 和 student 網絡的對應 channel activation 進行 softly align。
- 首先,將每個 channel 的 activation 轉換成概率分布,即可以使用概率分布度量方式來衡量其差異,如 KL 散度。如圖2c所示,每個 channel 的 activation 都趨向于對每個類別的突出特征進行編碼
- 然后,使用訓練好的 teacher 模型來得到預測的 clear category-specific mask,如圖1右側所示,讓 student 網絡從 teacher 網絡中學習知識
Channel-wise distillation loss 如下:
- yTy^TyT:teacher 的 activation map
- ySy^SyS:student 的 activation map
- ?\phi?:將 activation value 轉換成概率分布的方式,如下所示,使用這種 softmax 歸一化,就可以消除大網絡和小網絡之間的數值大小之差。
- c=1,2,...,Cc = 1,2,...,Cc=1,2,...,C :表示 channel
- iii : channel 中像素位置
- TTT:溫度參數,也是一個超參數,當 TTT 越大,輸出的概率分布越 soft,即每個channel關注的空間區域就越大,-
- 如何解決 teacher 和 student 的 channel 個數不一致: 使用 1x1 卷積對 student 網絡個數進行上采樣
- Φ\PhiΦ:用來衡量 teacher 和 student 的每個 channel 的概率分布的差異,本文使用 KL 散度
- KL 散度是一個不對稱的衡量方式
- 當 ?(yc,iT)\phi(y_{c,i}^T)?(yc,iT?) 越大,?(yc,iS)\phi(y_{c,i}^S)?(yc,iS?) 也要越大,來最小化 KL 散度
- 當 ?(yc,iT)\phi(y_{c,i}^T)?(yc,iT?) 越小,則 KL 散度確不會讓 ?(yc,iS)\phi(y_{c,i}^S)?(yc,iS?) 一直變小
- 所以,student 網絡會更趨向于在前景突出特征的位置學習 teacher 網絡的分布,teacher 網絡分布的背景區域對學習產生的影響很小
四、效果
T=4T=4T=4
logits map: α=3\alpha=3α=3
feature map: α=50\alpha=50α=50
消融實驗:
五、訓練和測試
以 mmsegmentation 的訓練代碼為例
1、安裝 mmsegmentation
2、軟連接數據:
3、下載訓練好的大模型 pspnet_r101,并放到 pretrained_model下,下載模型路徑
4、訓練和測試
# 單 GPU 訓練 python tools/train.py configs/distiller/cwd/cwd_psp_r101-d8_distill_psp_r18_d8_512_1024_80k_cityscapes.py # 訓練教師網絡 python tools/train.py configs/ocrnet/ocrnet_hr48_512x1024_80k_cityscapes.py# 多 GPU 訓練 bash tools/dist_train.sh configs/distillers/cwd/cwd_psp_r101-d8_distill_psp_r18_d8_512_1024_80k_cityscapes.py 8#單 GPU 測試 python tools/test.py configs/distillers/cwd/cwd_psp_r101-d8_distill_psp_r18_d8_512_1024_80k_cityscapes.py $CHECKPOINT --eval mIoU#多 GPU 測試 bash tools/dist_test.sh configs/distillers/cwd/cwd_psp_r101-d8_distill_psp_r18_d8_512_1024_80k_cityscapes.py $CHECKPOINT 8 --eval mIoU5、了解 config
config/distiller/cwd/cwd_psp_r101-d8_distill_psp_d8_512_1024_80k_cityscapes.py _base_ = ['../../_base_/datasets/cityscapes.py','../../_base_/default_runtime.py', '../../_base_/schedules/schedule_80k.py' ]find_unused_parameters=True weight=5.0 tau=1.0 distiller = dict(type='SegmentationDistiller',teacher_pretrained = 'pretrained_model/pspnet_r101b-d8_512x1024_80k_cityscapes_20201226_170012-3a4d38ab.pth',distill_cfg = [ dict(student_module = 'decode_head.conv_seg',teacher_module = 'decode_head.conv_seg',output_hook = True,methods=[dict(type='ChannelWiseDivergence',name='loss_cwd',student_channels = 19,teacher_channels = 19,tau = tau,weight =weight,)]),])student_cfg = 'configs/pspnet/pspnet_r18-d8_512x1024_80k_cityscapes.py' teacher_cfg = 'configs/pspnet/pspnet_r101-d8_512x1024_80k_cityscapes.py'- 教師網絡 decode_head.conv_seg:
- 學生網絡 decode_head.conv_seg:
6、psp 教師網絡解碼頭結構:
(decode_head): PSPHead(input_transform=None, ignore_index=255, align_corners=False(loss_decode): CrossEntropyLoss()(conv_seg): Conv2d(512, 19, kernel_size=(1, 1), stride=(1, 1))(dropout): Dropout2d(p=0.1, inplace=False)(psp_modules): PPM((0): Sequential((0): AdaptiveAvgPool2d(output_size=1)(1): ConvModule((conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(activate): ReLU(inplace=True)))(1): Sequential((0): AdaptiveAvgPool2d(output_size=2)(1): ConvModule((conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(activate): ReLU(inplace=True)))(2): Sequential((0): AdaptiveAvgPool2d(output_size=3)(1): ConvModule((conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(activate): ReLU(inplace=True)))(3): Sequential((0): AdaptiveAvgPool2d(output_size=6)(1): ConvModule((conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(activate): ReLU(inplace=True))))(bottleneck): ConvModule((conv): Conv2d(4096, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(activate): ReLU(inplace=True)))(auxiliary_head): FCNHead(input_transform=None, ignore_index=255, align_corners=False(loss_decode): CrossEntropyLoss()(conv_seg): Conv2d(256, 19, kernel_size=(1, 1), stride=(1, 1))(dropout): Dropout2d(p=0.1, inplace=False)(convs): Sequential((0): ConvModule((conv): Conv2d(1024, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(activate): ReLU(inplace=True)))))7、psp 學生網絡解碼頭結構:
(decode_head): PSPHead(input_transform=None, ignore_index=255, align_corners=False(loss_decode): CrossEntropyLoss()(conv_seg): Conv2d(128, 19, kernel_size=(1, 1), stride=(1, 1))(dropout): Dropout2d(p=0.1, inplace=False)(psp_modules): PPM((0): Sequential((0): AdaptiveAvgPool2d(output_size=1)(1): ConvModule((conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(activate): ReLU(inplace=True)))(1): Sequential((0): AdaptiveAvgPool2d(output_size=2)(1): ConvModule((conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(activate): ReLU(inplace=True)))(2): Sequential((0): AdaptiveAvgPool2d(output_size=3)(1): ConvModule((conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(activate): ReLU(inplace=True)))(3): Sequential((0): AdaptiveAvgPool2d(output_size=6)(1): ConvModule((conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(activate): ReLU(inplace=True))))(bottleneck): ConvModule((conv): Conv2d(1024, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(activate): ReLU(inplace=True)))(auxiliary_head): FCNHead(input_transform=None, ignore_index=255, align_corners=False(loss_decode): CrossEntropyLoss()(conv_seg): Conv2d(64, 19, kernel_size=(1, 1), stride=(1, 1))(dropout): Dropout2d(p=0.1, inplace=False)(convs): Sequential((0): ConvModule((conv): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(activate): ReLU(inplace=True)))))(distill_losses): ModuleDict((loss_cwd): ChannelWiseDivergence()) )這里的 decode_head.seg_conv 其實是最后一層的輸出,即 PSP 頭輸出的最終結果,每個通道表示一個類別目標的預測。
8、如何修改為其他網絡結構的蒸餾
這里以 OCR 網絡為例,psp 中是使用網絡的 decode_head.seg_conv 作為輸入的,我們首先需要看一下 OCR 網絡的 decode_head 結構,然后也取最后一層的輸出,即最后一層頭的 seg_conv 作為蒸餾的輸入,這里以 hr48 作為教師網絡,hr18s作為學生網絡:
教師網絡 decode_head:
ModuleList((0): FCNHead(input_transform=resize_concat, ignore_index=255, align_corners=False(loss_decode): CrossEntropyLoss()(conv_seg): Conv2d(270, 19, kernel_size=(1, 1), stride=(1, 1))(convs): Sequential((0): ConvModule((conv): Conv2d(270, 270, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(270, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(activate): ReLU(inplace=True))))(1): OCRHead(input_transform=resize_concat, ignore_index=255, align_corners=False(loss_decode): CrossEntropyLoss()(conv_seg): Conv2d(512, 19, kernel_size=(1, 1), stride=(1, 1))(object_context_block): ObjectAttentionBlock((key_project): Sequential((0): ConvModule((conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(activate): ReLU(inplace=True))(1): ConvModule((conv): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(activate): ReLU(inplace=True)))(query_project): Sequential((0): ConvModule((conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(activate): ReLU(inplace=True))(1): ConvModule((conv): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(activate): ReLU(inplace=True)))(value_project): ConvModule((conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(activate): ReLU(inplace=True))(out_project): ConvModule((conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(activate): ReLU(inplace=True))(bottleneck): ConvModule((conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(activate): ReLU(inplace=True)))(spatial_gather_module): SpatialGatherModule()(bottleneck): ConvModule((conv): Conv2d(270, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(activate): ReLU(inplace=True))) )基于此,OCR 網絡的蒸餾輸入:
- 教師網絡
- 學生網絡
所以只需要修改config即可,大模型是在mmsegmentation 官方代碼中下載的,最終config如下:
_base_ = ['../../_base_/datasets/cityscapes.py','../../_base_/default_runtime.py', '../../_base_/schedules/schedule_80k.py' ]find_unused_parameters=True weight=5.0 tau=1.0 distiller = dict(type='SegmentationDistiller',teacher_pretrained = 'pretrained_model/ocrnet_hr48_512x1024_160k_cityscapes.pth',distill_cfg = [ dict(student_module = 'decode_head.1.conv_seg',teacher_module = 'decode_head.1.conv_seg',output_hook = True,methods=[dict(type='ChannelWiseDivergence',name='loss_cwd',student_channels = 19,teacher_channels = 19,tau = tau,weight =weight,)]),])student_cfg = 'configs/ocrnet/ocrnet_hr18s_512x1024_80k_cityscapes.py' teacher_cfg = 'configs/ocrnet/ocrnet_hr48_512x1024_160k_cityscapes.py'代碼訓練:
python tools/train.py configs/distiller/cwd/cwd_ocr_hr48-d8_distill_ocr_hr18s-d8_512_1024_80k_cityscapes.py訓練結果記錄:
cityscapes/ val /512x1024/ 80k iter/
| psp_r101 (272.4M) | 79.74 | psp_r18 (51.2M) | 74.86 | |
| ocr_hr48 (282.2M) | 81.35 | ocr_hr18s (25.8M) | 79.68 | 77.29 |
六、代碼解析
如果沒有 distiller config 的話,則會按照正常訓練方式訓練,distiller config 如下:
distiller_cfg = cfg.get('distiller', None) $ p disstiller_cfg >>> {'type': 'SegmentationDistiller', 'teacher_pretrained': 'pretrained_model/ocrnet_hr48_512x1024_160k_cityscapes.pth', 'distill_cfg': [{'student_module': 'decode_head.1.conv_seg', 'teacher_module': 'decode_head.1.conv_seg', 'output_hook': True, 'methods': [{'type': 'ChannelWiseDivergence', 'name': 'loss_cwd', 'student_channels': 19, 'teacher_channels': 19, 'tau': 1.0, 'weight': 5.0}]}]}使用 Config.fromfile() 即可把 config 文件中的內容拿出來:
teacher_cfg = Config.fromfile(cfg.teacher_cfg) student_cfg = Config.fromfile(cfg.student_cfg)訓練的時候使用的是 student 模型的 train_cfg 和 test_cfg:
tools/train.py # line 137 model = build_distiller(cfg.distiller,teacher_cfg,student_cfg,train_cfg=student_cfg.get('train_cfg'), test_cfg=student_cfg.get('test_cfg'))蒸餾的訓練方式和普通的訓練方式不同之一:optimezier 優化的參數不同,蒸餾的話,只有student 的參數和蒸餾 loss 的參數參與訓練。
mmseg/apis/train.py # line 72 # build runner distiller_cfg = cfg.get('distiller',None) if distiller_cfg is None:optimizer = build_optimizer(model, cfg.optimizer) else:# base_parameters() 在 segmentation_distiller.py line 69# base_parameters() 包括 student 和 distill_lossoptimizer = build_optimizer(model.module.base_parameters(), cfg.optimizer)可以使用這樣的方式來查看不需要參數訓練的參數:
# pytorch 中需要訓練的參數 model.named_parameters() # 不需要參數訓練的參數 model.named_buffers()pytorch 可以使用 register_buffer() 來使得該參數不參與訓練
# name 是名字, 參數是登記的不參與訓練的參數 register_buffer(name, 參數) buffer_key = [k for k,v in self.named_buffers()] >>> ['student_decode_head_1_conv_seg', 'teacher_decode_head_1_conv_seg', 'teacher.backbone.bn1.running_mean', 'teacher.backbone.bn1.running_var', 'teacher.backbone.bn1.num_batches_tracked', 'teacher.backbone.bn2.running_mean', 'teacher.backbone.bn2.running_var', 'teacher.backbone.bn2.num_batches_tracked', ...蒸餾的訓練方法:分兩步,第一步計算不參與蒸餾的層的 loss,然后計算參與蒸餾的層的loss
mmseg/distillation/distillers/segmentation_distiller.py def forward_train(self, img, img_metas, gt_semantic_seg):with torch.no_grad():self.teacher.eval()teacher_loss = self.teacher.forward_train(img, img_metas, gt_semantic_seg) # mmseg/models/segmentors/encoder_decoder.py(136)forward_train()student_loss = self.student.forward_train(img, img_metas, gt_semantic_seg)# 整體loss# {'decode_0.loss_seg': tensor(1.1701, device='cuda:0', grad_fn=<MulBackward0>), 'decode_0.acc_seg': tensor([3.7306], device='cuda:0'), \# 'decode_1.loss_seg': tensor(2.9231, device='cuda:0', grad_fn=<MulBackward0>), 'decode_1.acc_seg': tensor([6.0701], device='cuda:0')}buffer_dict = dict(self.named_buffers()) # named_buffers() 查看網絡中不需要更新的參數, parameters()查看網絡中需要更新的參數for item_loc in self.distill_cfg:student_module = 'student_' + item_loc.student_module.replace('.','_') # 'student_decode_head_1_conv_seg'teacher_module = 'teacher_' + item_loc.teacher_module.replace('.','_') # 'teacher_decode_head_1_conv_seg'# 下面這兩步是關鍵,提取的是教師網絡和學生網絡的輸入 decode_head 之前的輸出,如下圖所示student_feat = buffer_dict[student_module] # [b, 19, 128 256]teacher_feat = buffer_dict[teacher_module] # [b, 19, 128 256]for item_loss in item_loc.methods: # item_loc.methods: [{'type': 'ChannelWiseDivergence', 'name': 'loss_cwd', 'student_channels': 19, 'teacher_channels': 19, 'tau': 1.0, 'weight': 5.0}]loss_name = item_loss.name # 'loss_cwd'student_loss[ loss_name] = self.distill_losses[loss_name](student_feat,teacher_feat)# 增加了蒸餾 loss 后的loss: # {'decode_0.loss_seg': tensor(1.1701, device='cuda:0', grad_fn=<MulBackward0>), 'decode_0.acc_seg': tensor([3.7306], device='cuda:0'),# 'decode_1.loss_seg': tensor(2.9231, device='cuda:0', grad_fn=<MulBackward0>), 'decode_1.acc_seg': tensor([6.0701], device='cuda:0'), # 'loss_cwd': tensor(51.9439, device='cuda:0', grad_fn=<DivBackward0>)}return student_loss下面這兩組特征的特征圖如下圖所示,學生網絡是第一次迭代的特征圖,還沒有學到任何特征
student_feat = buffer_dict[student_module] # [b, 19, 128 256] teacher_feat = buffer_dict[teacher_module] # [b, 19, 128 256]teacher_feat:
student_feat:
看一下這兩個特征是怎么來的,這里是使用 hook 來獲取這兩層的輸出特征來得到的這兩組特征,每次實例化SegmentationDistiller 這個類的時候,其 init 里邊都會走一遍特征注冊的過程,保證每次迭代后的特征放入 hook 里邊:
hook 分為兩種:
- register_forward_hook(hook)
- register_backward_hook(hook)
hook 的作用是獲取某些變量的中間結果,因為pytorch會自動舍棄圖計算的中間結果,所以想要獲取這些數值就需要使用 hook 函數,hook 函數在使用后需要及時刪除,避免每次都運行其增加負載。
# 這里寫了一個注冊的 hook def regitster_hooks(student_module,teacher_module):def hook_teacher_forward(module, input, output):# 這里的 input 和 output 是這層的輸入和輸出self.register_buffer(teacher_module,output) # 通過register_buffer()登記過的張量:會自動成為模型中的參數,隨著模型移動(gpu/cpu)而移動,但是不會隨著梯度進行更新。def hook_student_forward(module, input, output):self.register_buffer( student_module,output )return hook_teacher_forward,hook_student_forwardfor item_loc in distill_cfg:student_module = 'student_' + item_loc.student_module.replace('.','_') # 'student_decode_head_1_conv_seg'teacher_module = 'teacher_' + item_loc.teacher_module.replace('.','_') # 'teacher_decode_head_1_conv_seg'# 這里進行hook_teacher_forward,hook_student_forward = regitster_hooks(student_module ,teacher_module )teacher_modules[item_loc.teacher_module].register_forward_hook(hook_teacher_forward)student_modules[item_loc.student_module].register_forward_hook(hook_student_forward)register_forward_hook(hook) 作用就是(假設想要conv2層),那么就是根據 model(該層),該層input,該層output,可以將 output獲取。
register_forward_hook(hook) 最大的作用也就是當訓練好某個model,想要展示某一層對最終目標的影響效果。
求loss的方法:
import torch.nn as nn import torch.nn.functional as F import torchfrom .utils import weight_reduce_loss from ..builder import DISTILL_LOSSES@DISTILL_LOSSES.register_module() class ChannelWiseDivergence(nn.Module):"""PyTorch version of `Channel-wise Distillation for Semantic Segmentation<https://arxiv.org/abs/2011.13256>`_.Args:student_channels(int): Number of channels in the student's feature map.teacher_channels(int): Number of channels in the teacher's feature map.name(str): tau (float, optional): Temperature coefficient. Defaults to 1.0.weight (float, optional): Weight of loss.Defaults to 1.0."""def __init__(self,student_channels,teacher_channels,name,tau=1.0,weight=1.0,):super(ChannelWiseDivergence, self).__init__()self.tau = tauself.loss_weight = weightif student_channels != teacher_channels:self.align = nn.Conv2d(student_channels, teacher_channels, kernel_size=1, stride=1, padding=0)else:self.align = Nonedef forward(self,preds_S,preds_T):"""Forward function."""assert preds_S.shape[-2:] == preds_T.shape[-2:],'the output dim of teacher and student differ'N,C,W,H = preds_S.shape # [2, 19, 128, 256]if self.align is not None:preds_S = self.align(preds_S)softmax_pred_T = F.softmax(preds_T.view(-1,W*H)/self.tau, dim=1)softmax_pred_S = F.softmax(preds_S.view(-1,W*H)/self.tau, dim=1)logsoftmax = torch.nn.LogSoftmax(dim=1)loss = torch.sum( - softmax_pred_T * logsoftmax(preds_S.view(-1,W*H)/self.tau)) * (self.tau ** 2)return self.loss_weight * loss / (C * N)
這里 KL 散度公式如上,展開后是這樣的:
DKL=∑plogp?plogq=∑TlogT?TlogSD_{KL} = \sum p logp-plogq=\sum TlogT-TlogSDKL?=∑plogp?plogq=∑TlogT?TlogS
前一項實際上是教師網絡的輸出,是固定不變的,所以最終的形式變成了 ∑?TlogS\sum-TlogS∑?TlogS,也就是上面的代碼中的形式。
這里以 OCR 為例解釋一下 loss 的組成:FCN loss + OCR loss + distillation loss
1、原始loss的計算:
- OCR 是 cascade_docode_head,因為其解碼頭由 FCN 和 OCR 組成
- FCN 的輸入是backbone的輸出,FCN 拿到一組 backbone 的輸出(有四組不同大小的特征圖構成,通道數共為270),然后輸出成 [N, 19, 128, 256] 的特征圖進行loss計算,這里就是總loss中的 'decode_0.loss_seg'
所以,在 segmentation_distiller.py 中計算原本的 loss 的時候,loss 會找到 mmseg/models/segmentors/cascade_encoder_decoder.py 中來計算前向傳播的loss:
def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg):"""Run forward function and calculate loss for decode head intraining."""losses = dict()# 先計算 decode_head[0] 的 loss,即 FPN 的 loss# 第一個 decode_head 走的是 cascade_head.py 的 forward_train 的過程loss_decode = self.decode_head[0].forward_train(x, img_metas, gt_semantic_seg, self.train_cfg)# loss_decode: {'loss_seg': tensor(1.1506, device='cuda:0', grad_fn=<MulBackward0>), 'acc_seg': tensor([1.5568], device='cuda:0')}losses.update(add_prefix(loss_decode, 'decode_0'))# loss: {'decode_0.loss_seg': tensor(1.1506, device='cuda:0', grad_fn=<MulBackward0>), 'decode_0.acc_seg': tensor([1.5568], device='cuda:0')}for i in range(1, self.num_stages): # config/models/ocrnet_hr18.py 中寫了 num_stage=2# forward test again, maybe unnecessary for most methods.# prev_outputs 是將 backbone 的輸出又走了一遍 FPN 得到的輸出,即 decode_head[0] 的輸出 [N, 19, 128, 256]prev_outputs = self.decode_head[i - 1].forward_test(x, img_metas, self.test_cfg)# 然后將 FPN 的輸出作為 loss 的輸入# 第二個及之后的 decode_heads 都會走 cascade_decode_head 的 forward_train,走到 ocr_head.py 中去# mmseg/models/decode_heads/cascade_decode_head.py # line 18# 這里的 x 是 backbone的輸出(270維),prev_outputs 是 FPN 的輸出# OCRnet 會利用backbone 的輸出和 FPN 的輸出,做一個自己的注意力操作,得到 [N, 19, 128, 256] 的輸出,然后和真值做 lossloss_decode = self.decode_head[i].forward_train(x, prev_outputs, img_metas, gt_semantic_seg, self.train_cfg)losses.update(add_prefix(loss_decode, f'decode_{i}'))# {'decode_0.loss_seg': tensor(1.1506, device='cuda:0', grad_fn=<MulBackward0>), 'decode_0.acc_seg': tensor([1.5568], device='cuda:0'), 'decode_1.loss_seg': tensor(2.8385, device='cuda:0', grad_fn=<MulBackward0>), 'decode_1.acc_seg': tensor([1.2970], device='cuda:0')}return losses # mmseg/models/decode_heads/decode_head.py # line 170 # decode_head[0] 的計算 loss def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):# inputs.shape [2, 19, 128, 256]# seg_logits = self.forward(inputs)losses = self.losses(seg_logits, gt_semantic_seg)return losses # mmseg/models/decode_heads/cascade_decode_head.py # line 18 # decode_head[1] 及之后 head 的計算 loss def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg,train_cfg):seg_logits = self.forward(inputs, prev_output)losses = self.losses(seg_logits, gt_semantic_seg)return losses # mmseg/models/decode_heads/decode_head.py @force_fp32(apply_to=('seg_logit', )) def losses(self, seg_logit, seg_label):"""Compute segmentation loss."""loss = dict()# 先把預測的 128x256 的結果上采樣到 512x1024的,和真值大小一樣seg_logit = resize(input=seg_logit,size=seg_label.shape[2:],mode='bilinear',align_corners=self.align_corners)if self.sampler is not None:seg_weight = self.sampler.sample(seg_logit, seg_label)else:seg_weight = Noneseg_label = seg_label.squeeze(1)# 進入 cross_entropy_loss # mmseg/models/losses/cross_entropy_loss.pyloss['loss_seg'] = self.loss_decode(seg_logit,seg_label,weight=seg_weight,ignore_index=self.ignore_index)loss['acc_seg'] = accuracy(seg_logit, seg_label)return loss# 得到 'acc_seg' 和 'loss_seg'2、蒸餾 loss 的計算:計算
def forward(self, preds_S, preds_T):"""Forward function."""assert preds_S.shape[-2:] == preds_T.shape[-2:],'the output dim of teacher and student differ'N,C,W,H = preds_S.shapeif self.align is not None:preds_S = self.align(preds_S)# 這里的歸一化方式是唯一能體現 channel 的地方# 對每個channel的所有元素進行歸一化,然后讓學生網絡學習歸一化后的通道特征softmax_pred_T = F.softmax(preds_T.view(-1,W*H)/self.tau, dim=1) #[NxC, 32768]logsoftmax = torch.nn.LogSoftmax(dim=1)loss = torch.sum( - softmax_pred_T * logsoftmax(preds_S.view(-1,W*H)/self.tau)) * (self.tau ** 2)return self.loss_weight * loss / (C * N)最終的 loss 如下:
{'decode_0.loss_seg': tensor(1.1506, device='cuda:0', grad_fn=<MulBackward0>), 'decode_0.acc_seg': tensor([1.5568], device='cuda:0'), 'decode_1.loss_seg': tensor(2.8385, device='cuda:0', grad_fn=<MulBackward0>), 'decode_1.acc_seg': tensor([1.2970], device='cuda:0'), 'loss_cwd': tensor(52.1290, device='cuda:0', grad_fn=<DivBackward0>)}然后在 mmseg/models/segmentors/base.py 中,求 loss 的和:
loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key) log_vars['loss'] = loss { 'loss':tensor(55.8550, device='cuda:0', grad_fn=<AddBackward0>), 'log_vars': OrderedDict([('decode_0.loss_seg', 1.0829237699508667), ('decode_0.acc_seg', 10.901641845703125), ('decode_1.loss_seg', 2.7209525108337402), ('decode_1.acc_seg', 2.446269989013672), ('loss_cwd', 52.051116943359375), ('loss', 55.8549919128418)]), 'num_samples': 2 }Register 的簡要介紹:
mmseg框架里邊使用了很多注冊的方式,注冊模塊實際上是通過字典保存名字對應類的地址,其中最重要的是 register 類
首先,使用self._module_dict = dict() 來作為注冊類的地址,以便后續訪問。
@SEGMENTORS.register_module() class EncoderDecoder(BaseSegmentor):def __init__(self,backbone,decode_head,neck=None,auxiliary_head=None,train_cfg=None,test_cfg=None,pretrained=None):super(EncoderDecoder, self).__init__()self.backbone = builder.build_backbone(backbone)if neck is not None:self.neck = builder.build_neck(neck)self._init_decode_head(decode_head)self._init_auxiliary_head(auxiliary_head)self.train_cfg = train_cfgself.test_cfg = test_cfgself.init_weights(pretrained=pretrained)assert self.with_decode_headregister.py 文件如下:
import inspect import sixdef is_str(x):"""Whether the input is an string instance."""return isinstance(x, six.string_types)class Registry(object):def __init__(self, name):self._name = name # 此處的self,是個對象(Object),是當前類的實例,name即為傳進來的'detector'值self._module_dict = dict() # 定義的屬性,是一個字典@propertydef name(self): # 把方法變成屬性,通過self.name 就能獲得name的值。我感覺是一個私有函數return self._name@propertydef module_dict(self):return self._module_dictdef get(self, key):return self._module_dict.get(key, None)def _register_module(self, module_class):"""關鍵的一個方法,作用就是Register a module.在model文件夾下的py文件中,里面的class定義上面都會出現 @DETECTORS.register_module,意思就是將類當做形參,將類送入了方法register_module()中執行。@的具體用法看后面解釋。Register a module.Args:module (:obj:`nn.Module`): Module to be registered."""# if not inspect.isclass(module_class): # 判斷是否為類,是類的話,就為True,否則報錯# raise TypeError('module must be a class, but got {}'.format(# type(module_class)))module_name = module_class.__name__ # 獲取類名if module_name in self._module_dict: # 看該類是否已經登記在屬性_module_dict中raise KeyError('{} is already registered in {}'.format(module_name, self.name))self._module_dict[module_name] = module_class # 在module中dict新增key和value。key為類名,value為類對象def register_module(self, cls): # 對上面的方法,修改了名字,添加了返回值,即返回類本身self._register_module(cls)return clsdef build_from_cfg(cfg, registry, default_args=None):"""Build a module from config dict.Args:cfg (dict): Config dict. It should at least contain the key "type".registry (:obj:`Registry`): The registry to search the type from.default_args (dict, optional): Default initialization arguments.Returns:obj: The constructed object."""assert isinstance(cfg, dict) and 'type' in cfgassert isinstance(default_args, dict) or default_args is Noneargs = cfg.copy()obj_type = args.pop('type')if is_str(obj_type):obj_cls = registry.get(obj_type)if obj_cls is None:raise KeyError('{} is not in the {} registry'.format(obj_type, registry.name))elif inspect.isclass(obj_type):obj_cls = obj_typeelse:raise TypeError('type must be a str or valid type, but got {}'.format(type(obj_type)))if default_args is not None:for name, value in default_args.items():args.setdefault(name, value)return obj_cls(**args)總結
以上是生活随笔為你收集整理的【知识蒸馏】ICCV21_Channel-wise Knowledge Distillation for Dense Prediction的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 神还原or毁童年?网飞发布真人版《海贼王
- 下一篇: 【语义分割】ICCV21_Mining