风格迁移(Style Transfer)首次学习总结
0、寫在前面
最近看了吳恩達老師風格遷移相關的講解視頻,深受啟發,于是想著做做總結。
1、主要思想
目的:把一張內容圖片(content image)的風格遷移成與另一張圖片(style image)風格一致。
(圖自論文:A Neural Algorithm of Artistic Style)
?方法:通過約束 Content Loss 和 Style Loss 來生成最終的圖片。
1.0 activation(representation)、kernel(filter)、channel 和 input
用一個已經 pretrained 好的網絡(如 Resnet-50)作為 backbone 來提取圖片每一層的特征。
每一個 filter 用來檢測該層 input 的某一種特征,如果有這種特征,那么輸出(activation)中對應 channel 就會被“點亮”(數值大)。
比如:假設這個 pretrained 好的網絡中第一層有一個 filter 用來檢測圖像中處于水平狀態的邊緣,那么,如果圖片(input)左上角有一些水平的邊緣,那這個圖片在該層的輸出(activation)中對應 channel 左上角的數值就會比較大。
?1.1 Content Loss
要保證 Style Transferred Image 和原 Content Image 的內容盡可能相似【即,原 Content Image 左上角有處于水平狀態的邊緣,那么 Style Transferred Image 左上角也要有處于水平狀態的邊緣】,就意味著?Content Image 和 Style?Transferred Image 經過同一個 pretrained 好的網絡后,其對應層的輸出(activation)要盡可能一致。
比如 ,Content Image 左上角有一些水平的邊緣,則 activation 中 channel i 的左上角數值就會比較大,那么,我們也希望?Style?Transferred Image 的 activation 中 channel i 的左上角數值也比較大(盡可能接近)。
所以,Content Loss 定義如下:
?(圖自論文:A Neural Algorithm of Artistic Style)
公式中的?representation 就是 activation。
?
1.2 Style Loss
論文中關于 Style 的定義如下:
we built a style representation that computes the correlations between the different filter responses, where the expectation is taken over the spatial extend of the input imag
一張圖片的 style 可以定義為某一層的 activation 里 channel 與 channel 之間的 correlation 矩陣。
比如:某張 Style 圖片里左上角部分全是紅色水平邊緣的元素。
那么檢測水平邊緣特征的 filter 得到的 channel i 和檢測紅色特征的 filter 得到的 channel j 高亮(數值大)的地方就都會在左上角,那么,這兩個 channel 對應位置相乘得到的數值就會比較大(10 * 10 = 100)。
假如此時還有一個檢測藍色特征的 filter,那么其對應得到的 channel k 左上角部分就不怎么會被點亮(數值小)【因為左上角部分全是紅色水平邊緣的元素】。那么,檢測水平邊緣特征的 filter 得到的 channel i 和?檢測藍色特征的 filter 得到的 channel k 對應位置的乘積就可能會比較小(10 * 0.5?= 5)。
那么,一張圖片的 style 矩陣定義如下:
?(圖自吳恩達老師的課程 ppt)
其中,k 和 k' 代表兩個不同的 channel;l 是指第 l 層。
那么要保證?Style?Transferred Image 和 Style Image 的風格相近,也就是讓兩張圖片的風格矩陣盡可能相似。所以 Style Loss 定義如下:
?(圖自論文:A Neural Algorithm of Artistic Style)
2、示例代碼
我在 Github 上找了一個能跑得通的示例代碼:https://github.com/Zhenye-Na/neural-style-pytorch
其中的核心代碼如下:
Content Loss & Style Loss & Style Matrix
class ContentLoss(nn.Module):"""Content Loss."""def __init__(self, target,):"""Initialize content loss"""super(ContentLoss, self).__init__()# we 'detach' the target content from the tree used# to dynamically compute the gradient: this is a stated value,# not a variable. Otherwise the forward method of the criterion# will throw an error.self.target = target.detach()def forward(self, inputs):"""Forward pass."""self.loss = F.mse_loss(inputs, self.target)return inputsclass StyleLoss(nn.Module):"""Style Loss."""def __init__(self, target_feature):"""Initialize style loss."""super(StyleLoss, self).__init__()self.target = gram_matrix(target_feature).detach()def forward(self, inputs):"""Forward pass."""G = gram_matrix(inputs)self.loss = F.mse_loss(G, self.target)return inputsdef gram_matrix(inputs):"""Gram matrix."""a, b, c, d = inputs.size()# resise F_XL into \hat F_XLfeatures = inputs.view(a * b, c * d)# compute the gram productG = torch.mm(features, features.t())return G.div(a * b * c * d)train
for epoch in range(0, self.args.epochs):def closure():# correct the values of updated input imageinput_img.data.clamp_(0, 1)self.optimizer.zero_grad()model(input_img)style_score = 0content_score = 0for sl in style_losses:style_score += sl.lossfor cl in content_losses:content_score += cl.lossstyle_score *= self.style_weightcontent_score *= self.content_weightloss = style_score + content_scoreloss.backward()if epoch % 5 == 0:print("Epoch {}: Style Loss : {:4f} Content Loss: {:4f}".format(epoch, style_score.item(), content_score.item()))return style_score + content_scoreself.optimizer.step(closure)優化器
def _init_optimizer(self, input_img):"""Initialize LBFGS optimizer."""self.optimizer = optim.LBFGS([input_img.requires_grad_()])注意!這里只更新 input image,網絡是不進行學習的!
還有,這是一個我之前沒用過的優化器:LBFGS,其中有一個參數:
max_iter (int): maximal number of iterations per optimization step (default: 20)
這也就是為什么訓練結果里每一個 epoch 更新會有二十次打印信息(iteration)了,之前一直想不通,我還找半天代碼里哪里有 20 這個數字。。。
運行結果
?
總結
以上是生活随笔為你收集整理的风格迁移(Style Transfer)首次学习总结的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 【文献阅读】ChangeNet——变化检
- 下一篇: Nature子刊:中大骆观正组在RNA修