facenet 中心损失函数(center loss)详解(代码分析)含tf.gather() 和 tf.scatter_sub()函数
生活随笔
收集整理的這篇文章主要介紹了
facenet 中心损失函数(center loss)详解(代码分析)含tf.gather() 和 tf.scatter_sub()函数
小編覺得挺不錯的,現(xiàn)在分享給大家,幫大家做個參考.
我們來解讀一下,中心損失,再來看代碼。
鏈接:https://www.cnblogs.com/carlber/p/10811396.html
我們的重點(diǎn)是分析代碼,所以定義部分,大家詳情參見上面的博客。
代碼:
#coding=gbk ''' Created on 2020年4月20日@author: DELL ''' import tensorflow as tf import numpy as npdata = [[1,1,1,1,1],[1,1,2,1,1],[1,1,3,1,1],[1,1,4,1,1],[2,2,2,1,2],[2,2,2,2,2],[2,2,2,3,2],[3,3,3,3,1],[3,3,3,3,2]]label = [0,0,0,0,1,1,1,2,2]data = np.array(data,dtype = 'float32') label = np.array(label)data = tf.convert_to_tensor(data) label = tf.convert_to_tensor(label)def center_loss(features, label, alfa, nrof_classes):"""Center loss based on the paper "A Discriminative Feature Learning Approach for Deep Face Recognition"(http://ydwen.github.io/papers/WenECCV16.pdf)"""nrof_features = features.get_shape()[1]centers = tf.get_variable('centers', [nrof_classes, nrof_features], dtype=tf.float32,initializer=tf.constant_initializer(0), trainable=False)#定義一個全零的centers, [nrof_classes, nrof_features]->(類別數(shù),特征維度)#print(sess.run(centers))label = tf.reshape(label, [-1]) #一維向量centers_batch = tf.gather(centers, label) #[batch_size,nrof_features] #按照label將centers歸類,形成的新矩陣維度為 [label_size,nrof_features]diff = (1 - alfa) * (centers_batch - features) #乘上我們的因子alfa [label_size,nrof_features]centers = tf.scatter_sub(centers, label, diff) #按照label用centers - diff,產(chǎn)生本次的centerswith tf.control_dependencies([centers]):#注意這個函數(shù)的作用,是限制計算順序的,即先計算centers,在利用計算好的centers去計算centers_batch以求lossloss = tf.reduce_mean(tf.square(features - centers_batch))return loss, centers,features,centers_batch,features - centers_batchloss, cen, fea, cen_bat,a = center_loss(data,label,0.5,3)sess = tf.Session() init = tf.global_variables_initializer() sess.run(init)print(sess.run(cen)) #print(sess.run(loss)) print(sess.run(fea)) #print(sess.run(cen_bat)) print(sess.run(a)) print(sess.run(fea - cen_bat)) print(sess.run(tf.square(fea - cen_bat))) print(sess.run(loss))'''驗(yàn)證tf.scatter_sub函數(shù) sess = tf.Session() ref = tf.Variable([1, 2, 3],dtype = tf.int32) indices = tf.constant([0, 0, 1, 1],dtype = tf.int32) updates = tf.constant([9, 10, 11, 12],dtype = tf.int32) sub = tf.scatter_sub(ref, indices, updates) with tf.Session() as sess:sess.run(tf.global_variables_initializer())print (sess.run(sub)) '''結(jié)果:
1.centers: [[2. 2. 5. 2. 2. ][3. 3. 3. 3. 3. ][3. 3. 3. 3. 1.5]] 2.features: [[1. 1. 1. 1. 1.][1. 1. 2. 1. 1.][1. 1. 3. 1. 1.][1. 1. 4. 1. 1.][2. 2. 2. 1. 2.][2. 2. 2. 2. 2.][2. 2. 2. 3. 2.][3. 3. 3. 3. 1.][3. 3. 3. 3. 2.]] 3.centers_batch [[2. 2. 5. 2. 2. ][2. 2. 5. 2. 2. ][2. 2. 5. 2. 2. ][2. 2. 5. 2. 2. ][3. 3. 3. 3. 3. ][3. 3. 3. 3. 3. ][3. 3. 3. 3. 3. ][3. 3. 3. 3. 1.5][3. 3. 3. 3. 1.5]] 4.features - centers_batch [[-1. -1. -4. -1. -1. ][-1. -1. -3. -1. -1. ][-1. -1. -2. -1. -1. ][-1. -1. -1. -1. -1. ][-1. -1. -1. -2. -1. ][-1. -1. -1. -1. -1. ][-1. -1. -1. 0. -1. ][ 0. 0. 0. 0. -0.5][ 0. 0. 0. 0. 0.5]] 5.loss 1.4111111主要用到的函數(shù):1.tf.gather(data,labels),將data按labels擴(kuò)充
? ? ? ? ? ? ? ? ? ? ? ? ? ? ?2.tf.scatter_sub(data,label,data_1),按label用data - data_
? ? ? ? ? ? ? ? ? ? ? ? ? ? ?3.with tf.control_dependencies(): ,限制運(yùn)算順序
在實(shí)驗(yàn)驗(yàn)證時注意的點(diǎn)是:不要多次sess.run()某個張量涉及到帶有依賴關(guān)系的張量,比如這里的loss,計算loss時 會 主動更新一次值,導(dǎo)致運(yùn)算結(jié)果出錯。原理我還沒搞清,日后補(bǔ)上
總結(jié)
以上是生活随笔為你收集整理的facenet 中心损失函数(center loss)详解(代码分析)含tf.gather() 和 tf.scatter_sub()函数的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: tensorflow 之 ValuErr
- 下一篇: python 查看 .npy文件 和 .