[Kaggle] Spam/Ham Email Classification 垃圾邮件分类(RNN/GRU/LSTM)
生活随笔
收集整理的這篇文章主要介紹了
[Kaggle] Spam/Ham Email Classification 垃圾邮件分类(RNN/GRU/LSTM)
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
文章目錄
- 1. 讀入數據
- 2. 文本處理
- 3. 建模
- 4. 訓練
- 5. 測試
練習地址:https://www.kaggle.com/c/ds100fa19
相關博文
[Kaggle] Spam/Ham Email Classification 垃圾郵件分類(spacy)
[Kaggle] Spam/Ham Email Classification 垃圾郵件分類(BERT)
1. 讀入數據
- 讀取數據,test集沒有標簽
- 數據有無效的單元
存在 Na 單元格
[0 6 0 0] [0 1 0]- fillna 填充處理
填充完成,顯示 sum = 0
[0 0 0 0] [0 0 0]- y 標簽 只有 0 不是垃圾郵件, 1 是垃圾郵件
2. 文本處理
- 郵件內容和主題合并為一個特征
- 文本轉成 tokens ids 序列
- pad ids 序列,使之長度一樣
3. 建模
embeddings_dim = 30 # 詞嵌入向量維度 from keras.models import Model, Sequential from keras.layers import Embedding, LSTM, GRU, SimpleRNN, Dense model = Sequential() model.add(Embedding(input_dim=max_words, # Size of the vocabularyoutput_dim=embeddings_dim, # 詞嵌入的維度input_length=maxlen)) model.add(GRU(units=64)) # 可以改為 SimpleRNN , LSTM model.add(Dense(units=1, activation='sigmoid')) model.summary()模型結構:
Model: "sequential_5" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= embedding_2 (Embedding) (None, 100, 30) 9000 _________________________________________________________________ gru (GRU) (None, 64) 18432 _________________________________________________________________ dense_2 (Dense) (None, 1) 65 ================================================================= Total params: 27,497 Trainable params: 27,497 Non-trainable params: 0 _________________________________________________________________4. 訓練
model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy']) # 配置模型 history = model.fit(X_train_tokens_pad, y_train,batch_size=128, epochs=10, validation_split=0.2) model.save("email_cat_lstm.h5") # 保存訓練好的模型- 繪制訓練曲線
5. 測試
pred_prob = model.predict(X_test_tokens_pad).squeeze() pred_class = np.asarray(pred_prob > 0.5).astype(np.int32) id = test['id'] output = pd.DataFrame({'id':id, 'Class': pred_class}) output.to_csv("submission_gru.csv", index=False)- 3種RNN模型對比:
總結
以上是生活随笔為你收集整理的[Kaggle] Spam/Ham Email Classification 垃圾邮件分类(RNN/GRU/LSTM)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: LeetCode 87. 扰乱字符串(记
- 下一篇: LeetCode 1197. 进击的骑士