Tensorlfow2.0 二分类和多分类focal loss实现和在文本分类任务效果评估
Tensorlfow2.0 二分類和多分類focal loss實現和在文本分類任務效果評估
- 前言
- 二分類 focal loss
- 多分類 focal loss
- 測試結果
- 二分類focal_loss結果
- 多分類focal_loss結果
- 總結
前言
最近看了focal loss的文章,正好在做文本分類的項目,一個是Sentence Bert句子匹配,一個是網易云音樂評論的情緒分類。本人用的框架是tensorflow2.0,所以想嘗試實踐一下focal loss,但是翻遍了網上的文章,不是代碼報錯就是錯誤實現。最后就自己根據focal loss的公式寫了一個,試跑了代碼確認無誤。
tensorflow :2.0.0(GPU上跑)
transformers :3.1
二分類 focal loss
from tensorflow.python.ops import array_ops def binary_focal_loss(target_tensor,prediction_tensor, alpha=0.25, gamma=2):zeros = array_ops.zeros_like(prediction_tensor, dtype=prediction_tensor.dtype)target_tensor = tf.cast(target_tensor,prediction_tensor.dtype)pos_p_sub = array_ops.where(target_tensor > zeros, target_tensor - prediction_tensor, zeros)neg_p_sub = array_ops.where(target_tensor > zeros, zeros, prediction_tensor)per_entry_cross_ent = - alpha * (pos_p_sub ** gamma) * tf.math.log(tf.clip_by_value(prediction_tensor, 1e-8, 1.0)) \- (1 - alpha) * (neg_p_sub ** gamma) * tf.math.log(tf.clip_by_value(1.0 - prediction_tensor, 1e-8, 1.0))return tf.math.reduce_sum(per_entry_cross_ent)使用方法:
model.compile(optimizer=optimizer,loss=binary_focal_loss,metrics=['acc'])幾個注意的點:
多分類 focal loss
def softmax_focal_loss(label,pred,class_num=6, gamma=2):label = tf.squeeze(tf.cast(tf.one_hot(tf.cast(label,tf.int32),class_num),pred.dtype)) pred = tf.clip_by_value(pred, 1e-8, 1.0)w1 = tf.math.pow((1.0-pred),gamma)L = - tf.math.reduce_sum(w1 * label * tf.math.log(pred))return L使用方法
bert_ner_model.compile(optimizer=optimizer, loss=softmax_focal_loss,metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])幾個注意的點:
測試結果
二分類focal_loss結果
此處是在Sentence Bert模型上的測試結果
binary_crossentropy 結果
可以看到用tf自帶的binary_crossentropy,訓練一輪就已經有過擬合的趨勢了,該用focal_loss可以很好的抑制模型過擬合且模型效果也有1個多點的提升。
多分類focal_loss結果
此處是在Roberta評論情緒多分類數據上的結果
發現loss會下降的越來越慢,這是正常的,需要訓練的輪次也變多,因為這里對loss乘了(1-pred)**gama的系數所以整體更新速度會變慢。
對比下使用sparse_cross_entropy結果:
發現并沒有提升,這可能與我的數據集類別分布比較平衡有關。所以focal_loss的使用場景還是要看自己的數據集情況。
總結
focal loss的使用還需要根據自己的數據集情況來判斷,當樣本不平衡性較強時使用focal loss會有較好的提升,在多分類上使用focal loss得到的效果目前無法很好的評估。
完整的模型代碼之后會專門寫一個博客來講,用 tf2.0.0 + transformers 搭一個Sentence Bert也借鑒了很多pytroch的代碼,tf實現比較少,也是自己慢慢摸索出來的。
總結
以上是生活随笔為你收集整理的Tensorlfow2.0 二分类和多分类focal loss实现和在文本分类任务效果评估的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 袋鼯麻麻——智能购物平台
- 下一篇: Tensorflow2.0 + Tra