多分类 数据不平衡的处理 lightgbm
前言
數(shù)據(jù)不平衡問題在機(jī)器學(xué)習(xí)分類問題中很常見,尤其是涉及到“異常檢測"類型的分類。因?yàn)楫惓R话阒傅南鄬Σ怀R姷默F(xiàn)象,因此發(fā)生的機(jī)率必然要小很多。因此正常類的樣本量會遠(yuǎn)遠(yuǎn)高于異常類的樣本量,一般高達(dá)幾個數(shù)量級。比如:疾病相關(guān)的樣本,正常的樣本會遠(yuǎn)高于疾病的樣本,即便是當(dāng)下流行的COVID-19。比如kaggle 競賽的信用卡交易欺詐(credit card fraud),正常交易與欺詐類交易比例大于10000:1。再比如工業(yè)中常見的故障診斷數(shù)據(jù),正常運(yùn)行的時間段會遠(yuǎn)遠(yuǎn)高于停機(jī)(故障)時間。
開題
首先我們提出一個問題:為什么數(shù)據(jù)不平衡會對機(jī)器模型產(chǎn)生影響?原因很直觀,因?yàn)橛?xùn)練集中的數(shù)據(jù)如果不平衡,“機(jī)器” 會集中解決大多數(shù)的數(shù)據(jù)的問題,而會忽視了少數(shù)類的數(shù)據(jù)。就像少數(shù)民族會不占優(yōu)勢。既然是基于大樣本訓(xùn)練的機(jī)器模型,無法避免地被主要樣本帶偏。
關(guān)鍵問題來了:那我們?nèi)绾巫屔贁?shù)類獲得同等的地位,然后被模型同等對待呢?今天我們可以通過一個實(shí)戰(zhàn)樣本來看看有哪些技巧能降低數(shù)據(jù)不平衡帶來的影響。
數(shù)據(jù)源準(zhǔn)備
數(shù)據(jù)源是NSL-KDD 數(shù)據(jù)包。數(shù)據(jù)源來自:https://www.unb.ca/cic/datasets/nsl.html。簡單介紹一下數(shù)據(jù)源,NSL-KDD是為解決在中KDD'99數(shù)據(jù)集的某些固有問題而推薦的數(shù)據(jù)集。盡管該數(shù)據(jù)集可能無法完美地代表現(xiàn)有的現(xiàn)實(shí)網(wǎng)絡(luò)世界,但是很多論文依然可以用它作有效的基準(zhǔn)數(shù)據(jù)集,以幫助研究人員比較不同的入侵檢測方法。
本文數(shù)據(jù)集來源于github的整理半成品。https://github.com/arjbah/nsl-kdd.git (include the most attack types) 和https://github.com/defcom17/NSL_KDD.git。數(shù)據(jù)集比較分散,train_file 和test_file 只包含樣本特征和標(biāo)簽值,但是沒有表頭(header),表頭的信息包含在field_name_file 中,另外關(guān)于網(wǎng)絡(luò)攻擊類型,分為5個大類,40多個小類,但是我們該測試中只預(yù)測5個大類。數(shù)據(jù)源略點(diǎn)凌亂,所以我們需要在代碼中稍作歸類。代碼入場:
# import packages
import pandas as pd
"""
DATASET SOURCE is from https://github.com/arjbah/nsl-kdd.git (include the most attack types)
https://github.com/defcom17/NSL_KDD.git
"""
train_file = 'https://raw.githubusercontent.com/arjbah/nsl-kdd/master/nsl-kdd/KDDTrain%2B.txt'
test_file = 'https://raw.githubusercontent.com/arjbah/nsl-kdd/master/nsl-kdd/KDDTest%2B.txt'
field_name_file = 'https://raw.githubusercontent.com/defcom17/NSL_KDD/master/Field%20Names.csv'
attack_type_file = 'https://raw.githubusercontent.com/arjbah/nsl-kdd/master/training_attack_types.txt'
這里就是常規(guī)的pandas 讀csv 或txt 操作,僅僅注意一下列表頭/列名稱的處理。
field_names_df = pd.read_csv(
    field_name_file, header=None, names=[
        'name', 'data_type']) # 定義dataframe ,并給個column name,方便索引
    field_names = field_names_df['name'].tolist()
