top中的res只增不减_tensorflow中张量排序与accuracy计算
今天跟B站“奔跑的東兵衛(wèi)”大佬的教程,學(xué)習(xí)了在tensorflow中進(jìn)行張量排序和accuracy計(jì)算的內(nèi)容。
大概的想法是這樣。
模型計(jì)算導(dǎo)出的預(yù)測(cè)結(jié)果output的shape是[b,N],真實(shí)值target的shape是[b]
b是batch,就是本次運(yùn)算有多少sample,N是結(jié)果的可能性有幾種
先對(duì)output的每一行(即每個(gè)sample在各個(gè)類(lèi)的概率)進(jìn)行大小的比較,得到一個(gè)大小不變但是概率變成排序數(shù)字tensor,其shape還是[b, N]。
pred = tf.math.top_k(output, maxk).indices然后對(duì)這個(gè)tensor進(jìn)行轉(zhuǎn)置
pred = tf.transpose(pred, perm=[1,0])pred的shape就變成了[N, b]
為了能讓pred和真實(shí)值target能夠進(jìn)行比較,需要把target的shape也變成和pred一樣,這里用broadcast_to函數(shù)
target_ = tf.broadcast_to(target,pred.shape)然后就是比較pred和target,相同的數(shù)字變成True,不同的變成False
correct = tf.equal(pred,target_) #[N,b]這個(gè)的shape也是[N, b]
為了計(jì)算有幾個(gè)True,需要把True和False對(duì)應(yīng)變成1和0,其實(shí)本身T和F就是用1和0表示的,但是數(shù)據(jù)格式不一樣,因此用tf.cast函數(shù)改一下數(shù)據(jù)格式,另外我們選擇的top k的k不一樣的話后面對(duì)應(yīng)選擇correct的前k行進(jìn)行相加。這里的top k就是說(shuō)對(duì)于每個(gè)sample,預(yù)測(cè)結(jié)果的前k個(gè)包括真實(shí)值就算是正確。
correct_k = tf.cast(tf.reshape(correct[:k], [-1]), dtype=tf.float32)得到的結(jié)果加起來(lái)是幾就是有幾個(gè)預(yù)測(cè)對(duì)了。
而accuracy就是加起來(lái)的結(jié)果/b
下面記一下代碼。
import tensorflow as tf import osos.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' tf.random.set_seed(2467)def accuracy(output, target, topk = (1,)):maxk = max(topk)batch_size = target.shape[0]pred = tf.math.top_k(output, maxk).indicespred = tf.transpose(pred, perm=[1,0])target_ = tf.broadcast_to(target,pred.shape)correct = tf.equal(pred,target_) #[k,b]res = []for k in topk:correct_k = tf.cast(tf.reshape(correct[:k], [-1]), dtype=tf.float32)correct_k = tf.reduce_sum(correct_k)acc = float(correct_k / batch_size)res.append(acc)return resoutput = tf.random.normal([10, 6]) output = tf.math.softmax(output, axis=1) target = tf.random.uniform([10], maxval=6, dtype=tf.int32) print('prob: ', output.numpy()) pred = tf.argmax(output, axis=1) print('pred: ', pred.numpy()) print('label: ', target.numpy())acc = accuracy(output, target, topk=(1,2,3,4,5,6)) print('top-1-6 acc:', acc)總結(jié)
以上是生活随笔為你收集整理的top中的res只增不减_tensorflow中张量排序与accuracy计算的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: python画五角星填充不同颜色_Pyt
- 下一篇: string获取 倒数 下标_Redis