单样本学习与孪生网络
@miracle 在 單樣本學習(One shot learning)和孿生網絡(Siamese Network) 中說:
孿生網絡與偽孿生網絡
Siamese network就是“連體的神經網絡”,神經網絡的“連體”是通過共享權值來實現的,如下圖所示。共享權值意味著兩邊的網絡權重矩陣一模一樣,甚至可以是同一個網絡。
如果左右兩邊不共享權值,而是兩個不同的神經網絡,叫偽孿生網絡(pseudo-siamese network,偽孿生神經網絡),對于pseudo-siamese network,兩邊可以是不同的神經網絡(如一個是lstm,一個是cnn),也可以是相同類型的神經網絡。
原理
衡量兩個輸入的相似程度,輸出是一個[0,1]的浮點數,表示二者的相似程度。孿生神經網絡有兩個輸入(Input1 and Input2),將兩個輸入feed進入兩個神經網絡(Network1 and Network2),這兩個神經網絡分別將輸入映射到新的空間,**形成輸入在新的空間中的表示。**通過Loss的計算,評價兩個輸入的相似度。
孿生神經網絡和偽孿生神經網絡分別適用的場景
先上結論:孿生神經網絡用于處理兩個輸入**“比較類似"的情況。偽孿生神經網絡適用于處理兩個輸入"有一定差別”**的情況。比如,我們要計算兩個句子或者詞匯的語義相似度,使用siamese network比較適合;如果驗證標題與正文的描述是否一致(標題和正文長度差別很大),或者文字是否描述了一幅圖片(一個是圖片,一個是文字),就應該使用pseudo-siamese network。也就是說,要根據具體的應用,判斷應該使用哪一種結構,哪一種Loss。
用途
- 前面提到的詞匯的語義相似度分析,QA中question和answer的匹配,簽名/人臉驗證。
- 手寫體識別也可以用siamese network,網上已有github代碼。
- 還有kaggle上Quora的question pair的比賽,即判斷兩個提問是不是同一問題,冠軍隊伍用的就是n多特征+Siamese network,知乎團隊也可以拿這個模型去把玩一下。
- 在圖像上,基于Siamese網絡的視覺跟蹤算法也已經成為熱點《Fully-convolutional siamese networks for object tracking》。
- 單樣本學習
單樣本學習
定義問題
我們的模型只獲得了很少的標記的訓練樣本S,它有N個樣本,每個相同維度的向量有一個對應的標簽y
S={(x1,y1),…,(xN,yN)}S = \{(x_1,y_1), …, (x_N,y_N)\} S={(x1?,y1?),…,(xN?,yN?)}
再給出一個待分類的測試樣例 。因為樣本集中每個樣本都有一個正確的類別,我們的目標是正確的預測 中哪一個是 的正確標簽 。
一個單樣本學習的baseline–1近鄰
最簡單的分類方式是使用k近鄰方法,但是因為每個類別只有一個樣本,所以我們需要用1近鄰。這很簡單,只需要計算測試樣本與訓練集中每個樣本的的歐式距離,然后選擇最近的一個就可以了:
C(x^)=argmin?c∈S∣∣x^?xc∣∣C(\hat{x}) = \underset{c \in S}{\operatorname{argmin}} || \hat{x} - x_c || C(x^)=c∈Sargmin?∣∣x^?xc?∣∣
根據Koch等人的論文,在omniglot數據集中的20類上,單樣本分類,1-nn可以得到大約28%的精度,28%看起來很差,但是它已經是隨機猜測(5%)的6倍精度了。這是一個單樣本學習算法最好的baseline或者“合理性測試”了。
網絡架構
Koch等人使用卷積孿生網絡去分類成對的omniglot圖像,所以這兩個孿生網絡都是卷積神經網絡。這兩個孿生網絡每個的架構如下:64通道的10×10卷積核,relu->max pool->128通道的7×7卷積核,relu->max pool->128通道的4×4卷積核,relu->max pool->256通道的4×4卷積核。孿生網絡把輸入降低到越來越小的3d張量上,最終它們經過一個4096神經元的全連接層。兩個向量的絕對差作為線性分類器的輸入。這個網絡一共有38,951,745個參數–96%的參數屬于全連接層。這個參數量很大,所以網絡有很高的過擬合風風險,但是成對的訓練意味著數據集是很大的,所以過擬合問題不成出現。
輸出被歸一化到[0,1]之間,使用sigmoid函數讓它成為一個概率。當兩個圖像是相同類別的時候,我們使目標_t_=1,類別不相同的時候使_t_=0。它使用邏輯斯特回歸來訓練。這意味著損失函數應該是預測和目標之間的二分類交叉熵。損失函數中還有一個L2權重衰減項,以讓網絡可以學習更小的\更平滑的權重,從而提高泛化能力:
當網絡做單樣本學習的時候,孿生網絡簡單的分類一下測試圖像與訓練集中的圖像中哪個最相似就可以了:
這里使用argmax而不是近鄰方法中的argmin,因為類別越不同,L2度量的值越高,但是這個模型的輸出 ,所以我們要這個值最大。這個方法有一個明顯的缺陷:對于訓練集中的 ,概率 與訓練集中每個樣本都是獨立的!這意味著概率值的和不為1。言歸正傳,測試圖像與訓練圖像應該是相同類型的。。。
觀察一下:逐對訓練的有效的數據集大小
經過與以為UoA大學的博士討論后發現,我認為這個是過分夸大的,或者就是錯的。憑經驗來說,我的實現沒有過擬合,即使它沒有在每個可能的成對圖像上充分訓練,這與該節是沖突的。在有錯就說思想的指引下,我會保留這個問題。
我注意到,采用逐對訓練的話,將會有平方級別對的圖像對來訓練模型,這讓模型很難過擬合,好酷。假設我們有_E_類,每類有_C_個樣本。一共有 張圖片,總共可能的配方數量可以這樣計算:
對于omniglot中的964類(每類20個樣本),這會有185,849,560個可能的配對,這是巨大的!然而,孿生網絡需要相同類的和不同類的配對都有。每類_E_個訓練樣本,所以每個類別有 對,這意味著這里有 個相同類別的配對。–對于Omniglot有183,160對。及時183,160對已經很大了,但他只是所有可能配對的千分之一,因為相同類別的配對數量隨著E平方級的增大,但是隨著C是線性增加。這個問題非常重要,因為孿生網絡訓練的時候,同類別和不同類別的比例應該是1:1.–或許它表明逐對訓練在那種每個類別有更多樣本的數據集上更容易訓練。
代碼
如果你更喜歡用jupyter notebook?這里是傳送門
下面是模型定義,如果你見過keras,那很容易理解。我只用Sequential()來定義一次孿生網絡,然后使用兩個輸入層來調用它,這樣兩個輸入使用相同的參數。然后我們把它們使用絕對距離合并起來,添加一個輸出層,使用二分類交叉熵損失來編譯這個模型。
from keras.layers import Input, Conv2D, Lambda, merge, Dense, Flatten,MaxPooling2D from keras.models import Model, Sequential from keras.regularizers import l2 from keras import backend as K from keras.optimizers import SGD,Adam from keras.losses import binary_crossentropy import numpy.random as rng import numpy as np import os import dill as pickle import matplotlib.pyplot as plt from sklearn.utils import shuffledef W_init(shape,name=None):"""Initialize weights as in paper"""values = rng.normal(loc=0,scale=1e-2,size=shape)return K.variable(values,name=name) #//TODO: figure out how to initialize layer biases in keras. def b_init(shape,name=None):"""Initialize bias as in paper"""values=rng.normal(loc=0.5,scale=1e-2,size=shape)return K.variable(values,name=name)input_shape = (105, 105, 1) left_input = Input(input_shape) right_input = Input(input_shape) #build convnet to use in each siamese 'leg' convnet = Sequential() convnet.add(Conv2D(64,(10,10),activation='relu',input_shape=input_shape,kernel_initializer=W_init,kernel_regularizer=l2(2e-4))) convnet.add(MaxPooling2D()) convnet.add(Conv2D(128,(7,7),activation='relu',kernel_regularizer=l2(2e-4),kernel_initializer=W_init,bias_initializer=b_init)) convnet.add(MaxPooling2D()) convnet.add(Conv2D(128,(4,4),activation='relu',kernel_initializer=W_init,kernel_regularizer=l2(2e-4),bias_initializer=b_init)) convnet.add(MaxPooling2D()) convnet.add(Conv2D(256,(4,4),activation='relu',kernel_initializer=W_init,kernel_regularizer=l2(2e-4),bias_initializer=b_init)) convnet.add(Flatten()) convnet.add(Dense(4096,activation="sigmoid",kernel_regularizer=l2(1e-3),kernel_initializer=W_init,bias_initializer=b_init)) #encode each of the two inputs into a vector with the convnet encoded_l = convnet(left_input) encoded_r = convnet(right_input) #merge two encoded inputs with the l1 distance between them L1_distance = lambda x: K.abs(x[0]-x[1]) both = merge([encoded_l,encoded_r], mode = L1_distance, output_shape=lambda x: x[0]) prediction = Dense(1,activation='sigmoid',bias_initializer=b_init)(both) siamese_net = Model(input=[left_input,right_input],output=prediction) #optimizer = SGD(0.0004,momentum=0.6,nesterov=True,decay=0.0003)optimizer = Adam(0.00006) #//TODO: get layerwise learning rates and momentum annealing scheme described in paperworking siamese_net.compile(loss="binary_crossentropy",optimizer=optimizer)siamese_net.count_params()原論文中每個層的學習率和沖量都不相同–我跳過了這個步驟,因為使用keras來實現這個太麻煩了,并且超參數不是該論文的重點。Koch等人增加向訓練集中增加失真的圖像,使用150,000對樣本訓練模型。因為這個太大了,我的內存放不下,所以我決定使用隨機采樣的方法。載入圖像對或許是這個模型最難實現的部分。因為這里每個類別有20個樣本,我把數據重新調整為N_classes×20×105×105的數組,這樣可以很方便的來索引。
class Siamese_Loader:"""For loading batches and testing tasks to a siamese net"""def __init__(self,Xtrain,Xval):self.Xval = Xvalself.Xtrain = Xtrainself.n_classes,self.n_examples,self.w,self.h = Xtrain.shapeself.n_val,self.n_ex_val,_,_ = Xval.shapedef get_batch(self,n):"""Create batch of n pairs, half same class, half different class"""categories = rng.choice(self.n_classes,size=(n,),replace=False)pairs=[np.zeros((n, self.h, self.w,1)) for i in range(2)]targets=np.zeros((n,))targets[n//2:] = 1for i in range(n):category = categories[i]idx_1 = rng.randint(0,self.n_examples)pairs[0][i,:,:,:] = self.Xtrain[category,idx_1].reshape(self.w,self.h,1)idx_2 = rng.randint(0,self.n_examples)#pick images of same class for 1st half, different for 2ndcategory_2 = category if i >= n//2 else (category + rng.randint(1,self.n_classes)) % self.n_classespairs[1][i,:,:,:] = self.Xtrain[category_2,idx_2].reshape(self.w,self.h,1)return pairs, targetsdef make_oneshot_task(self,N):"""Create pairs of test image, support set for testing N way one-shot learning. """categories = rng.choice(self.n_val,size=(N,),replace=False)indices = rng.randint(0,self.n_ex_val,size=(N,))true_category = categories[0]ex1, ex2 = rng.choice(self.n_examples,replace=False,size=(2,))test_image = np.asarray([self.Xval[true_category,ex1,:,:]]*N).reshape(N,self.w,self.h,1)support_set = self.Xval[categories,indices,:,:]support_set[0,:,:] = self.Xval[true_category,ex2]support_set = support_set.reshape(N,self.w,self.h,1)pairs = [test_image,support_set]targets = np.zeros((N,))targets[0] = 1return pairs, targetsdef test_oneshot(self,model,N,k,verbose=0):"""Test average N way oneshot learning accuracy of a siamese neural net over k one-shot tasks"""passn_correct = 0if verbose:print("Evaluating model on {} unique {} way one-shot learning tasks ...".format(k,N))for i in range(k):inputs, targets = self.make_oneshot_task(N)probs = model.predict(inputs)if np.argmax(probs) == 0:n_correct+=1percent_correct = (100.0*n_correct / k)if verbose:print("Got an average of {}% {} way one-shot learning accuracy".format(percent_correct,N))return percent_correct下面是訓練過程了。沒什么特別的,除了我監測的是驗證機精度來測試性能,而不是驗證集上的損失。
evaluate_every = 7000 loss_every=300 batch_size = 32 N_way = 20 n_val = 550 siamese_net.load_weights("PATH") best = 76.0 for i in range(900000):(inputs,targets)=loader.get_batch(batch_size)loss=siamese_net.train_on_batch(inputs,targets)if i % evaluate_every == 0:val_acc = loader.test_oneshot(siamese_net,N_way,n_val,verbose=True)if val_acc >= best:print("saving")siamese_net.save('PATH')best=val_accif i % loss_every == 0:print("iteration {}, training loss: {:.2f},".format(i,loss))引用
總結
以上是生活随笔為你收集整理的单样本学习与孪生网络的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: WGAN-GP
- 下一篇: object detection