MoCo论文中的Algorithm 1伪代码解读
生活随笔
收集整理的這篇文章主要介紹了
MoCo论文中的Algorithm 1伪代码解读
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
具體解讀了什么東西
論文中提供的偽代碼大約如下:
下面我將分步驟介紹這個代碼干什么
1.query encoder和key encoder的參數初始化
其實也沒表達什么就是一開始大家的參數是一樣的:
f_k.params = f_q.params2.之后就是loader當中取數據
這個也沒啥的就是取出來數據的問題:
for x in loader: # load a minibatch x with N samples3.數據增強
就是代碼不是直接將內容輸入其中,也會通過數據增強取出內容
x_q = aug(x) # a randomly augmented version x_k = aug(x) # another randomly augmented version4.核心操作
首先我們先理解一下這個N和C是什么?
q = f_q.forward(x_q) # queries: NxC k = f_k.forward(x_k) # keys: NxCN其實是一個batch_size
C是一個輸入數據的特征數,每個輸入數據是一個1×C的張量
這個其實就是文章的主要創新點了,因為優化key_encoder是來自于query_encoder的優化。所以自然就不需要前傳梯度,也能剩下個內存。
這里是矩陣乘法,理解一下這里的矩陣乘法:
# positive logits: Nx1 l_pos = bmm(q.view(N,1,C), k.view(N,C,1)) # negative logits: NxK l_neg = mm(q.view(N,C), queue.view(C,K)) # logits: Nx(1+K) logits = cat([l_pos, l_neg], dim=1)- 1.首先我們應當理解一下這個q和k到底是什么東西,可以看到q和k分別來自于x_q和x_k,我們注意這兩個東西其實都來自于x只是作了不同的數據增強罷了。
好了,現在我們應該能判斷出來,這里的x和k我們應該認為同一個類別。 - 2.l_pos 現在我們就知道這個東西應該是一個N*1的一組接近1的數值
- 3.我們注意queue是我們存儲的之前的batch的內容,所以這個東西和我們當前這個batch的內容應該是沒有任何交集的,也就是他們來自于不同的內容,按照對比學習的思想,來自不同事物的內容應該完全不相交。所以他們的相似度應該盡量的低。
- 4.l_neg應當得到一個N*K的一組接近0的數值。
- 5.logits的內容就自然而然出現了,應該為一個N*(K+1)的內容,這些內容應該具有下面的特點:K+1的向量除了第一位接近1之外其他都應該接近0。
- 6.在現在的情況下我們自然而然可以得出一個內容就是,每個(K+1)的張量經過softmax之后,模型都應該判別其為正確。也就是所有的N個張量都是0號分類。
5.交叉熵loss
這里其實不能完全的算成交叉熵損失函數,這個是一個帶有熱度的交叉熵損失函數。但是其實我們可以將其想成交叉熵函數來簡化理解:
# contrastive loss, Eqn.(1) labels = zeros(N) # positives are the 0-th loss = CrossEntropyLoss(logits/t, labels)之前我們談過了,這里的所有內容都應該是第0個分類,所以我們這里直接讓所有的分類都是第0分類就完事了。
模型更新
下面是很正常的backward
# SGD update: query network loss.backward() update(f_q.params)然后就是本文核心的動量優化
# momentum update: key network f_k.params = m*f_k.params+(1-m)*f_q.params其實就是讓keyencoder也和queryencoder做相同方向的優化
7.更新字典
首先理解什么是字典,就是和什么比較的問題,這個字典就是我們用來和學習的內容比較的內容。這里其實就是實現了將這個batchsize的內容出隊將新的這個batchsize進隊。
# update dictionary enqueue(queue, k) # enqueue the current minibatch dequeue(queue) # dequeue the earliest minibatch總結
以上是生活随笔為你收集整理的MoCo论文中的Algorithm 1伪代码解读的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 暴力解决:InvocationExcep
- 下一篇: Batch Normalization的