Grad-CAM在语义分割中的pytorch实现
生活随笔
收集整理的這篇文章主要介紹了
Grad-CAM在语义分割中的pytorch实现
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
利用pytorch,實現在語義分割任務中得到某個類別的Grad-CAM
主要參考
github:pytorch-grad-cam
首先將pytorch-grad-cam中提供的 pytorch-grad-cam/ 路徑下所有的代碼文件下載至本地,將在腳本中作為庫調用。
pytorch-grad-cam本身提供了在語義分割任務中得到Grad-CAM的腳本和教程:Tutorial: Class Activation Maps for Semantic Segmentation。建議閱讀此腳本和教程,筆者也是閱讀之后參考這個腳本進行修改得到的自己的腳本。
python腳本
import warnings from torchvision.models.segmentation import deeplabv3_resnet50 import torch import torch.functional as F import numpy as np import requests import torchvision from PIL import Image from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image from pytorch_grad_cam import GradCAM warnings.filterwarnings('ignore') warnings.simplefilter('ignore') import logging from utils.my_data_parallel import MyDataParallel import network from optimizer import restore_snapshot import argparse from datasets import dataset_XXX from config import assert_and_infer_cfg from torch.backends import cudnnparser = argparse.ArgumentParser(description='evaluation') parser.add_argument('--snapshot', required=True, type=str, default='') parser.add_argument('--arch', type=str, default='', required=True) parser.add_argument('--dataset_cls', type=str, default='cityscapes', help='cityscapes') args = parser.parse_args()# 此處讀取數據集的描述文件,如沒有也可以自己寫一個,添加本代碼中需要的參數就可以了 args.dataset_cls = dataset_XXX assert_and_infer_cfg(args, train_mode=False) args.apex = False # No support for apex eval cudnn.benchmark = False # 此處添加需要得到CAM的文件名 img_name_list = ["XXX", "XXX", "XXX", "XXX"]def get_net():"""Get Network for evaluation"""logging.info('Load model file: %s', args.snapshot)net = network.get_net(args, criterion=None)net = torch.nn.DataParallel(net).cuda()net, _ = restore_snapshot(net, optimizer=None,snapshot=args.snapshot, restore_optimizer_bool=False)net.eval()return netfor img_name in img_name_list:# 因為筆者的網絡是RGB+MASK四通道的,所以需要分別讀取圖片和掩膜并進行合并img_path = "./tmp/grad_cam/pick/raw/" + img_name + ".png"mask_path = "./tmp/grad_cam/pick/mask/" + img_name + ".png"image = np.array(Image.open(img_path))mask = np.array(Image.open(mask_path))rgb_img = np.float32(image) / 255rgb_mask = np.float32(mask) / 255# 此處添加數據集歸一化的均值和方差tensor_img = preprocess_image(rgb_img,mean=[0.000, 0.000, 0.000],std=[0.000, 0.000, 0.000])tensor_mask = preprocess_image(rgb_mask,mean=[0.000],std=[0.000])input_tensor = torch.cat((tensor_img, tensor_mask), dim=0).unsqueeze(0)# Taken from the torchvision tutorial# https://pytorch.org/vision/stable/auto_examples/plot_visualization_utils.htmlmodel = get_net()if torch.cuda.is_available():model = model.cuda()input_tensor = input_tensor.cuda()class SegmentationModelOutputWrapper(torch.nn.Module):def __init__(self, model):super(SegmentationModelOutputWrapper, self).__init__()self.model = modeldef forward(self, x):return self.model(x)model = SegmentationModelOutputWrapper(model)output = model(input_tensor)normalized_masks = torch.nn.functional.softmax(output, dim=1).cpu()# 此處添加類名sem_classes = ['background', 'XXX', 'YYY']sem_class_to_idx = {cls: idx for (idx, cls) in enumerate(sem_classes)}# 將需要進行CAM的類名寫至此處plaque_category = sem_class_to_idx["YYY"]plaque_mask = normalized_masks[0, :, :, :].argmax(axis=0).detach().cpu().numpy()plaque_mask_uint8 = 255 * np.uint8(plaque_mask == plaque_category)plaque_mask_float = np.float32(plaque_mask == plaque_category)both_images = np.hstack((image, np.repeat(plaque_mask_uint8[:, :, None], 3, axis=-1)))Image.fromarray(both_images)class SemanticSegmentationTarget:def __init__(self, category, mask):self.category = categoryself.mask = torch.from_numpy(mask)if torch.cuda.is_available():self.mask = self.mask.cuda()def __call__(self, model_output):return (model_output[self.category, :, :] * self.mask).sum()# 此處修改希望得到特征圖所在的網絡層target_layers = [model.model.backbone.layer4]targets = [SemanticSegmentationTarget(plaque_category, plaque_mask_float)]with GradCAM(model=model,target_layers=target_layers,use_cuda=torch.cuda.is_available()) as cam:grayscale_cam = cam(input_tensor=input_tensor,targets=targets)[0, :]cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)img = Image.fromarray(cam_image)# 保存位置img.save("./tmp/grad_cam/final/" + img_name + ".png")bash腳本
運行此腳本需要添加網絡名稱和模型名稱的參數,分別是–arch和–snapshot,于是撰寫bash腳本添加參數:
#!/usr/bin/env bashpython3 grad_cam.py \--arch network.deepv3.DeepR50V3PlusD_m1_deeply \--snapshot last_epoch_XXX_mean-iu_0.XXXXX.pth運行腳本就可以得到語義分割的Grad-CAM圖啦!
運行效果
總結
以上是生活随笔為你收集整理的Grad-CAM在语义分割中的pytorch实现的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: android ip冲突检测工具,and
- 下一篇: 软件工程专业英语