深度学习入门(五十二)计算机视觉——风格迁移
深度學習入門(五十二)計算機視覺——風格遷移
- 前言
- 計算機視覺——風格遷移
- 課件
- 樣式遷移
- 易于CNN的樣式遷移
- 教材
- 1 方法
- 2 閱讀內容和風格圖像
- 3 預處理和后處理
- 4 抽取圖像特征
- 5 定義損失函數
- 5.1 內容損失
- 5.2 風格損失
- 5.3 全變分損失
- 5.4 損失函數
- 6 初始化合成圖像
- 7 訓練模型
- 8 小結
- 參考文獻
前言
核心內容來自博客鏈接1博客連接2希望大家多多支持作者
本文記錄用,防止遺忘
計算機視覺——風格遷移
課件
樣式遷移
將樣式圖片中的樣式遷移到內容圖片上,得到合成圖片
易于CNN的樣式遷移
奠基性工作:
教材
如果你是一位攝影愛好者,你也許接觸過濾波器。它能改變照片的顏色風格,從而使風景照更加銳利或者令人像更加美白。但一個濾波器通常只能改變照片的某個方面。如果要照片達到理想中的風格,你可能需要嘗試大量不同的組合。這個過程的復雜程度不亞于模型調參。
在本節中,我們將介紹如何使用卷積神經網絡,自動將一個圖像中的風格應用在另一圖像之上,即風格遷移(style transfer) 。 這里我們需要兩張輸入圖像:一張是內容圖像,另一張是風格圖像。 我們將使用神經網絡修改內容圖像,使其在風格上接近風格圖像。
1 方法
下圖用簡單的例子闡述了基于卷積神經網絡的風格遷移方法。 首先,我們初始化合成圖像,例如將其初始化為內容圖像。 該合成圖像是風格遷移過程中唯一需要更新的變量,即風格遷移所需迭代的模型參數。 然后,我們選擇一個預訓練的卷積神經網絡來抽取圖像的特征,其中的模型參數在訓練中無須更新。 這個深度卷積神經網絡憑借多個層逐級抽取圖像的特征,我們可以選擇其中某些層的輸出作為內容特征或風格特征。 以下圖例,這里選取的預訓練的神經網絡含有3個卷積層,其中第二層輸出內容特征,第一層和第三層輸出風格特征。
接下來,我們通過前向傳播(實線箭頭方向)計算風格遷移的損失函數,并通過反向傳播(虛線箭頭方向)迭代模型參數,即不斷更新合成圖像。 風格遷移常用的損失函數由3部分組成:
(i)內容損失使合成圖像與內容圖像在內容特征上接近;
(ii)風格損失使合成圖像與風格圖像在風格特征上接近;
(iii)全變分損失則有助于減少合成圖像中的噪點。 最后,當模型訓練結束時,我們輸出風格遷移的模型參數,即得到最終的合成圖像。
在下面,我們將通過代碼來進一步了解風格遷移的技術細節。
2 閱讀內容和風格圖像
首先,我們讀取內容和風格圖像。 從打印出的圖像坐標軸可以看出,它們的尺寸并不一樣。
%matplotlib inline import torch import torchvision from torch import nn from d2l import torch as d2ld2l.set_figsize() content_img = d2l.Image.open('../img/rainier.jpg') d2l.plt.imshow(content_img);輸出:
輸出:
3 預處理和后處理
下面,定義圖像的預處理函數和后處理函數。 預處理函數preprocess對輸入圖像在RGB三個通道分別做標準化,并將結果變換成卷積神經網絡接受的輸入格式。 后處理函數postprocess則將輸出圖像中的像素值還原回標準化之前的值。 由于圖像打印函數要求每個像素的浮點數值在0到1之間,我們對小于0和大于1的值分別取0和1。
torchvision.transforms模塊有大量現成的轉換方法,不過需要注意的是有的方法輸入的是PIL圖像,如Resize;有的方法輸入的是tensor,如Normalize;而還有的是用于二者轉換,如ToTensor將PIL圖像轉換成tensor。一定要注意這點
rgb_mean = torch.tensor([0.485, 0.456, 0.406]) rgb_std = torch.tensor([0.229, 0.224, 0.225])def preprocess(img, image_shape):transforms = torchvision.transforms.Compose([torchvision.transforms.Resize(image_shape),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)])return transforms(img).unsqueeze(0)def postprocess(img):img = img[0].to(rgb_std.device)img = torch.clamp(img.permute(1, 2, 0) * rgb_std + rgb_mean, 0, 1)return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))4 抽取圖像特征
我們使用基于ImageNet數據集預訓練的VGG-19模型來抽取圖像特征
PyTorch官方在torchvision.models模塊提供了一些常見的預訓練好的計算機視覺模型,包括圖片分類、語義分割、目標檢測、實例分割、人關鍵點檢測和視頻分類等等。
pretrained_net = torchvision.models.vgg19(pretrained=True)第一次執行上述代碼會把預訓練好的模型參數下載到環境變量TORCH_HOME指定的位置,如果沒有該環境變量的話默認位置是.cache/torch。
為了抽取圖像的內容特征和風格特征,我們可以選擇VGG網絡中某些層的輸出。 一般來說,越靠近輸入層,越容易抽取圖像的細節信息;反之,則越容易抽取圖像的全局信息。 為了避免合成圖像過多保留內容圖像的細節,我們選擇VGG較靠近輸出的層,即內容層,來輸出圖像的內容特征。 我們還從VGG中選擇不同層的輸出來匹配局部和全局的風格,這些圖層也稱為風格層。 正如VGG一節中所介紹的,VGG網絡使用了5個卷積塊。 實驗中,我們選擇第四卷積塊的最后一個卷積層作為內容層,選擇每個卷積塊的第一個卷積層作為風格層。 這些層的索引可以通過打印pretrained_net實例獲取。
style_layers, content_layers = [0, 5, 10, 19, 28], [25]使用VGG層抽取特征時,我們只需要用到從輸入層到最靠近輸出層的內容層或風格層之間的所有層。 下面構建一個新的網絡net,它只保留需要用到的VGG的所有層。
net = nn.Sequential(*[pretrained_net.features[i] for i inrange(max(content_layers + style_layers) + 1)])給定輸入X,如果我們簡單地調用前向傳播net(X),只能獲得最后一層的輸出。 由于我們還需要中間層的輸出,因此這里我們逐層計算,并保留內容層和風格層的輸出。
def extract_features(X, content_layers, style_layers):contents = []styles = []for i in range(len(net)):X = net[i](X)if i in style_layers:styles.append(X)if i in content_layers:contents.append(X)return contents, styles下面定義兩個函數:get_contents函數對內容圖像抽取內容特征; get_styles函數對風格圖像抽取風格特征。 因為在訓練時無須改變預訓練的VGG的模型參數,所以我們可以在訓練開始之前就提取出內容特征和風格特征。 由于合成圖像是風格遷移所需迭代的模型參數,我們只能在訓練過程中通過調用extract_features函數來抽取合成圖像的內容特征和風格特征。
def get_contents(image_shape, device):content_X = preprocess(content_img, image_shape).to(device)contents_Y, _ = extract_features(content_X, content_layers, style_layers)return content_X, contents_Ydef get_styles(image_shape, device):style_X = preprocess(style_img, image_shape).to(device)_, styles_Y = extract_features(style_X, content_layers, style_layers)return style_X, styles_Y5 定義損失函數
下面我們來描述風格遷移的損失函數。 它由內容損失、風格損失和全變分損失3部分組成。
5.1 內容損失
與線性回歸中的損失函數類似,內容損失通過平方誤差函數衡量合成圖像與內容圖像在內容特征上的差異。 平方誤差函數的兩個輸入均為extract_features函數計算所得到的內容層的輸出。
def content_loss(Y_hat, Y):# 我們從動態計算梯度的樹中分離目標:# 這是一個規定的值,而不是一個變量。return torch.square(Y_hat - Y.detach()).mean()5.2 風格損失
風格損失與內容損失類似,也通過平方誤差函數衡量合成圖像與風格圖像在風格上的差異。 為了表達風格層輸出的風格,我們先通過extract_features函數計算風格層的輸出。 假設該輸出的樣本數為1,通道數為ccc,高和寬分別為hhh和www,我們可以將此輸出轉換為矩陣X\mathbf{X}X,其有ccc行和hwhwhw列。 這個矩陣可以被看作是由ccc個長度為hwhwhw的向量x1,…,xc\mathbf{x}_1, \ldots, \mathbf{x}_cx1?,…,xc?組合而成的。其中向量xi\mathbf{x}_ixi?代表了通道iii上的風格特征。
在這些向量的格拉姆矩陣XX?∈Rc×c\mathbf{X}\mathbf{X}^\top \in \mathbb{R}^{c \times c}XX?∈Rc×c中,iii行jjj列的元素xijx_{ij}xij?即向量xi\mathbf{x}_ixi?和xj\mathbf{x}_jxj?的內積。它表達了通道iii和通道jjj上風格特征的相關性。我們用這樣的格拉姆矩陣來表達風格層輸出的風格。 需要注意的是,當hwhwhw的值較大時,格拉姆矩陣中的元素容易出現較大的值。 此外,格拉姆矩陣的高和寬皆為通道數ccc。 為了讓風格損失不受這些值的大小影響,下面定義的gram函數將格拉姆矩陣除以了矩陣中元素的個數,即chwchwchw。
def gram(X):num_channels, n = X.shape[1], X.numel() // X.shape[1]X = X.reshape((num_channels, n))return torch.matmul(X, X.T) / (num_channels * n)自然地,風格損失的平方誤差函數的兩個格拉姆矩陣輸入分別基于合成圖像與風格圖像的風格層輸出。這里假設基于風格圖像的格拉姆矩陣gram_Y已經預先計算好了。
def style_loss(Y_hat, gram_Y):return torch.square(gram(Y_hat) - gram_Y.detach()).mean()5.3 全變分損失
有時候,我們學到的合成圖像里面有大量高頻噪點,即有特別亮或者特別暗的顆粒像素。 一種常見的去噪方法是全變分去噪(total variation denoising): 假設xi,jx_{i, j}xi,j?表示坐標(i,j)(i, j)(i,j)處的像素值,降低全變分損失
∑i,j∣xi,j?xi+1,j∣+∣xi,j?xi,j+1∣\sum_{i, j} \left|x_{i, j} - x_{i+1, j}\right| + \left|x_{i, j} - x_{i, j+1}\right|i,j∑?∣xi,j??xi+1,j?∣+∣xi,j??xi,j+1?∣
能夠盡可能使鄰近的像素值相似。
def tv_loss(Y_hat):return 0.5 * (torch.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() +torch.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())5.4 損失函數
風格轉移的損失函數是內容損失、風格損失和總變化損失的加權和。 通過調節這些權重超參數,我們可以權衡合成圖像在保留內容、遷移風格以及去噪三方面的相對重要性。
content_weight, style_weight, tv_weight = 1, 1e3, 10def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram):# 分別計算內容損失、風格損失和全變分損失contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip(contents_Y_hat, contents_Y)]styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip(styles_Y_hat, styles_Y_gram)]tv_l = tv_loss(X) * tv_weight# 對所有損失求和l = sum(10 * styles_l + contents_l + [tv_l])return contents_l, styles_l, tv_l, l6 初始化合成圖像
在風格遷移中,合成的圖像是訓練期間唯一需要更新的變量。因此,我們可以定義一個簡單的模型SynthesizedImage,并將合成的圖像視為模型參數。模型的前向傳播只需返回模型參數即可。
class SynthesizedImage(nn.Module):def __init__(self, img_shape, **kwargs):super(SynthesizedImage, self).__init__(**kwargs)self.weight = nn.Parameter(torch.rand(*img_shape))def forward(self):return self.weight下面,我們定義get_inits函數。該函數創建了合成圖像的模型實例,并將其初始化為圖像X。風格圖像在各個風格層的格拉姆矩陣styles_Y_gram將在訓練前預先計算好
def get_inits(X, device, lr, styles_Y):gen_img = SynthesizedImage(X.shape).to(device)gen_img.weight.data.copy_(X.data)trainer = torch.optim.Adam(gen_img.parameters(), lr=lr)styles_Y_gram = [gram(Y) for Y in styles_Y]return gen_img(), styles_Y_gram, trainer7 訓練模型
在訓練模型進行風格遷移時,我們不斷抽取合成圖像的內容特征和風格特征,然后計算損失函數。下面定義了訓練循環
def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch):X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y)scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_decay_epoch, 0.8)animator = d2l.Animator(xlabel='epoch', ylabel='loss',xlim=[10, num_epochs],legend=['content', 'style', 'TV'],ncols=2, figsize=(7, 2.5))for epoch in range(num_epochs):trainer.zero_grad()contents_Y_hat, styles_Y_hat = extract_features(X, content_layers, style_layers)contents_l, styles_l, tv_l, l = compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram)l.backward()trainer.step()scheduler.step()if (epoch + 1) % 10 == 0:animator.axes[1].imshow(postprocess(X))animator.add(epoch + 1, [float(sum(contents_l)),float(sum(styles_l)), float(tv_l)])return X現在我們訓練模型: 首先將內容圖像和風格圖像的高和寬分別調整為300和450像素,用內容圖像來初始化合成圖像。
device, image_shape = d2l.try_gpu(), (300, 450) net = net.to(device) content_X, contents_Y = get_contents(image_shape, device) _, styles_Y = get_styles(image_shape, device) output = train(content_X, contents_Y, styles_Y, device, 0.3, 500, 50)輸出:
我們可以看到,合成圖像保留了內容圖像的風景和物體,并同時遷移了風格圖像的色彩。例如,合成圖像具有與風格圖像中一樣的色彩塊,其中一些甚至具有畫筆筆觸的細微紋理。
8 小結
-
風格遷移常用的損失函數由3部分組成:(i)內容損失使合成圖像與內容圖像在內容特征上接近;(ii)風格損失令合成圖像與風格圖像在風格特征上接近;(iii)全變分損失則有助于減少合成圖像中的噪點。
-
我們可以通過預訓練的卷積神經網絡來抽取圖像的特征,并通過最小化損失函數來不斷更新合成圖像來作為模型參數。
-
我們使用格拉姆矩陣表達風格層輸出的風格。
參考文獻
[1] Gatys, L. A., Ecker, A. S., & Bethge, M. (2016). Image style transfer using convolutional neural networks. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (pp. 2414-2423).
總結
以上是生活随笔為你收集整理的深度学习入门(五十二)计算机视觉——风格迁移的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: LightningChart解决方案:X
- 下一篇: 聚N-乙烯基乙酰胺接枝丙烯腈/苯乙烯聚合