机器学习中绘制(多标签)PR曲线和F1-score【转载】
生活随笔
收集整理的這篇文章主要介紹了
机器学习中绘制(多标签)PR曲线和F1-score【转载】
小編覺得挺不錯的,現(xiàn)在分享給大家,幫大家做個參考.
參考:PR曲線和F1-score 評價指標(biāo)相關(guān) - 知乎
sklearn官方文檔:Precision-Recall — scikit-learn 1.0.2 documentation
?
多標(biāo)簽設(shè)置中的PR曲線
??????? 查了好多文檔,但依舊看得稀里糊涂。上面兩個鏈接講的非常清楚了,所以在這里記錄一下。具體原理不多說,看上面鏈接就行,下面直接上代碼:
一、數(shù)據(jù)說明:
????????n_classes,標(biāo)簽類別總數(shù)
??????? Y_test, 真實標(biāo)簽(one-hot形式,轉(zhuǎn)換為nudarry格式)
????????y_score,預(yù)測結(jié)果得分(轉(zhuǎn)換為nudarry格式)
二、導(dǎo)入:
import numpy as np from sklearn.metrics import precision_recall_curve,average_precision_score,PrecisionRecallDisplayimport matplotlib.pyplot as plt from itertools import cycle三、代碼:
def plot_pr_multi_label(n_classes, Y_test, y_score):# For each classprecision = dict()recall = dict()average_precision = dict()for i in range(n_classes):precision[i], recall[i], _ = precision_recall_curve(Y_test[:, i], y_score[:, i])average_precision[i] = average_precision_score(Y_test[:, i], y_score[:, i])# print(recall)# print(average_precision)# A "micro-average": quantifying score on all classes jointlyprecision["micro"], recall["micro"], _ = precision_recall_curve(Y_test.ravel(), y_score.ravel())average_precision["micro"] = average_precision_score(Y_test, y_score, average="micro")# print(precision)# print(average_precision)# 繪制平均PR曲線display = PrecisionRecallDisplay(recall=recall["micro"], precision=precision["micro"],average_precision=average_precision["micro"], )display.plot()_ = display.ax_.set_title("Micro-averaged over all classes")# 繪制每個類的PR曲線和 iso-f1 曲線# setup plot detailscolors = cycle(["navy", "turquoise", "darkorange", "cornflowerblue", "teal"])_, ax = plt.subplots(figsize=(7, 8))f_scores = np.linspace(0.2, 0.8, num=4)lines, labels = [], []for f_score in f_scores:x = np.linspace(0.01, 1)y = f_score * x / (2 * x - f_score)(l,) = plt.plot(x[y >= 0], y[y >= 0], color="gray", alpha=0.2)plt.annotate("f1={0:0.1f}".format(f_score), xy=(0.9, y[45] + 0.02))display = PrecisionRecallDisplay(recall=recall["micro"], precision=precision["micro"],average_precision=average_precision["micro"], )display.plot(ax=ax, name="Micro-average precision-recall", color="gold")for i, color in zip(range(n_classes), colors):display = PrecisionRecallDisplay(recall=recall[i], precision=precision[i],average_precision=average_precision[i], )display.plot(ax=ax, name=f"Precision-recall for class {i}", color=color)# add the legend for the iso-f1 curveshandles, labels = display.ax_.get_legend_handles_labels()handles.extend([l])labels.extend(["iso-f1 curves"])# print(l)# print(handles)# print(labels)# set the legend and the axesax.set_xlim([0.0, 1.0])ax.set_ylim([0.0, 1.05])ax.legend(handles=handles, labels=labels, loc="best")ax.set_title("Extension of Precision-Recall curve to multi-class")# plt.show()四、結(jié)果:
(1) 平均精度
(2)每個類的PR曲線和 iso-f1 曲線
總結(jié)
以上是生活随笔為你收集整理的机器学习中绘制(多标签)PR曲线和F1-score【转载】的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 焦绪录:大数据如何推动数字中国建设
- 下一篇: 杭电1874畅通工程绪