field_names += ['label', 'label_code'] # 源文件中沒有標(biāo)簽名稱,以及等級信息
df = pd.read_csv(train_file, header=None, names=field_names)
df_test = pd.read_csv(test_file, header=None, names=field_names)
attack_type_df = pd.read_csv(
    attack_type_file, sep=' ', header=None, names=[
        'name', 'attack_type'])
attack_type_dict = dict(
    zip(attack_type_df['name'].tolist(), attack_type_df['attack_type'].tolist())) # 定義5大類和小類的映射字典,方便替代
df.drop('label_code', axis=1, inplace=True) # 最后一列 既無法作為feature,也不是我們的label,刪掉
df_test.drop('label_code', axis=1, inplace=True)
df['label'].replace(attack_type_dict, inplace=True) # 替換label 為5 大類
df_test['label'].replace(attack_type_dict, inplace=True)
數(shù)據(jù)一覽(不平衡分布)
數(shù)據(jù)已經(jīng)準(zhǔn)備好,我們可以初步瀏覽一下數(shù)據(jù)結(jié)構(gòu)。
print(df.info())
結(jié)果如下:
Data columns (total 42 columns):
 #   Column                       Non-Null Count   Dtype
---  ------                       --------------   -----
 0   duration                     125973 non-null  int64
 1   protocol_type                125973 non-null  object
 2   service                      125973 non-null  object
 3   flag                         125973 non-null  object
 4   src_bytes                    125973 non-null  int64
 5   dst_bytes                    125973 non-null  int64
 6   land                         125973 non-null  int64
 7   wrong_fragment               125973 non-null  int64
 8   urgent                       125973 non-null  int64
 9   hot                          125973 non-null  int64
 10  num_failed_logins            125973 non-null  int64
 11  logged_in                    125973 non-null  int64
 12  num_compromised              125973 non-null  int64
 13  root_shell                   125973 non-null  int64
 14  su_attempted                 125973 non-null  int64
 15  num_root                     125973 non-null  int64
 16  num_file_creations           125973 non-null  int64
 17  num_shells                   125973 non-null  int64
 18  num_access_files             125973 non-null  int64
 19  num_outbound_cmds            125973 non-null  int64
 20  is_host_login                125973 non-null  int64
 21  is_guest_login               125973 non-null  int64
 22  count                        125973 non-null  int64
 23  srv_count                    125973 non-null  int64
 24  serror_rate                  125973 non-null  float64
 25  srv_serror_rate              125973 non-null  float64
 26  rerror_rate                  125973 non-null  float64
 27  srv_rerror_rate              125973 non-null  float64
 28  same_srv_rate                125973 non-null  float64
 29  diff_srv_rate                125973 non-null  float64
 30  srv_diff_host_rate           125973 non-null  float64
 31  dst_host_count               125973 non-null  int64
 32  dst_host_srv_count           125973 non-null  int64
 33  dst_host_same_srv_rate       125973 non-null  float64
 34  dst_host_diff_srv_rate       125973 non-null  float64
 35  dst_host_same_src_port_rate  125973 non-null  float64
 36  dst_host_srv_diff_host_rate  125973 non-null  float64
 37  dst_host_serror_rate         125973 non-null  float64
 38  dst_host_srv_serror_rate     125973 non-null  float64
 39  dst_host_rerror_rate         125973 non-null  float64
 40  dst_host_srv_rerror_rate     125973 non-null  float64
 41  label                        125973 non-null  object
dtypes: float64(15), int64(23), object(4)
首先我們來看label的分布:
from collections import Counter
# 簡單定義一個print 函數(shù)
def print_label_dist(label_col):
    c = Counter(label_col)
    print(f'label is {c}')
print_label_dist(df['label'])
print_label_dist(df_test['label'])
可以看到分布為:
label is Counter({'normal': 67343, 'dos': 45927, 'probe': 11656, 'r2l': 995, 'u2r': 52})
label is Counter({'normal': 9711, 'dos': 7636, 'r2l': 2574, 'probe': 2423, 'u2r': 200})
為了更直觀的對比,我們可以看一下countplot 的結(jié)果。
import seaborn as sns
train_label= df[['label']]
train_label['type'] = 'train'
test_label= df_test[['label']]
test_label['type'] = 'test'
label_all = pd.concat([train_label,test_label],axis=0)
print(label_all)
print(test_label)
sns.countplot(x='label',hue='type', data=label_all)
這是典型的不平衡數(shù)據(jù),正常的樣本量遠(yuǎn)大于其他類別的樣本量,尤其是u2r樣本類別。
“硬train一”發(fā)作為baseline
okay,首先我們來“硬train一發(fā)”。最后一列為標(biāo)簽,也就是我們要分類的對象,會被分離出特征矩陣。
    Y = df[‘label’]
    Y_test = df_test[‘label’]
    X = df.drop(‘label’, axis=1)
    X_test = df_test.drop(‘label’, axis=1)
