深度学习小技巧(二):如何保存和恢复scikit-learn训练的模型
深度學習小技巧(一):如何保存和恢復TensorFlow訓練的模型
在許多情況下,在使用scikit學習庫的同時,你需要將預測模型保存到文件中,然后在使用它們的時候還原它們,以便重復使用以前的工作。比如在新數據上測試模型,比較多個模型的優劣。這種保存過程也稱為對象序列化——表示具有字節流的對象,以便將其存儲在磁盤上,它可以通過網絡發送或保存到數據庫,而其恢復的過程被稱為反序列化。在本文中,我們將在Python和scikit學習中看到三種可能的方法,而且每種都有其優點和缺點。
1.保存和恢復模型的工具
我們第一個介紹的工具是Pickle,用于對象(de)序列化的標準Python工具。之后,我們會介紹Joblib庫,它提供了容易(de)序列化方法,其中包含了大數據數組的對象,最后我們會介紹一種手動方法來保存和恢復JSON對象(JavaScript Object Notation)。這些方法都不能代表最佳解決方案,但是可以根據項目的需要選擇合適的方案。
2.模型初始化
首先,我們要創建一個scikit學習模型。在我們的例子中,我們將使用Logistic回歸模型和Iris數據集。我們導入所需的庫,并且加載數據,并將其拆分為訓練集和測試集。
現在讓我們用一些非默認參數來創建模型,并用訓練數據來“喂養”它。我們假設你先前已經找到了模型的最優參數,即產生最高估計精度的參數。
這是我們產生的模型:
使用該fit方法,模型已經學習了存儲在其中的系數model.coef_。目標是將模型的參數和系數保存到文件中,因此你不需要再次對新數據重復模型訓練和參數優化的步驟。
3.Pickle模塊
在以下幾行代碼中,我們將上一步中創建的模型保存到文件中,然后作為一個新對象加載pickled_model。然后使用加載的模型計算準確度分數,并對新的未見(測試)數據進行預測結果。
運行此代碼應該會產生你的預測分數,并通過Pickle保存模型:
使用Pickle來保存和恢復學習模型的好處在于它很快,并且你可以用兩行代碼完成。如果你已經對訓練數據上的模型參數進行了優化,那么這是非常有用的,因此你不需要重復此步驟。不管如何,它都不保存測試結果和任何數據。但仍然可以保存多個對象的元組或列表(并記住哪個對象在哪里),如下所示:
3.Joblib模塊
Joblib庫它的目的是替代Pickle,用于包含大數據的對象。我們將重復與Pickle一樣的保存和恢復過程。
從示例中可以看出,與Pickle相比,Joblib庫提供了一個簡單的工作流程。雖然Pickle要求將文件對象作為參數傳遞,但是Joblib可與文件對象和字符串文件名一起使用。如果你的模型包含大量數據,則每個數組將存儲在單獨的文件中,但整體的保存和恢復過程將保持不變。Joblib還允許使用不同的壓縮方法,如“zlib”,“gzip”,“bz2”和不同的壓縮級別。
4.手動保存并還原到JSON
根據你的項目,很多時候你會發現Pickle和Joblib都不是合適的解決方案。其中一些原因將在兼容性問題部分中稍后討論。無論何時要想完全控制保存和恢復過程,最好的方法是手動構建自己的功能。
以下顯示了使用JSON手動保存和恢復對象的示例。這種方法允許我們選擇需要保存的數據,例如模型參數,系數,訓練數據以及我們需要的任何其他數據。
由于我們想將所有這些數據保存在一個對象中,所以一個可能的方法是創建一個繼承我們的示例中的模型類的新類LogisticRegression。這個新類被MyLogReg調用,然后分別實現save_json和load_json的方法以保存和恢復JSON文件。
為簡單起見,我們將只保存三個模型參數和訓練數據。我們可以用這種方法存儲一些額外的數據,例如訓練集上的交叉驗證分數,測試數據,測試數據的準確度等等。
現在我們來試一試MyLogReg。首先我們創建一個對象mylogreg,將訓練數據傳遞給它,并將其保存到文件中。然后我們創建一個新對象json_mylogreg,并調用該load_json方法從文件加載數據。
打印出新的對象,我們可以根據需要來查看我們的參數和訓練數據。
由于使用JSON的數據序列化實際上是將對象保存為字符串格式,而不是字節流,所以'mylogreg.json'文件可以使用文本編輯器打開和修改。雖然這種方法對開發人員來說很方便,但是由于入侵者可以查看和修改JSON文件的內容,因此安全性較低。此外,這種方法更適合于具有少量實例變量的對象,例如scikit-learn模型,因為任何添加新變量都需要在保存和恢復方法中進行更改。
5.兼容性問題
盡管到目前為止,每個工具的優點和缺點已被介紹,但Pickle和Joblib工具的最大缺點可能是其與不同型號的Python版本的兼容性。
5.1:Python版本的兼容性——兩種工具的文檔都指出,不建議(de)在不同的Python版本之間對對象進行序列化,盡管它可能在低級的版本更改中起作用。
5.2:模型兼容性——最常見的錯誤之一是使用Pickle或Joblib保存模型,然后在嘗試從文件還原之前更改模型。模型的內部結構需要在保存和重新加載之間保持不變。
Pickle和Joblib的最后一個問題與安全性有關。這兩種工具都可能包含惡意代碼,因此不建議從不受信任或未經身份驗證的源代碼。
6.結論
在這篇文章中,我們描述了三種保存和恢復scikit學習模型的工具。Pickle和Joblib庫可以快速方便地使用,但是在不同的Python版本和學習模型的變化中存在兼容性問題。另一方面,手動方法更難實現,需要在模型結構發生任何變化中進行修改,但在另一方面,它可以輕松地適應各種需求,并且沒有任何兼容性問題。
作者信息
作者:Mihajlo Pavloski,數據科學與機器學習的愛好者,博士生。
本文由阿里云云棲社區組織翻譯。
文章原標題《TensorFlow : Save and Restore Models》
作者:Mihajlo Pavloski譯者:虎說八道
文章為簡譯,更為詳細的內容,請查看原文
總結
以上是生活随笔為你收集整理的深度学习小技巧(二):如何保存和恢复scikit-learn训练的模型的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: poi读取Excel内容数据
- 下一篇: BZOJ-2748: [HAOI2012