pytorch复现RRU-Net
                                                            生活随笔
收集整理的這篇文章主要介紹了
                                pytorch复现RRU-Net
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.                        
                                論文地址
https://openaccess.thecvf.com/content_CVPRW_2019/html/CV-COPS/Bi_RRU-Net_The_Ringed_Residual_U-Net_for_Image_Splicing_Forgery_Detection_CVPRW_2019_paper.html
開源代碼
https://github.com/yelusaleng/RRU-Net
下面就是復現的代碼
使用的數據集是NIST
項目結構
詳細代碼
./dataload/NIST_data.py
import torch import torchvision.transforms as tfs import os from PIL import Image import timeLINUX = 0if(LINUX==0):prefix = "G:/data/" else:prefix = "/home/yyl/data/"HEIGHT = 256 WIDTH = 256class NIST_DATA(object):def __init__(self, mode="train"):super(NIST_DATA, self).__init__()self.nist_data_root = prefix + "NIST/"self.image_name = []self.gt_name = []self.readImage()# 圖像正則化參數self.mean = [0.40789654, 0.44719302, 0.47026115]self.std = [0.28863828, 0.27408164, 0.27809835]self.im_tfs = tfs.Compose([tfs.ToTensor(),tfs.Normalize(self.mean, self.std)])self.im_tfs_gt = tfs.Compose([tfs.ToTensor()])#self.filter()# 訓練集配置前80%的圖片,測試集配置后20%的圖片if(mode == "train"):self.image_name = self.image_name[:round(len(self.image_name)*0.8)]self.gt_name = self.gt_name[:round(len(self.gt_name)*0.8)]else:self.image_name = self.image_name[round(len(self.image_name) * 0.8):]self.gt_name = self.gt_name[round(len(self.gt_name) * 0.8):]print("%s->一共加載了%d張圖片"%(mode,len(self.image_name)))# 讀取圖片def readImage(self):image_root = self.nist_data_root + "1/"gt_root = self.nist_data_root + "1_mask/"filename = os.listdir(image_root)for i in range(len(filename)):self.image_name.append(image_root + filename[i])self.gt_name.append(gt_root + filename[i].split(".")[0]+".png")image_root = self.nist_data_root + "2/"gt_root = self.nist_data_root + "2_mask/"filename = os.listdir(image_root)for i in range(len(filename)):self.image_name.append(image_root + filename[i])self.gt_name.append(gt_root + filename[i].split(".")[0] + ".png")def filter(self, thresold=10):tmp_image_name = self.image_name.copy()tmp_gt_name = self.gt_name.copy()for idx in range(len(self.image_name)):img = Image.open(self.image_name[idx])img_gt = Image.open(self.gt_name[idx])img, img_gt = self.image_transform(img, img_gt)if(img_gt.sum()<thresold):#print("刪除掉%d"%idx)tmp_image_name.remove(self.image_name[idx])tmp_gt_name.remove(self.gt_name[idx])self.image_name = tmp_image_nameself.gt_name = tmp_gt_name# 檢查gt和原圖是否一致def check(self):for i in range(len(self.image_name)):print(self.image_name[i] + " " + self.gt_name[i])# 裁剪圖片,默認為256*256def crop(self, img, img_gt, height=HEIGHT, width=WIDTH, offset=0):box = (offset, offset, width, height)img = img.crop(box)img_gt = img_gt.crop(box)return img, img_gtdef image_transform(self, img, img_gt):#img, img_gt = self.crop(img, img_gt)img = self.im_tfs(img)img_gt = self.im_tfs_gt(img_gt)return img, img_gtdef __getitem__(self, idx):img = Image.open(self.image_name[idx])img_gt = Image.open(self.gt_name[idx])img = img.resize((HEIGHT, WIDTH))img_gt = img_gt.resize((HEIGHT,WIDTH))img, img_gt = self.image_transform(img, img_gt)return img, img_gtdef __len__(self):return len(self.image_name)if __name__ == "__main__":st = time.time()nist_train = NIST_DATA("train")nist_val = NIST_DATA("val")nist_train.check()train_loader = torch.utils.data.DataLoader(nist_train, batch_size=8, shuffle=True)val_loader = torch.utils.data.DataLoader(nist_val, batch_size=8, shuffle=True)print("cost time%d"%(time.time()-st))for data in val_loader:img, img_gt = dataprint(img.shape)print(img_gt.shape)./network/RRUNet.py
import torch import torch.nn as nn import torch.nn.functional as F# ~~~~~~~~~~ RRU-Net的基礎模塊 ~~~~~~~~~~class RRU_double_conv(nn.Module):def __init__(self, in_ch, out_ch):super(RRU_double_conv, self).__init__()self.conv = nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=2, dilation=2),nn.GroupNorm(32, out_ch),nn.ReLU(inplace=True),nn.Conv2d(out_ch, out_ch, 3, padding=2, dilation=2),nn.GroupNorm(32, out_ch))def forward(self, x):x = self.conv(x)return xclass RRU_first_down(nn.Module):def __init__(self, in_ch, out_ch):super(RRU_first_down, self).__init__()self.conv = RRU_double_conv(in_ch, out_ch)self.relu = nn.ReLU(inplace=True)self.res_conv = nn.Sequential(nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False),nn.GroupNorm(32, out_ch))self.res_conv_back = nn.Sequential(nn.Conv2d(out_ch, in_ch, kernel_size=1, bias=False))def forward(self, x):# the first ring convft1 = self.conv(x)r1 = self.relu(ft1 + self.res_conv(x))# the second ring convft2 = self.res_conv_back(r1)x = torch.mul(1 + torch.sigmoid(ft2), x)# the third ring convft3 = self.conv(x)r3 = self.relu(ft3 + self.res_conv(x))return r3class RRU_down(nn.Module):def __init__(self, in_ch, out_ch):super(RRU_down, self).__init__()self.conv = RRU_double_conv(in_ch, out_ch)self.relu = nn.ReLU(inplace=True)self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.res_conv = nn.Sequential(nn.Conv2d(in_ch, out_ch, 1, bias=False),nn.GroupNorm(32, out_ch))self.res_conv_back = nn.Sequential(nn.Conv2d(out_ch, in_ch, kernel_size=1, bias=False))def forward(self, x):x = self.pool(x)# the first ring convft1 = self.conv(x)r1 = self.relu(ft1 + self.res_conv(x))# the second ring convft2 = self.res_conv_back(r1)x = torch.mul(1 + torch.sigmoid(ft2), x)# the third ring convft3 = self.conv(x)r3 = self.relu(ft3 + self.res_conv(x))return r3class RRU_up(nn.Module):def __init__(self, in_ch, out_ch, bilinear=False):super(RRU_up, self).__init__()if bilinear:self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)else:self.up = nn.Sequential(nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2),nn.GroupNorm(32, in_ch // 2))self.conv = RRU_double_conv(in_ch, out_ch)self.relu = nn.ReLU(inplace=True)self.res_conv = nn.Sequential(nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False),nn.GroupNorm(32, out_ch))self.res_conv_back = nn.Sequential(nn.Conv2d(out_ch, in_ch, kernel_size=1, bias=False))def forward(self, x1, x2):x1 = self.up(x1)diffX = x2.size()[2] - x1.size()[2]diffY = x2.size()[3] - x1.size()[3]x1 = F.pad(x1, (diffY, 0,diffX, 0))x = self.relu(torch.cat([x2, x1], dim=1))# the first ring convft1 = self.conv(x)r1 = self.relu(self.res_conv(x) + ft1)# the second ring convft2 = self.res_conv_back(r1)x = torch.mul(1 + torch.sigmoid(ft2), x)# the third ring convft3 = self.conv(x)r3 = self.relu(ft3 + self.res_conv(x))return r3# !!!!!!!!!!!! Universal functions !!!!!!!!!!!!class outconv(nn.Module):def __init__(self, in_ch, out_ch):super(outconv, self).__init__()self.conv = nn.Conv2d(in_ch, out_ch, 1)def forward(self, x):x = self.conv(x)return x############# RRU-Net主干網絡 ############# class Ringed_Res_Unet(nn.Module):def __init__(self, n_channels=3, n_classes=1):super(Ringed_Res_Unet, self).__init__()self.down = RRU_first_down(n_channels, 32)self.down1 = RRU_down(32, 64)self.down2 = RRU_down(64, 128)self.down3 = RRU_down(128, 256)self.down4 = RRU_down(256, 256)self.up1 = RRU_up(512, 128)self.up2 = RRU_up(256, 64)self.up3 = RRU_up(128, 32)self.up4 = RRU_up(64, 32)self.out = outconv(32, n_classes)def forward(self, x):x1 = self.down(x)x2 = self.down1(x1)x3 = self.down2(x2)x4 = self.down3(x3)x5 = self.down4(x4)x = self.up1(x5, x4)x = self.up2(x, x3)x = self.up3(x, x2)x = self.up4(x, x1)x = self.out(x)return xif __name__ == "__main__":device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")img = torch.rand(1,3,256,256).to(device)RRUNet = Ringed_Res_Unet().to(device)output = RRUNet(img)print(output.shape)./train/RRUNet_NIST_train.py
import torch import torch.nn as nn import matplotlib as mat mat.use('Agg') import matplotlib.pyplot as plt import time import numpy as np import dataload.NIST_data as Dataload import network.RRUNet as Network######################### # RRUNet在NIST數據集上訓練 ########################## 計算模型準確率,召回率和f1分數 # output->[batch, 1, 256, 256] # img_gt->[batch, 1, 256, 256] def calprecise(output, img_gt):output = torch.sigmoid(output)mask = output > 0.5acc_mask = torch.mul(mask.float(),img_gt)acc_mask = acc_mask.sum()acc_fenmu = mask.sum()recall_fenmu = img_gt.sum()acc = acc_mask / (acc_fenmu + 0.0001)recall = acc_mask / (recall_fenmu + 0.0001)f1 = 2*acc*recall / (acc + recall + 0.0001)return acc, recall, f1BATCH = 8 LR = 0.1 EPOCH = 50# step為True表示打印每一步的結果 # save_epoch表示多少輪保存一次 def train(print_step=True, save_epoch=20):nist_train = Dataload.NIST_DATA("train")nist_val = Dataload.NIST_DATA("val")train_load = torch.utils.data.DataLoader(nist_train, batch_size=BATCH, shuffle=True)val_load = torch.utils.data.DataLoader(nist_val, batch_size=BATCH, shuffle=True)device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")net = Network.Ringed_Res_Unet()net = net.to(device)optimizer = torch.optim.SGD(net.parameters(), lr=LR, momentum=0.9, weight_decay=0.0005)lossFunction = nn.BCELoss()print("開始訓練??ヽ(°▽°)ノ?")losses = []acces = []recalles = []f1es = []for epoch in range(EPOCH):total_loss = 0acc_ = 0recall_ = 0f1es_ = 0st = time.time()for step,data in enumerate(train_load):img, img_gt = dataimg = img.to(device)img_gt = 1 - img_gtimg_gt = img_gt.to(device)pred_mask = net(img)pred_mask_sigmoid = torch.sigmoid(pred_mask)pred_mask_flat = pred_mask_sigmoid.view(-1)true_masks_flat = img_gt.view(-1)loss = lossFunction(pred_mask_flat, true_masks_flat)# 更新網絡optimizer.zero_grad()loss.backward()optimizer.step()acc, recall, f1 = calprecise(pred_mask, img_gt)if(print_step == True):print("step:%d->loss:%.4f acc:%.4f recall:%.4f f1:%.4f cost time:%ds"%(step, loss, acc, recall, f1, time.time()-st))total_loss = total_loss + lossacc_ = acc_ + accrecall_ = recall_ + recallf1es_ = f1es_ + f1# 計算每個epoch的平均指標loss_mean = total_loss / len(train_load)acc_mean = acc_ / len(train_load)recall_mean = recall_ / len(train_load)f1_mean = f1es_ / len(train_load)losses.append(loss_mean)acces.append(acc_mean)recalles.append(recall_mean)f1es.append(f1_mean)# 打印這一輪的信息print("epoch:%d->loss:%.4f acc:%.4f recall:%.4f f1:%.4f cost time:%ds" %(epoch, loss_mean, acc_mean, recall_mean, f1_mean, time.time() - st))if(epoch!=0 and epoch % save_epoch==0):torch.save(net.state_dict(), "./tmp/RRU_NIST_epoch_%d.pth"%(epoch))# 訓練結束torch.save(net.state_dict(), "./RRUNet_NIST.pth")# 繪圖x = np.arange(len(losses))plt.plot(x, losses, label="train")plt.title("train loss")plt.grid()plt.legend()plt.savefig("loss.jpg")plt.clf()plt.plot(x, acces, label="acc")plt.plot(x, recalles, label="recall")plt.plot(x, f1es, label="f1")plt.title("performance")plt.grid()plt.legend()plt.savefig("performance.jpg")plt.clf()print("訓練結束啦??ヽ(°▽°)ノ?")if __name__ == "__main__":train()./boot.py
from train.RRUNet_NIST_train import *if __name__ == "__main__":train()模型評估代碼
./eval/RRUNet_NIST_eval.py
import dataload.NIST_data as Dataload import network.RRUNet as Network import matplotlib.pyplot as plt import torch# 計算模型準確率,召回率和f1分數 # output->[batch, 1, 256, 256] # img_gt->[batch, 1, 256, 256] def calprecise(output, img_gt):output = torch.sigmoid(output)mask = output > 0.5acc_mask = torch.mul(mask.float(),img_gt)acc_mask = acc_mask.sum()acc_fenmu = mask.sum()recall_fenmu = img_gt.sum()acc = acc_mask / (acc_fenmu + 0.0001)recall = acc_mask / (recall_fenmu + 0.0001)f1 = 2*acc*recall / (acc + recall + 0.0001)return acc, recall, f1def eval(model):device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")net = Network.Ringed_Res_Unet().to(device)net.load_state_dict(torch.load(model))casia_val = Dataload.NIST_DATA("val")val_loader = torch.utils.data.DataLoader(casia_val, batch_size=1, shuffle=True)acces = 0recalles = 0f1es = 0for step, data in enumerate(val_loader):img, img_gt = dataimg_gt = 1 - img_gtimg = img.to(device)img_gt = img_gt.to(device)output = net(img)acc, recall, f1 = calprecise(output, img_gt)print("step%d-> acc:%.4f recall:%.4f f1:%.4f"%(step, acc, recall, f1))mask = torch.sigmoid(output)>0.5acces = acces + accrecalles = recalles + recallf1es = f1es + f1plt.subplot(1,3,1), plt.imshow(img[0].permute(1,2,0).cpu().numpy())plt.subplot(1,3,2), plt.imshow(img_gt[0][0].cpu().long().numpy())plt.subplot(1,3,3), plt.imshow(mask[0][0].cpu().long().numpy())plt.savefig("./pic/step_%d"%step)plt.clf()print("total outcome-> acc:%.4f recall:%.4f f1:%.4f"%(acces/len(val_loader), recalles/len(val_loader), f1es/len(val_loader)))if __name__ == "__main__":model = "./model/RRUNet_NIST.pth"eval(model)評估結果
總結
以上是生活随笔為你收集整理的pytorch复现RRU-Net的全部內容,希望文章能夠幫你解決所遇到的問題。
 
                            
                        - 上一篇: RRU入门学习
- 下一篇: 计算机英语的四种变量,计算机英语:BAS