對于決策樹類型的機(jī)器學(xué)習(xí)模型,單個特征的單調(diào)變化不會對最終結(jié)果產(chǎn)生影響,因?yàn)槲覀儫o需log或者歸一化處理。
本文我們不進(jìn)行過多的特征工程,因?yàn)槲覀兇舜螌?shí)驗(yàn)中不會對特征進(jìn)行EDA分析。我們只進(jìn)行最基本的預(yù)處理,有三個feature為object 類型,也就是離散數(shù)據(jù),這個需要我們預(yù)處理,我們會采用one-hot 進(jìn)行處理。為了方便,我們寫兩個小函數(shù),方便重復(fù)調(diào)用。
# 分離離散變量
def split_category(data, columns):
    cat_data = data[columns]
    rest_data = data.drop(columns, axis=1)
    return rest_data, cat_data
#  轉(zhuǎn)所有離散變量為one-hot
def one_hot_cat(data):
    if isinstance(data, pd.Series):
        data = pd.DataFrame(data, columns=[data.name])
    out = pd.DataFrame([])
    for col in data.columns:
        one_hot_cols = pd.get_dummies(data[col], prefix=col)
        out = pd.concat([out, one_hot_cols], axis=1)
    out.set_index(data.index)
    return out
# categorical_columns
categorical_mask = (X.dtypes == object)
categorical_columns = X.columns[categorical_mask].tolist()
X, X_cat = split_category(X, categorical_columns)
X_test, X_test_cat = split_category(X_test, categorical_columns)
# convert to one-hot
X_cat_one_hot_cols = one_hot_cat(X_cat)
X_test_cat_one_hot_cols = one_hot_cat(X_test_cat)
# align train to test
X_cat_one_hot_cols, X_test_cat_one_hot_cols = X_cat_one_hot_cols.align(
    X_test_cat_one_hot_cols, join=‘inner’, axis=1)
X_cat_one_hot_cols.fillna(0, inplace=True)
X_test_cat_one_hot_cols.fillna(0, inplace=True)
X = pd.concat([X, X_cat_one_hot_cols], axis=1)
X_test = pd.concat([X_test, X_test_cat_one_hot_cols],
                   axis=1)
print(f’add one-hot features’)
print(f’x shape is {X.shape}’)
x shape is (125973, 116)
準(zhǔn)備lightgbm 模型.
import lightgbm as lgb
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import confusion_matrix,classification_report,accuracy_score,roc_auc_score,f1_score
feature_name = list(X.columns) # 特征名稱后續(xù)會用到
Y_encode = LabelEncoder().fit_transform(Y)
Y_test_encode = LabelEncoder().fit_transform(Y_test)
dtrain = lgb.Dataset(X.values, label=Y_encode)
dtest = lgb.Dataset(X_test.values, label=Y_test_encode)
param = {
    ‘eta’: 0.1,
    ‘objective’: ‘multiclass’,
    ‘num_class’: 5,
    ‘verbose’: 0,
        ‘metric’:‘multi_error’
} # 參數(shù)幾乎都是默認(rèn)值,僅僅修改一些多分類必須的參數(shù)
evals_result = {}
valid_sets = [dtrain, dtest]
valid_name = [‘train’, ‘eval’]
model = lgb.train(param, dtrain, num_boost_round=500, feature_name=feature_name,
                  valid_sets=valid_sets, valid_names=valid_name, evals_result=evals_result)
