Python CNN风格迁移
生活随笔
收集整理的這篇文章主要介紹了
Python CNN风格迁移
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
| 本文僅供學習交流使用,如侵立刪!demo下載見文末 |
環境:
python:3.6.7
tensorflow=2.2.0
torch=1.5.1+cpu
torchvision=0.6.0+cpu
Pillow=8.1.2
效果:
輸入原圖和風格圖生成遷移圖片
# -*- coding: utf-8 -*- import cv2 import timedef style_transfer(pathIn='',pathOut='',model='',width=None,jpg_quality=80):'''pathIn: 原始圖片的路徑pathOut: 風格化圖片的保存路徑model: 預訓練模型的路徑width: 設置風格化圖片的寬度,默認為None, 即原始圖片尺寸jpg_quality: 0-100,設置輸出圖片的質量,默認80,越大圖片質量越好'''# 讀入原始圖片,調整圖片至所需尺寸,然后獲取圖片的寬度和高度img = cv2.imread(pathIn)(h, w) = img.shape[:2]if width is not None:img = cv2.resize(img, (width, round(width * h / w)), interpolation=cv2.INTER_CUBIC)(h, w) = img.shape[:2]# 從本地加載預訓練模型print('加載預訓練模型......')net = cv2.dnn.readNetFromTorch(model)# 將圖片構建成一個blob:設置圖片尺寸,將各通道像素值減去平均值(比如ImageNet所有訓練樣本各通道統計平均值)# 然后執行一次前饋網絡計算,并輸出計算所需的時間blob = cv2.dnn.blobFromImage(img, 1.0, (w, h), (103.939, 116.779, 123.680), swapRB=False, crop=False)net.setInput(blob)start = time.time()output = net.forward()end = time.time()print("風格遷移花費:{:.2f}秒".format(end - start))# reshape輸出結果, 將減去的平均值加回來,并交換各顏色通道output = output.reshape((3, output.shape[2], output.shape[3]))output[0] += 103.939output[1] += 116.779output[2] += 123.680output = output.transpose(1, 2, 0)# 輸出風格化后的圖片cv2.imwrite(pathOut, output, [int(cv2.IMWRITE_JPEG_QUALITY), jpg_quality])CNN模型訓練
import __future__ import torch import torchvision.models as models import torchvision.transforms as transforms import time import os from PIL import Imagedevice = torch.device("cuda" if torch.cuda.is_available() else "cpu") ''' 環境: python:3.6.7 tensorflow==2.2.0 torch==1.5.1+cpu torchvision==0.6.0+cpu Pillow==8.1.2 '''class GramMatrix(torch.nn.Module):def forward(self, input):b, n, h, w = input.size()features = input.view(b * n, h * w)G = torch.mm(features, features.t())return G.div(b * n * h * w)class StyleLoss(torch.nn.Module):def __init__(self, style_feature, weight):super(StyleLoss, self).__init__()self.style_feature = style_feature.detach()self.weight = weightself.gram = GramMatrix()self.criterion = torch.nn.MSELoss()def forward(self, combination):# output = combinationstyle_feature = self.gram(self.style_feature.clone() * self.weight)combination_features = self.gram(combination.clone() * self.weight)self.loss = self.criterion(combination_features, style_feature)return combinationclass StyleTransfer:def __init__(self, content_image, style_image, style_weight=5, content_weight=0.025):# Weights of the different loss componentsself.vgg19 = models.vgg19()self.vgg19.load_state_dict(torch.load('vgg19-dcbb9e9d.pth'))self.img_ncols = 1280self.img_nrows = 720self.style_weight = style_weightself.content_weight = content_weight# 處理原圖和風格圖片self.content_tensor, self.content_name = self.process_img(content_image)self.style_tensor, self.style_name = self.process_img(style_image)self.conbination_tensor = self.content_tensor.clone()def process_img(self, img_path):img = Image.open(img_path)img_name = img_path.split('/')[-1][:-4]loader = transforms.Compose([transforms.Resize((self.img_nrows, self.img_ncols)),transforms.ToTensor()])img_tensor = loader(img)img_tensor = img_tensor.unsqueeze(0)return img_tensor.to(device, torch.float), img_namedef deprocess_img(self, x, index):unloader = transforms.ToPILImage()x = x.cpu().clone()img_tensor = x.squeeze(0)img = unloader(img_tensor)result_folder = f'{self.content_name}_and_{self.style_name}'os.path.exists(result_folder) or os.mkdir(result_folder)filename = f'{result_folder}/rersult_{index}.png'img.save(filename)print(f'save {filename} successfully!')print()def get_loss_and_model(self, vgg_model, content_image, style_image):vgg_layers = vgg_model.features.to(device).eval()style_losses = []content_losses = []model = torch.nn.Sequential()style_layer_name_maping = {'0': "style_loss_1",'5': "style_loss_2",'10': "style_loss_3",'19': "style_loss_4",'28': "style_loss_5",}content_layer_name_maping = {'30': "content_loss"}for name, module in vgg_layers._modules.items():model.add_module(name, module)if name in content_layer_name_maping:content_feature = model(content_image).clone()content_loss = ContentLoss(content_feature, self.content_weight)model.add_module(f'{content_layer_name_maping[name]}', content_loss)content_losses.append(content_loss)if name in style_layer_name_maping:style_feature = model(style_image).clone()style_loss = StyleLoss(style_feature, self.style_weight)style_losses.append(style_loss)model.add_module(f'{style_layer_name_maping[name]}', style_loss)return content_losses, style_losses, modeldef get_input_param_optimizer(self, input_img):input_param = torch.nn.Parameter(input_img.data)optimizer = torch.optim.LBFGS([input_param])return input_param, optimizerdef main_train(self, epoch=10):print('Load model preprocessing')combination_param, optimizer = self.get_input_param_optimizer(self.conbination_tensor)content_losses, style_losses, model = self.get_loss_and_model(self.vgg19, self.content_tensor,self.style_tensor)cur, pre = 10, 10for i in range(1, epoch + 1):start = time.time()def closure():combination_param.data.clamp_(0, 1)optimizer.zero_grad()model(combination_param)style_score = 0content_score = 0for cl in content_losses:content_score += cl.lossfor sl in style_losses:style_score += sl.lossloss = content_score + style_scoreloss.backward()return style_score + content_scoreloss = optimizer.step(closure)cur, pre = loss, curend = time.time()print(f'|using:{int(end - start):2d}s |epoch:{i:2d} |loss:{loss.data}')if pre <= cur:print('Early stopping!')breakcombination_param.data.clamp_(0, 1)| 本文僅供學習交流使用,如侵立刪! |
總結
以上是生活随笔為你收集整理的Python CNN风格迁移的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: torch torchvision 下载
- 下一篇: Python 搜狗站长平台批量添加域名+