60分钟吃掉三杀模型FiBiNET
神經(jīng)網(wǎng)絡的結構設計有3個主流的高級技巧:
1,高低融合 (將高層次特征與低層次特征融合,提升特征維度的豐富性和多樣性,像人一樣同時考慮整體和細節(jié))
2,權值共享 (一個權值矩陣參與多個不同的計算,降低參數(shù)規(guī)模并同時緩解樣本稀疏性,像人一樣一條知識多處運用)
3,動態(tài)適應 (不同的輸入樣本使用不同的權值矩陣,動態(tài)地進行特征選擇并賦予特征重要度解釋性,像人一樣聚焦重要信息排除干擾信息)
技巧應用范例:
1,高低融合 (DeepWide,UNet,特征金字塔FPN...)
2,權值共享 (CNN,RNN,FM,DeepFM,BlinearFFM...)
3,動態(tài)適應 (各種Attention機制...)
新浪微博廣告推薦技術團隊2019年發(fā)布的CTR預估模型FiBiNET同時巧妙地運用了以上3種技巧,是神經(jīng)網(wǎng)絡結構設計的教科書級的范例。
在此介紹給大家。
參考資料:
FiBiNET論文:https://arxiv.org/pdf/1905.09433.pdf
FiBiNET-結合特征重要性和雙線性特征交互進行CTR預估:https://zhuanlan.zhihu.com/p/72931811
代碼實現(xiàn):https://github.com/xue-pai/FuxiCTR/blob/main/fuxictr/pytorch/models/FiBiNET.py
SENet原理:https://zhuanlan.zhihu.com/p/65459972
公眾號后臺回復關鍵詞: FiBiNET?獲取本文全部源碼和數(shù)據(jù)集下載鏈接。
本文目錄:
一,FiBiNET原理解析
二,FiBiNET的pytorch實現(xiàn)
三,Criteo數(shù)據(jù)集完整范例
1,準備數(shù)據(jù)
2,定義模型
3,訓練模型
4,評估模型
5,使用模型
6,保存模型
一,FiBiNET原理解析
FiBiNET全稱為Feature Importance and Bilinear Interaction Network.
顧名思義,其主要的創(chuàng)意有2個。
第一個是Feature Importance,通過借鑒SENet(Squeeze-and-Excitation)Attention機制實現(xiàn)特征選擇和重要度解釋。
第二個是Bilinear Interaction Network,這是應用權值共享技巧對 FFM(Field-Aware FM)結構進行改進的一種結構。
同時,FiBiNET保留了DeepWide的高低融合的網(wǎng)絡架構。
所以它綜合使用了 高低融合、權值共享、動態(tài)適應 這3種神經(jīng)網(wǎng)絡結構設計的高級技巧。一個不落,Triple kill!
我們重點介紹一下 SENet Attention 和 Bilinear Interaction.
1, SENet Attention
SENet 全稱為 Squeeze-and-Excitation Network,是一種通過注意力機制計算特征重要度的網(wǎng)絡模塊。
最早是在CV領域引入,通過在ResNet結構上添加SENet Attention模塊,贏得了ImageNet 2017競賽分類任務的冠軍。
如何計算各個Feature Map(通道)的特征重要度(注意力權重)呢?
SENet的思想非常簡潔。
step1: 通過全局池化將各個Feature Map由一個一個的矩陣匯總成一個一個的標量。此即Squeeze操作。
step2:通過一個2層MLP將匯總成得到的一個一個的標量所構成的向量進行變換,得到注意力權重。此即Excitation操作。細節(jié)一點地說,這個2層的MLP的第1層將通道數(shù)量縮減成原來的1/3, 第2層再將通道數(shù)恢復。并且每層后面都接入了激活函數(shù)。
step3:用注意力權重乘以原始的Feature Map。這個是Re-Weight操作。
圖片示意如下。
pytorch代碼實現(xiàn)如下,可能比圖片更加好懂。
import?torch? from?torch?import?nn? class?SENetAttention(nn.Module):"""Squeeze-and-Excitation?Attention輸入shape:?[batch_size,?num_fields,?d_embed]???#num_fields即num_features輸出shape:?[batch_size,?num_fields,?d_embed]"""def?__init__(self,?num_fields,?reduction_ratio=3):super().__init__()reduced_size?=?max(1,?int(num_fields?/?reduction_ratio))self.excitation?=?nn.Sequential(nn.Linear(num_fields,?reduced_size,?bias=False),nn.ReLU(),nn.Linear(reduced_size,?num_fields,?bias=False),nn.ReLU())def?forward(self,?x):Z?=?torch.mean(x,?dim=-1,?out=None)?#1,SequeezeA?=?self.excitation(Z)?#2,ExcitationV?=?x?*?A.unsqueeze(-1)?#3,Re-Weightreturn?V2, Bilinear Interaction
Bilinear Interaction實際上是FFM在權值共享思想下的一種改進,也可以稱之為Bilinear FFM。
我們先說說FFM(Field-Aware FM),再看看這個Bilinear FFM 怎么改進的。
FM用隱向量之間的點積來計算特征之間的交叉,并且一個特征用一個隱向量來表示。
FFM認為一個特征用一個隱向量來表達太粗糙了,如果這個特征和不同分組(Field)的特征來做交叉,應該用不同的隱向量。
舉例來說,考慮一個廣告點擊預測的場景,廣告類別 和 用戶所在城市、用戶職業(yè)之間的交叉。
在FM中 一個確定的廣告類別 比如游戲廣告 不論是和用戶所在城市,還是用戶職業(yè)交叉,都用同一個隱向量。
但是FFM認為,用戶所在城市和用戶職業(yè)是兩類完全不同的特征(不同F(xiàn)ield),描述它們的向量空間應該是完全不相關的,FM用一個相同的隱向量來和它們做點積不合理。
所以,FFM引入了Field(域)的概念,和不同F(xiàn)ield的特征做交叉,要使用不同的隱向量。
實踐表明,FFM這個思路是有效的, FFM的作者阮毓欽正是憑借這個方案贏得了2015年kaggle舉辦的Criteo比賽的冠軍。
但是FFM有個很大的缺點,就是參數(shù)量太多了。
對于FM來說,每個特征只有一個隱向量,假設有n個特征,每個隱向量維度為k,全部隱向量參數(shù)矩陣的大小 size = n k.
但是對于FFM,有過有f個不同的field,每個特征都將有f-1個隱向量,全部隱向量的參數(shù)矩陣的大小增大為 size = (f-1) n k.
通常的應用場景中,Field的數(shù)量有幾十幾百維,而Feature的數(shù)量有數(shù)萬數(shù)百萬維。
很顯然,FFM將隱向量的參數(shù)規(guī)模擴大了幾十幾百倍。
FFM的本質(zhì)思想是在做特征交叉的時候要區(qū)分不同的Field,其實現(xiàn)方式是和不同的Field做交叉時用不同的隱向量。
有沒有辦法保留FFM中區(qū)分不同F(xiàn)ield的特性,并降低參數(shù)規(guī)模呢?
BilinearFFM說,我有辦法,權重共享走起來!
BilinearFFM不直接針對不同F(xiàn)ield設計不同的隱向量,而是引入了Field變換矩陣來區(qū)分不同的Field。
每個特征還是一個隱向量,但是和不同的Field的特征做交叉時,先乘上這個特征所在Field的變換矩陣,然后再做后面的點積。
因此,同屬一個Field的特征共享一個Field變換矩陣。這種bilinear_type叫做 field_each.
Field變換矩陣的大小是k^2, 這種方式下,全部隱向量的參數(shù)大小加上共享變換矩陣的參數(shù)大小一共是 size = n k + f k^2
由于k和f遠小于n,這種Bilinear方式相比FM增加的參數(shù)量可以忽略不計。
除了 同屬一個Field的特征共享一個Field變換矩陣外,我們還可以更加簡單粗暴一點,所有特征共享一個變換矩陣.
這種bilinear_type叫做 field_all.這種方式下,size = n k + k^2
我們也可以更加精細一點,相同的Field組合之間的交互共享一個變換矩陣,這種bilinear_type叫做field_interaction.
總共有f(f-1)/2種組合,這種方式下, size = n k + k^2 f(f-1)/2
以上就是BilinearFFM的基本思想。
FiBiNET中用到的Bilinear Interaction相比BilinearFFM, 還有一處小改動,將點積改成了哈達瑪積,如下圖所示。
pytorch代碼實現(xiàn)如下,整體不難理解。作2點說明。
1,Field概念說明
在FFM相關的文章中,引入了Field的概念,以和Feature區(qū)分,一個Field中可以包括多個Feature.
實際上Field就是我們通常理解的特征,包括數(shù)值特征和類別特征,但是Feature是數(shù)值特征或者類別特征onehot后的特征。一個類別特征對應一個Field,但是對應多個Feature。
2,combinations函數(shù)說明
組合函數(shù)combinations從num_fields中任取2種作為組合,共有 num_fields*(num_fields-1)中組合方式。
所以輸出的Field數(shù)量變成了 num_fields*(num_fields-1)/2。
import?torch? from?torch?import?nn? from?itertools?import?combinations class?BilinearInteraction(nn.Module):"""雙線性FFM輸入shape:?[batch_size,?num_fields,?d_embed]?#num_fields即num_features輸出shape:?[batch_size,?num_fields*(num_fields-1)/2,?d_embed]"""def?__init__(self,?num_fields,?d_embed,?bilinear_type="field_interaction"):super().__init__()self.bilinear_type?=?bilinear_typeif?self.bilinear_type?==?"field_all":self.bilinear_layer?=?nn.Linear(d_embed,?d_embed,?bias=False)elif?self.bilinear_type?==?"field_each":self.bilinear_layer?=?nn.ModuleList([nn.Linear(d_embed,?d_embed,?bias=False)for?i?in?range(num_fields)])elif?self.bilinear_type?==?"field_interaction":self.bilinear_layer?=?nn.ModuleList([nn.Linear(d_embed,?d_embed,?bias=False)for?i,?j?in?combinations(range(num_fields),?2)])else:raise?NotImplementedError()def?forward(self,?feature_emb):feature_emb_list?=?torch.split(feature_emb,?1,?dim=1)if?self.bilinear_type?==?"field_all":bilinear_list?=?[self.bilinear_layer(v_i)?*?v_jfor?v_i,?v_j?in?combinations(feature_emb_list,?2)]elif?self.bilinear_type?==?"field_each":bilinear_list?=?[self.bilinear_layer[i](feature_emb_list[i])?*?feature_emb_list[j]for?i,?j?in?combinations(range(len(feature_emb_list)),?2)]elif?self.bilinear_type?==?"field_interaction":bilinear_list?=?[self.bilinear_layer[i](v[0])?*?v[1]for?i,?v?in?enumerate(combinations(feature_emb_list,?2))]return?torch.cat(bilinear_list,?dim=1)二,FiBiNET的pytorch實現(xiàn)
下面是FiBiNET的一個pytorch實現(xiàn)。
核心代碼是SENetAttention模塊和BilinearInteraction模塊的實現(xiàn)。
import?torch? from?torch?import?nn? from?itertools?import?combinationsclass?NumEmbedding(nn.Module):"""連續(xù)特征用linear層編碼輸入shape:?[batch_size,num_features,?d_in],?#?d_in?通常是1輸出shape:?[batch_size,num_features,?d_out]"""def?__init__(self,?n:?int,?d_in:?int,?d_out:?int,?bias:?bool?=?False)?->?None:super().__init__()self.weight?=?nn.Parameter(torch.Tensor(n,?d_in,?d_out))self.bias?=?nn.Parameter(torch.Tensor(n,?d_out))?if?bias?else?Nonewith?torch.no_grad():for?i?in?range(n):layer?=?nn.Linear(d_in,?d_out)self.weight[i]?=?layer.weight.Tif?self.bias?is?not?None:self.bias[i]?=?layer.biasdef?forward(self,?x_num):assert?x_num.ndim?==?3#x?=?x_num[...,?None]?*?self.weight[None]#x?=?x.sum(-2)x?=?torch.einsum("bfi,fij->bfj",x_num,self.weight)if?self.bias?is?not?None:x?=?x?+?self.bias[None]return?xclass?CatEmbedding(nn.Module):"""離散特征用Embedding層編碼輸入shape:?[batch_size,?num_features],?輸出shape:?[batch_size,?num_features,?d_embed]"""def?__init__(self,?categories,?d_embed):super().__init__()self.embedding?=?nn.Embedding(sum(categories),?d_embed)self.offsets?=?nn.Parameter(torch.tensor([0]?+?categories[:-1]).cumsum(0),requires_grad=False)nn.init.xavier_uniform_(self.embedding.weight.data)def?forward(self,?x_cat):"""x_cat:?Long?tensor?of?size?``(batch_size,?features_num)``"""x?=?x_cat?+?self.offsets[None]return?self.embedding(x)?class?CatLinear(nn.Module):"""離散特征用Embedding實現(xiàn)線性層(等價于先F.onehot再nn.Linear())輸入shape:?[batch_size,?num_features?],?輸出shape:?[batch_size,?d_out]"""def?__init__(self,?categories,?d_out=1):super().__init__()self.fc?=?nn.Embedding(sum(categories),?d_out)self.bias?=?nn.Parameter(torch.zeros((d_out,)))self.offsets?=?nn.Parameter(torch.tensor([0]?+?categories[:-1]).cumsum(0),requires_grad=False)nn.init.xavier_uniform_(self.fc.weight.data)def?forward(self,?x_cat):"""Long?tensor?of?size?``(batch_size,?num_features)``"""x?=?x_cat?+?self.offsets[None]return?torch.sum(self.fc(x),?dim=1)?+?self.bias?class?SENetAttention(nn.Module):"""Squeeze-and-Excitation?Attention輸入shape:?[batch_size,?num_fields,?d_embed]???#num_fields即num_features輸出shape:?[batch_size,?num_fields,?d_embed]"""def?__init__(self,?num_fields,?reduction_ratio=3):super().__init__()reduced_size?=?max(1,?int(num_fields?/?reduction_ratio))self.excitation?=?nn.Sequential(nn.Linear(num_fields,?reduced_size,?bias=False),nn.ReLU(),nn.Linear(reduced_size,?num_fields,?bias=False),nn.ReLU())def?forward(self,?x):Z?=?torch.mean(x,?dim=-1,?out=None)?#1,SequeezeA?=?self.excitation(Z)?#2,ExcitationV?=?x?*?A.unsqueeze(-1)?#3,Re-Weightreturn?Vclass?BilinearInteraction(nn.Module):"""雙線性FFM輸入shape:?[batch_size,?num_fields,?d_embed]?#num_fields即num_features輸出shape:?[batch_size,?num_fields*(num_fields-1)/2,?d_embed]"""def?__init__(self,?num_fields,?d_embed,?bilinear_type="field_interaction"):super().__init__()self.bilinear_type?=?bilinear_typeif?self.bilinear_type?==?"field_all":self.bilinear_layer?=?nn.Linear(d_embed,?d_embed,?bias=False)elif?self.bilinear_type?==?"field_each":self.bilinear_layer?=?nn.ModuleList([nn.Linear(d_embed,?d_embed,?bias=False)for?i?in?range(num_fields)])elif?self.bilinear_type?==?"field_interaction":self.bilinear_layer?=?nn.ModuleList([nn.Linear(d_embed,?d_embed,?bias=False)for?i,?j?in?combinations(range(num_fields),?2)])else:raise?NotImplementedError()def?forward(self,?feature_emb):feature_emb_list?=?torch.split(feature_emb,?1,?dim=1)if?self.bilinear_type?==?"field_all":bilinear_list?=?[self.bilinear_layer(v_i)?*?v_jfor?v_i,?v_j?in?combinations(feature_emb_list,?2)]elif?self.bilinear_type?==?"field_each":bilinear_list?=?[self.bilinear_layer[i](feature_emb_list[i])?*?feature_emb_list[j]for?i,?j?in?combinations(range(len(feature_emb_list)),?2)]elif?self.bilinear_type?==?"field_interaction":bilinear_list?=?[self.bilinear_layer[i](v[0])?*?v[1]for?i,?v?in?enumerate(combinations(feature_emb_list,?2))]return?torch.cat(bilinear_list,?dim=1)#mlp class?MultiLayerPerceptron(nn.Module):def?__init__(self,?d_in,?d_layers,?dropout,?d_out?=?1):super().__init__()layers?=?[]for?d?in?d_layers:layers.append(nn.Linear(d_in,?d))layers.append(nn.BatchNorm1d(d))layers.append(nn.ReLU())layers.append(nn.Dropout(p=dropout))d_in?=?dlayers.append(nn.Linear(d_layers[-1],?d_out))self.mlp?=?nn.Sequential(*layers)def?forward(self,?x):"""float?tensor?of?size?``(batch_size,?d_in)``"""return?self.mlp(x)#fibinet? class?FiBiNET(nn.Module):def?__init__(self,d_numerical,?categories,?d_embed,mlp_layers,?mlp_dropout,reduction_ratio?=?3,bilinear_type?=?"field_interaction",n_classes?=?1):super().__init__()if?d_numerical?is?None:d_numerical?=?0if?categories?is?None:categories?=?[]self.categories?=?categoriesself.n_classes?=?n_classesself.num_linear?=?nn.Linear(d_numerical,n_classes)?if?d_numerical?else?Noneself.cat_linear?=?CatLinear(categories,n_classes)?if?categories?else?Noneself.num_embedding?=?NumEmbedding(d_numerical,1,d_embed)?if?d_numerical?else?Noneself.cat_embedding?=?CatEmbedding(categories,?d_embed)?if?categories?else?Nonenum_fields?=?d_numerical+len(categories)self.se_attention?=?SENetAttention(num_fields,?reduction_ratio)self.bilinear?=?BilinearInteraction(num_fields,?d_embed,?bilinear_type)mlp_in?=?num_fields?*?(num_fields?-?1)?*?d_embedself.mlp?=?MultiLayerPerceptron(d_in=?mlp_in,d_layers?=?mlp_layers,dropout?=?mlp_dropout,d_out?=?n_classes)def?forward(self,?x):"""x_num:?numerical?featuresx_cat:?category?features"""x_num,x_cat?=?x#一,wide部分x_linear?=?0.0if?self.num_linear:x_linear?=?x_linear?+?self.num_linear(x_num)?if?self.cat_linear:x_linear?=?x_linear?+?self.cat_linear(x_cat)#二,deep部分?#1,embeddingx_embedding?=?[]if?self.num_embedding:x_embedding.append(self.num_embedding(x_num[...,None]))if?self.cat_embedding:x_embedding.append(self.cat_embedding(x_cat))x_embedding?=?torch.cat(x_embedding,dim=1)#2,interactionse_embedding?=?self.se_attention(x_embedding)ffm_out?=?self.bilinear(x_embedding)se_ffm_out?=?self.bilinear(se_embedding)x_interaction?=?torch.flatten(torch.cat([ffm_out,?se_ffm_out],?dim=1),?start_dim=1)#3,mlpx_deep?=?self.mlp(x_interaction)#三,高低融合x_out?=?x_linear+x_deepif?self.n_classes==1:x_out?=?x_out.squeeze(-1)return?x_out ##測試?FiBiNETmodel?=?FiBiNET(d_numerical?=?3,?categories?=?[4,3,2],d_embed?=?4,?mlp_layers?=?[20,20],?mlp_dropout=0.25,reduction_ratio?=?3,bilinear_type?=?"field_interaction",n_classes?=?1)x_num?=?torch.randn(2,3) x_cat?=?torch.randint(0,2,(2,3)) print(model((x_num,x_cat))) tensor([-0.8621,??0.6743],?grad_fn=<SqueezeBackward1>)三,criteo數(shù)據(jù)集完整范例
Criteo數(shù)據(jù)集是一個經(jīng)典的廣告點擊率CTR預測數(shù)據(jù)集。
這個數(shù)據(jù)集的目標是通過用戶特征和廣告特征來預測某條廣告是否會為用戶點擊。
數(shù)據(jù)集有13維數(shù)值特征(I1-I13)和26維類別特征(C14-C39), 共39維特征, 特征中包含著許多缺失值。
訓練集4000萬個樣本,測試集600萬個樣本。數(shù)據(jù)集大小超過100G.
此處使用的是采樣100萬個樣本后的cretio_small數(shù)據(jù)集。
!pip?install?-U?torchkeras?-i?https://pypi.org/simple/ import?numpy?as?np? import?pandas?as?pd? import?datetime?from?sklearn.model_selection?import?train_test_split?import?torch? from?torch?import?nn? from?torch.utils.data?import?Dataset,DataLoader?? import?torch.nn.functional?as?F? import?torchkeras?def?printlog(info):nowtime?=?datetime.datetime.now().strftime('%Y-%m-%d?%H:%M:%S')print("\n"+"=========="*8?+?"%s"%nowtime)print(info+'...\n\n')1,準備數(shù)據(jù)
from?sklearn.preprocessing?import?LabelEncoder,QuantileTransformer from?sklearn.pipeline?import?Pipeline? from?sklearn.impute?import?SimpleImputer?dfdata?=?pd.read_csv("../input/criteo-small/train_1m.txt",sep="\t",header=None) dfdata.columns?=?["label"]?+?["I"+str(x)?for?x?in?range(1,14)]?+?["C"+str(x)?for?x?in?range(14,40)]cat_cols?=?[x?for?x?in?dfdata.columns?if?x.startswith('C')] num_cols?=?[x?for?x?in?dfdata.columns?if?x.startswith('I')] num_pipe?=?Pipeline(steps?=?[('impute',SimpleImputer()),('quantile',QuantileTransformer())])for?col?in?cat_cols:dfdata[col]??=?LabelEncoder().fit_transform(dfdata[col])dfdata[num_cols]?=?num_pipe.fit_transform(dfdata[num_cols])categories?=?[dfdata[col].max()+1?for?col?in?cat_cols] import?torch? from?torch.utils.data?import?Dataset,DataLoader?#DataFrame轉換成torch數(shù)據(jù)集Dataset,?特征分割成X_num,X_cat方式 class?DfDataset(Dataset):def?__init__(self,df,label_col,num_features,cat_features,categories,is_training=True):self.X_num?=?torch.tensor(df[num_features].values).float()?if?num_features?else?Noneself.X_cat?=?torch.tensor(df[cat_features].values).long()?if?cat_features?else?Noneself.Y?=?torch.tensor(df[label_col].values).float()?self.categories?=?categoriesself.is_training?=?is_trainingdef?__len__(self):return?len(self.Y)def?__getitem__(self,index):if?self.is_training:return?((self.X_num[index],self.X_cat[index]),self.Y[index])else:return?(self.X_num[index],self.X_cat[index])def?get_categories(self):return?self.categories dftrain_val,dftest?=?train_test_split(dfdata,test_size=0.2) dftrain,dfval?=?train_test_split(dftrain_val,test_size=0.2)ds_train?=?DfDataset(dftrain,label_col?=?"label",num_features?=?num_cols,cat_features?=?cat_cols,categories?=?categories,?is_training=True)ds_val?=?DfDataset(dfval,label_col?=?"label",num_features?=?num_cols,cat_features?=?cat_cols,categories?=?categories,?is_training=True)ds_test?=?DfDataset(dftest,label_col?=?"label",num_features?=?num_cols,cat_features?=?cat_cols,categories?=?categories,?is_training=True) dl_train?=?DataLoader(ds_train,batch_size?=?2048,shuffle=True) dl_val?=?DataLoader(ds_val,batch_size?=?2048,shuffle=False) dl_test?=?DataLoader(ds_test,batch_size?=?2048,shuffle=False)for?features,labels?in?dl_train:break2,定義模型
def?create_net():net?=?FiBiNET(d_numerical=?ds_train.X_num.shape[1],categories=?ds_train.get_categories(),d_embed?=?8,?mlp_layers?=?[128,64,32],?mlp_dropout=0.25,reduction_ratio?=?3,bilinear_type?=?"field_all",n_classes?=?1)return?net?from?torchkeras?import?summarynet?=?create_net()3,訓練模型
import?os,sys,time import?numpy?as?np import?pandas?as?pd import?datetime? from?tqdm?import?tqdm?import?torch from?torch?import?nn? from?accelerate?import?Accelerator from?copy?import?deepcopydef?printlog(info):nowtime?=?datetime.datetime.now().strftime('%Y-%m-%d?%H:%M:%S')print("\n"+"=========="*8?+?"%s"%nowtime)print(str(info)+"\n")class?StepRunner:def?__init__(self,?net,?loss_fn,stage?=?"train",?metrics_dict?=?None,?optimizer?=?None,?lr_scheduler?=?None,accelerator?=?None):self.net,self.loss_fn,self.metrics_dict,self.stage?=?net,loss_fn,metrics_dict,stageself.optimizer,self.lr_scheduler?=?optimizer,lr_schedulerself.accelerator?=?acceleratordef?__call__(self,?features,?labels):#losspreds?=?self.net(features)loss?=?self.loss_fn(preds,labels)#backward()if?self.optimizer?is?not?None?and?self.stage=="train":if?self.accelerator?is??None:loss.backward()else:self.accelerator.backward(loss)self.optimizer.step()if?self.lr_scheduler?is?not?None:self.lr_scheduler.step()self.optimizer.zero_grad()#metricsstep_metrics?=?{self.stage+"_"+name:metric_fn(preds,?labels).item()?for?name,metric_fn?in?self.metrics_dict.items()}return?loss.item(),step_metricsclass?EpochRunner:def?__init__(self,steprunner):self.steprunner?=?steprunnerself.stage?=?steprunner.stageself.steprunner.net.train()?if?self.stage=="train"?else?self.steprunner.net.eval()def?__call__(self,dataloader):total_loss,step?=?0,0loop?=?tqdm(enumerate(dataloader),?total?=len(dataloader))for?i,?batch?in?loop:features,labels?=?batchif?self.stage=="train":loss,?step_metrics?=?self.steprunner(features,labels)else:with?torch.no_grad():loss,?step_metrics?=?self.steprunner(features,labels)step_log?=?dict({self.stage+"_loss":loss},**step_metrics)total_loss?+=?lossstep+=1if?i!=len(dataloader)-1:loop.set_postfix(**step_log)else:epoch_loss?=?total_loss/stepepoch_metrics?=?{self.stage+"_"+name:metric_fn.compute().item()?for?name,metric_fn?in?self.steprunner.metrics_dict.items()}epoch_log?=?dict({self.stage+"_loss":epoch_loss},**epoch_metrics)loop.set_postfix(**epoch_log)for?name,metric_fn?in?self.steprunner.metrics_dict.items():metric_fn.reset()return?epoch_logclass?KerasModel(torch.nn.Module):def?__init__(self,net,loss_fn,metrics_dict=None,optimizer=None,lr_scheduler?=?None):super().__init__()self.accelerator?=?Accelerator()self.history?=?{}self.net?=?netself.loss_fn?=?loss_fnself.metrics_dict?=?nn.ModuleDict(metrics_dict)?self.optimizer?=?optimizer?if?optimizer?is?not?None?else?torch.optim.Adam(self.parameters(),?lr=1e-2)self.lr_scheduler?=?lr_schedulerself.net,self.loss_fn,self.metrics_dict,self.optimizer?=?self.accelerator.prepare(self.net,self.loss_fn,self.metrics_dict,self.optimizer)def?forward(self,?x):if?self.net:return?self.net.forward(x)else:raise?NotImplementedErrordef?fit(self,?train_data,?val_data=None,?epochs=10,?ckpt_path='checkpoint.pt',?patience=5,?monitor="val_loss",?mode="min"):train_data?=?self.accelerator.prepare(train_data)val_data?=?self.accelerator.prepare(val_data)?if?val_data?else?[]for?epoch?in?range(1,?epochs+1):printlog("Epoch?{0}?/?{1}".format(epoch,?epochs))#?1,train?-------------------------------------------------??train_step_runner?=?StepRunner(net?=?self.net,stage="train",loss_fn?=?self.loss_fn,metrics_dict=deepcopy(self.metrics_dict),optimizer?=?self.optimizer,?lr_scheduler?=?self.lr_scheduler,accelerator?=?self.accelerator)train_epoch_runner?=?EpochRunner(train_step_runner)train_metrics?=?train_epoch_runner(train_data)for?name,?metric?in?train_metrics.items():self.history[name]?=?self.history.get(name,?[])?+?[metric]#?2,validate?-------------------------------------------------if?val_data:val_step_runner?=?StepRunner(net?=?self.net,stage="val",loss_fn?=?self.loss_fn,metrics_dict=deepcopy(self.metrics_dict),accelerator?=?self.accelerator)val_epoch_runner?=?EpochRunner(val_step_runner)with?torch.no_grad():val_metrics?=?val_epoch_runner(val_data)val_metrics["epoch"]?=?epochfor?name,?metric?in?val_metrics.items():self.history[name]?=?self.history.get(name,?[])?+?[metric]#?3,early-stopping?-------------------------------------------------arr_scores?=?self.history[monitor]best_score_idx?=?np.argmax(arr_scores)?if?mode=="max"?else?np.argmin(arr_scores)if?best_score_idx==len(arr_scores)-1:torch.save(self.net.state_dict(),ckpt_path)print("<<<<<<?reach?best?{0}?:?{1}?>>>>>>".format(monitor,arr_scores[best_score_idx]),file=sys.stderr)if?len(arr_scores)-best_score_idx>patience:print("<<<<<<?{}?without?improvement?in?{}?epoch,?early?stopping?>>>>>>".format(monitor,patience),file=sys.stderr)self.net.load_state_dict(torch.load(ckpt_path))break?return?pd.DataFrame(self.history)@torch.no_grad()def?evaluate(self,?val_data):val_data?=?self.accelerator.prepare(val_data)val_step_runner?=?StepRunner(net?=?self.net,stage="val",loss_fn?=?self.loss_fn,metrics_dict=deepcopy(self.metrics_dict),accelerator?=?self.accelerator)val_epoch_runner?=?EpochRunner(val_step_runner)val_metrics?=?val_epoch_runner(val_data)return?val_metrics@torch.no_grad()def?predict(self,?dataloader):dataloader?=?self.accelerator.prepare(dataloader)result?=?torch.cat([self.forward(t[0])?for?t?in?dataloader])return?result.data from?torchkeras.metrics?import?AUCloss_fn?=?nn.BCEWithLogitsLoss()metrics_dict?=?{"auc":AUC()}optimizer?=?torch.optim.Adam(net.parameters(),?lr=0.002,?weight_decay=0.001)?model?=?KerasModel(net,loss_fn?=?loss_fn,metrics_dict=?metrics_dict,optimizer?=?optimizer) dfhistory?=?model.fit(train_data=dl_train,val_data=dl_val,epochs=100,?patience=5,monitor?=?"val_auc",mode="max",ckpt_path='checkpoint.pt')4,評估模型
%matplotlib?inline %config?InlineBackend.figure_format?=?'svg'import?matplotlib.pyplot?as?pltdef?plot_metric(dfhistory,?metric):train_metrics?=?dfhistory["train_"+metric]val_metrics?=?dfhistory['val_'+metric]epochs?=?range(1,?len(train_metrics)?+?1)plt.plot(epochs,?train_metrics,?'bo--')plt.plot(epochs,?val_metrics,?'ro-')plt.title('Training?and?validation?'+?metric)plt.xlabel("Epochs")plt.ylabel(metric)plt.legend(["train_"+metric,?'val_'+metric])plt.show() plot_metric(dfhistory,"loss") plot_metric(dfhistory,"auc")5,使用模型
from?sklearn.metrics?import?roc_auc_score preds?=?torch.sigmoid(model.predict(dl_val)) labels?=?torch.cat([x[-1]?for?x?in?dl_val])val_auc?=?roc_auc_score(labels.cpu().numpy(),preds.cpu().numpy()) print(val_auc)0.7806176567186112
6,保存模型
torch.save(model.net.state_dict(),"best_fibinet.pt") net_clone?=?create_net() net_clone.load_state_dict(torch.load("best_fibinet.pt")) from?sklearn.metrics?import?roc_auc_score net_clone.eval() preds?=?torch.cat([torch.sigmoid(net_clone(x[0])).data?for?x?in?dl_val])? labels?=?torch.cat([x[-1]?for?x?in?dl_val])val_auc?=?roc_auc_score(labels.cpu().numpy(),preds.cpu().numpy()) print(val_auc)0.7806176567186112
可以看到FiBiNET在驗證集的AUC得分為0.7806,相比之下DeepFM的驗證集AUC為0.7803。
不能說紋絲不動, 只能說了漲了個蚊子腿大小肉的點。
并且這是以較大地犧牲模型訓練預測效率為代價的。
DeepFM訓練一個Epoch大約需要20s, 而FiBiNET訓練一個Epoch需要大約2min.
盡管如此, FiBiNET的結構設計依然是值得我們學習和借鑒的, 集神經(jīng)網(wǎng)絡結構設計三大主流高級技巧于一體, 閃爍著穿越時空的才華與智慧光芒。
以上。
萬水千山總是情,點個在看行不行?😋😋?
公眾號后臺回復關鍵詞:?FiBiNET?獲取本文全部源碼和數(shù)據(jù)集下載鏈接。
總結
以上是生活随笔為你收集整理的60分钟吃掉三杀模型FiBiNET的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 日期到天数转换
- 下一篇: css中aspect,css扩展之asp