y_pred_1 = model.predict(X_test.values)
y_pred = pd.DataFrame(y_pred_1).idxmax(axis=1) #預(yù)測概率值轉(zhuǎn)為預(yù)測標(biāo)簽
#
# 我們用了多種metric 來衡量結(jié)果,其中有些是明顯不適合的,比如accuracy,因?yàn)樗鼤徊黄胶獾臄?shù)據(jù)分布帶到陰溝里(誤導(dǎo))。
print(f’auc score is {accuracy_score(Y_test_encode, y_pred)}’)
print(confusion_matrix(Y_test_encode, y_pred))
print(classification_report(Y_test_encode, y_pred, digits=3))
auc = roc_auc_score(Y_test_encode, y_pred_1, multi_class=“ovo”, average=“macro”) # 選用macro 很重要。參考sklearn。
#Calculate metrics for each label, and find their unweighted mean. #This does not take label imbalance into account.
print(f’roc_auc_score  is {auc}’)
f1 = f1_score(y_pred, Y_test_encode, average=‘macro’)
print(f’f1_score  is {f1}’)
硬train的結(jié)果如下:acc 指標(biāo)已經(jīng)提到,會有誤導(dǎo)性,這里列出就是為了參考。report中3 和4 precision 和recall 較低,這也很正常,因?yàn)閿?shù)據(jù)不平衡嘛。
acc score is 0.6652767920511
  precision    recall  f1-score   support
0      0.840     0.645     0.730      7636
1      0.619     0.899     0.734      9711
2      0.570     0.547     0.558      2423
3      0.312     0.002     0.004      2574
4      0.026     0.030     0.028       200
accuracy                          0.665     22544
macro avg      0.473     0.425     0.411     22544
roc_auc_score  is 0.6405673646606284
f1_score  is 0.41066470104083724
我們稍作改善
改進(jìn)的方向,我認(rèn)為會有一下幾個方面:
- 采用更多的數(shù)據(jù)集,很顯然臣妾做不到 
- 換其他的模型,比如異常診斷(半監(jiān)督或者無監(jiān)督),不在我們討論范圍 
- 小心謹(jǐn)慎的特征工程,需要一定的先驗(yàn)知識 
- 調(diào)參。 
- 重采樣 
- 其他 
我們在此也硬“tune”一發(fā),看當(dāng)前的模型是否可以調(diào)整參數(shù),進(jìn)行一定程度改善。至于特征工程,如果是作為一個項(xiàng)目,還是可以深究,本文不涉及。
我們分析一下模型訓(xùn)練的歷史曲線。
曲線慘不忍睹,但是還是可以看到train 和test 最后都已經(jīng)趨近水平,也就是num_boost_round 參數(shù)已經(jīng)讓目前的模型找到較理想的值了。
重采樣
我們對模型加上重采樣,重采樣的思路很簡單,就是重新采樣讓不同類別的樣本量趨于平等。升采樣和降采樣,也是最常用的方法。對于本案例的數(shù)據(jù),如果我們采用降采樣,會損失太多的信息。而且可控(樣本量)的降采樣,一般也就是隨機(jī)降采樣,對于隨機(jī)的結(jié)果無法有太多的說服力。不可控的降采樣,最終會導(dǎo)致樣本量接近于最小類別的樣本量,也就是本案例中的20多。這樣會大大丟失樣本信息。
因此本文中采用升采樣的方法,常見的升采樣有多種。我們采用的imbalanced-learn (https://imbalanced-learn.readthedocs.io/en/stable/ )的包,里面包含多種升采樣方法,網(wǎng)上似乎一提 升采樣,就是SMOTE。本文中采用的ADASYN(對本案例來說,效果更好,各位可以自行對比)。
# 代碼需要放置在one-hot 之前
from imblearn.over_sampling import SMOTE, ADASYN
from sklearn.preprocessing import LabelEncoder
def label_encoder(data):
    labelencoder = LabelEncoder()
    for col in data.columns:
        data.loc[:,col] = labelencoder.fit_transform(data[col])
    return data
# first label_encoder to allow resampling
X[categorical_columns] = label_encoder(X[categorical_columns])
X_test[categorical_columns] = label_encoder(X_test[categorical_columns])
oversample = ADASYN()
X, Y = oversample.fit_resample(X, Y)
# 之后的代碼為
#X, X_cat = split_category(X, categorical_columns)
#X_test, X_test_cat = split_category(X_test, categorical_columns)
先不進(jìn)行l(wèi)ightbgm調(diào)參,我們看一下結(jié)果:
acc score is 0.7869943222143364
[[6258 1126  251    1    0]
 [  61 9364  276    6    4]
 [ 164  403 1856    0    0]
 [   0 2299   21  246    8]
 [   0  152   22    8   18]]
              precision    recall  f1-score   support
           0      0.965     0.820     0.886      7636
           1      0.702     0.964     0.812      9711
           2      0.765     0.766     0.766      2423
           3      0.943     0.096     0.174      2574
           4      0.600     0.090     0.157       200
    accuracy                          0.787     22544
   macro avg      0.795     0.547     0.559     22544
weighted avg      0.824     0.787     0.754     22544
roc_auc_score  is 0.9097110919608917
f1_score  is 0.5588737585068755
各項(xiàng)指標(biāo)都有提升,同樣的回顧一下我們的訓(xùn)練曲線。尾巴依然光滑,說明不算欠擬合。train 和test 的間距有些大,可能有過擬合之嫌。
我們試試是否為過擬合,對于數(shù)模型,最好控制的就是tree max depth,一般推薦為3-10,我們采用的默認(rèn)6. 我們可以將為3 試試。
acc score is 0.7916962384669979
[[6277 1163  196    0    0]
 [  90 9319  248   25   29]
 [ 166  356 1901    0    0]
 [   4 2174   45  329   22]
 [   0  104   54   20   22]]
              precision    recall  f1-score   support
           0      0.960     0.822     0.886      7636
           1      0.711     0.960     0.816      9711
           2      0.778     0.785     0.781      2423
           3      0.880     0.128     0.223      2574
           4      0.301     0.110     0.161       200
    accuracy                          0.792     22544
   macro avg      0.726     0.561     0.574     22544
weighted avg      0.818     0.792     0.763     22544
roc_auc_score  is 0.8931058881203062
f1_score  is 0.5735623327532393
結(jié)果略有變化,好像更側(cè)重于f1_score的分?jǐn)?shù)。
偏向少數(shù)類
對于不平衡的數(shù)據(jù),如果有需要,我們還可以通過分配權(quán)重,來讓模型偏向少數(shù)類。通過這樣的方法,我們又可以一定程度的平衡模型。lightgbm 支持樣本權(quán)重,我們可以調(diào)整權(quán)重來重新訓(xùn)練。上代碼:
class_w = {
    ‘normal’: 0.1,  # 0.1
    ‘dos’: 0.6,
    ‘probe’: 0.6,
    ‘r2l’: 2,
    ‘u2r’: 1.2} #以上數(shù)據(jù)需要微調(diào),調(diào)整一般從normal開始,因?yàn)樗臋?quán)重大
