8--风格迁移
????????使用卷積神經(jīng)網(wǎng)絡,自動將一個圖像中的風格應用在另一圖像之上,即風格遷移(style transfer)。需要輸入兩張圖片,一張是原圖,另一張是想要應用的風格圖像,如下圖所示,最后輸出風格遷移后的合成圖像。
8.1 方法?
????????首先,初始化最后的合成圖像,例如將其初始化為內(nèi)容圖像,該合成圖像是風格遷移過程中唯一需要更新的變量。然后,選擇一個預訓練的卷積神經(jīng)網(wǎng)絡來抽取圖像的特征,其中的模型參數(shù)在訓練中無須更新。 這個深度卷積神經(jīng)網(wǎng)絡憑借多個層逐級抽取圖像的特征,我們可以選擇其中某些層的輸出作為內(nèi)容特征或風格特征。
? ? ? ? 這里的三個卷積層都是同一個卷積網(wǎng)絡中的,如圖所示,將第二層的輸出作為內(nèi)容特征(越靠近底層的風格特征越接近原圖),第一層和第三層的輸出作為風格特征。
? ? ? ? 這里通過控制合成圖像與內(nèi)容圖像和樣式圖像之間的損失,來得到最終的合成圖像。減少合成圖像中的噪點還引入了總變差損失。
8.2 預處理和后處理函數(shù)
? ? ? ? 預處理函數(shù)用來對輸入圖像在RGB三個通道分別做標準化,并將結(jié)果變換成卷積神經(jīng)網(wǎng)絡接受的輸入格式。后處理函數(shù)用來對輸出圖像中的像素值還原回標準化之前的值。
#均值和方差是先驗經(jīng)驗 rgb_mean = torch.tensor([0.485, 0.456, 0.406]) rgb_std = torch.tensor([0.229, 0.224, 0.225]) #預處理函數(shù)實現(xiàn)對圖片進行resize并對三個通道進行標準化 def preprocess(img,image_shape):transforms = torchvision.transforms.Compose([torchvision.transforms.Resize(image_shape),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize(mean=rgb_mean,std=rgb_std)])return transforms(img).unsqueeze(0) #后處理函數(shù) def postprocess(img):img = img[0].to(rgb_std.device)#Tensor將輸入input張量每個元素的夾緊到區(qū)間 [min,max][min,max],并返回結(jié)果到一個新張量。#這里就是把img返回到標準化前的值并固定到0-1之間便于打印圖片img = torch.clamp(img.permute(1,2,0)* rgb_std + rgb_mean,0,1)return torchvision.transforms.ToPILImage()(img.permute(2,0,1))8.3 抽取圖像的特征?
?????????使用基于ImageNet數(shù)據(jù)集預訓練的VGG-19模型來抽取圖像特征。這里將0,5,10,19,28層的輸出作為風格特征(選擇不同層的輸出來匹配局部和全局的風格),將25層的輸出作為內(nèi)容特征(為了避免合成圖像過多保留內(nèi)容圖像的細節(jié))。
pretrained_net = torchvision.models.vgg19(pretrained=True)style_layers, content_layers = [0, 5, 10, 19, 28], [25]net = nn.Sequential(*[pretrained_net.features[i] for i inrange(max(content_layers + style_layers) + 1)])????????給定輸入X,如果我們簡單地調(diào)用前向傳播net(X),只能獲得最后一層的輸出。 由于我們還需要中間層的輸出,因此這里我們逐層計算,并保留內(nèi)容層和風格層的輸出。?
def extract_features(X, content_layers, style_layers):contents = []styles = []for i in range(len(net)):X = net[i](X)if i in style_layers:styles.append(X)if i in content_layers:contents.append(X)return contents, styles????????get_contents函數(shù)對內(nèi)容圖像抽取內(nèi)容特征;?get_styles函數(shù)對風格圖像抽取風格特征。 因為在訓練時無須改變預訓練的VGG的模型參數(shù),所以可以在訓練開始之前就提取出內(nèi)容特征和風格特征。?
def get_contents(image_shape, device):content_X = preprocess(content_img, image_shape).to(device)contents_Y, _ = extract_features(content_X, content_layers, style_layers)return content_X, contents_Ydef get_styles(image_shape, device):style_X = preprocess(style_img, image_shape).to(device)_, styles_Y = extract_features(style_X, content_layers, style_layers)return style_X, styles_Y8.3 定義損失函數(shù)?
????????由內(nèi)容損失、風格損失和全變分損失3部分組成。這里全變分損失是由于學到的合成圖像里面有大量高頻噪點,即有特別亮或者特別暗的顆粒像素。 能夠盡可能使鄰近的像素值相似。假設(shè)xi,j表示坐標(i,j)處的像素值。全變分損失表示為:
#內(nèi)容損失 def content_loss(Y_hat, Y):# 我們從動態(tài)計算梯度的樹中分離目標:# 這是一個規(guī)定的值,而不是一個變量。return torch.square(Y_hat - Y.detach()).mean()#風格損失 def gram(X):num_channels, n = X.shape[1], X.numel() // X.shape[1]X = X.reshape((num_channels, n))return torch.matmul(X, X.T) / (num_channels * n) def style_loss(Y_hat, gram_Y):return torch.square(gram(Y_hat) - gram_Y.detach()).mean() #全變分損失 def tv_loss(Y_hat):return 0.5 * (torch.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() +torch.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())????????損失函數(shù)是內(nèi)容損失、風格損失和總變化損失的加權(quán)和,調(diào)節(jié)這些權(quán)重超參數(shù),我們可以權(quán)衡合成圖像在保留內(nèi)容、遷移風格以及去噪三方面的相對重要性。?
content_weight,style_weight,tv_weight = 1,1e4,10def compute_loss(X,contents_Y_hat,styles_Y_hat,contents_Y,styles_Y_gram):contents_l = [content_loss(Y_hat,Y)*content_weight for Y_hat, Y in zip(contents_Y_hat, contents_Y)]styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip(styles_Y_hat, styles_Y_gram)]tv_l = tv_loss(X) * tv_weightl = sum(10*styles_l + contents_l + [tv_l])return contents_l, styles_l, tv_l, l8.4 初始化合成圖像?
? ? ? ? 訓練期間唯一需要更新的變量就是合成的圖像,定義一個簡單的模型SynthesizedImage,并將合成的圖像視為模型參數(shù)。模型的前向傳播只需返回模型參數(shù)即可。
class SynthesizedImage(nn.Module):def __init__(self, img_shape, **kwargs):super(SynthesizedImage, self).__init__(**kwargs)self.weight = nn.Parameter(torch.rand(*img_shape))def forward(self):return self.weight????????定義get_inits函數(shù)創(chuàng)建了合成圖像的模型實例,并將其初始化為圖像X,并提前計算好風格圖像在各個風格層的格拉姆矩陣。
def get_inits(X, device, lr, styles_Y):gen_img = SynthesizedImage(X.shape).to(device)gen_img.weight.data.copy_(X.data)trainer = torch.optim.Adam(gen_img.parameters(), lr=lr)#風格圖像在各個風格層的格拉姆矩陣styles_Y_gram = [gram(Y) for Y in styles_Y]return gen_img(), styles_Y_gram, trainer8.5 訓練?
def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch):X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y)scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_decay_epoch, 0.8)animator = d2l.Animator(xlabel='epoch', ylabel='loss',xlim=[10, num_epochs],legend=['content', 'style', 'TV'],ncols=2, figsize=(7, 2.5))for epoch in range(num_epochs):trainer.zero_grad()contents_Y_hat, styles_Y_hat = extract_features(X, content_layers, style_layers)contents_l, styles_l, tv_l, l = compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram)l.backward()trainer.step()scheduler.step()if (epoch + 1) % 10 == 0:animator.axes[1].imshow(postprocess(X))animator.add(epoch + 1, [float(sum(contents_l)),float(sum(styles_l)), float(tv_l)])return Xdevice, image_shape = d2l.try_gpu(), (300, 450) net = net.to(device) content_X, contents_Y = get_contents(image_shape, device) _, styles_Y = get_styles(image_shape, device) output = train(content_X, contents_Y, styles_Y, device, 0.3, 500, 50)? ? ? ? 運行結(jié)果如下,前面兩張圖分別是輸入的內(nèi)容圖片和風格樣式圖片,最后一張圖為格遷移后的合成圖像。
? ? ? ? ?
????????這里我將第一次的輸出結(jié)果又作為輸入再次進行了風格遷移,運行結(jié)果如下所示,這里可以看出輸出的圖像跟style_image更加接近。?
總結(jié)
- 上一篇: Java中CAS操作本身怎么保证原子性及
- 下一篇: windows 7系统安装虚拟机及在虚拟