from sklearn.utils.class_weight import compute_sample_weight
sample_w = compute_sample_weight(class_weight=class_w, y=Y)
##!!然后傳入該權(quán)重到數(shù)據(jù)集中
dtrain = lgb.Dataset(X.values, label=Y_encode,weight=sample_w)
訓(xùn)練結(jié)果與效果:
acc score is 0.828069552874379
[[6448  684  366   63   75]
 [ 142 8551  271  434  313]
 [ 203    3 2185   12   20]
 [  10  895   28 1442  199]
 [   0    5  109   44   42]]
              precision    recall  f1-score   support
           0      0.948     0.844     0.893      7636
           1      0.843     0.881     0.862      9711
           2      0.738     0.902     0.812      2423
           3      0.723     0.560     0.631      2574
           4      0.065     0.210     0.099       200
    accuracy                          0.828     22544
   macro avg      0.663     0.679     0.659     22544
weighted avg      0.847     0.828     0.834     22544
roc_auc_score  is 0.8996899325820623
f1_score  is 0.6593715668480359
可以看到f1-score 有了很大的提升,當(dāng)然你可以繼續(xù)調(diào)整該class_w 去讓你的模型有所側(cè)重。multi_error 也降低了。
總結(jié)
對于不平衡的數(shù)據(jù)集,重新采樣和調(diào)整權(quán)重會對結(jié)果產(chǎn)生影響。當(dāng)然其他的超參可以gridsearch 來優(yōu)化,本文不做研究。推薦https://imbalanced-learn.readthedocs.io/en/stable/ 來深入了解不同采樣的影響。
后記
附上windows 中l(wèi)ightgbm 樹圖的plot以及特征重要性的plot代碼。
import os
graphviz_path = r’C:\Program Files (x86)\Graphviz2.38\bin’
os.environ[“PATH”] += os.pathsep + graphviz_path
lgb.plot_tree(model, tree_index=0)
lgb.plot_importance(model)
 
總結(jié)
以上是生活随笔為你收集整理的多分类 数据不平衡的处理 lightgbm的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
 
                            
                        - 上一篇: 文件流处理流式处理大数据处理
- 下一篇: 样本不平衡不均衡数据处理
