【深度学习】PyTorch 数据集随机值的完美实践
? 作者 | Elvanth@知乎
來源 | https://zhuanlan.zhihu.com/p/377155682
編輯 | 極市平臺
本文僅作學術交流,版權歸原作者所有,如有侵權請聯系刪除。
導讀
?本文所分析的問題與解決方案將在最近發布的pytorch版本中解決;因此解決所有煩惱的根源是方法,更新pytorch~?>>
一個快捷的解決方案:
def worker_init_fn(worker_id):worker_seed = torch.initial_seed() % 2**32np.random.seed(worker_seed)random.seed(worker_seed)ds?=?DataLoader(ds,?10,?shuffle=False,?num_workers=4,?worker_init_fn=worker_init_fn)01 關于pytorch數據集隨機種子的基本認識
在pytorch中random、torch.random等隨機值產生方法一般沒有問題,只有少數工人運行也可以保障其不同的最終值.
np.random.seed 會出現問題的原因是,當多處理采用 fork 方式產生子進程時,numpy 不會對不同的子進程產生不同的隨機值.
換言之,當沒有多處理使用時,numpy 不會出現隨機種子的不同的問題;實驗代碼的可復現性要求一個是工人種子 ,即工人內包括numpy,random,torch.random所有的隨機表現;另一個是Base ,即程序運行后的初始隨機值,其可以通過以下兩種方式產生
torch.manual_seed(base_seed)
由特定的seed generator設置
使用spawn模式可以斬斷以上所有煩惱.
02 直接在網上搜這個問題會得到什么答案
參考很多的解決方案時,往往會提出以下功能:
def worker_init_fn(worker_id):np.random.seed(np.random.get_state()[1][0] + worker_id)讓我們看看它的輸出結果:
(第0,3列是索引,第1,4列是np.random的結果,第2,5列是random.randint的結果)
假設上述方案對一個時代內可以防止不同的工人出現隨機值相同的情況,但不同的時代之間,其最終的隨機種子仍然是不變的。
03 那應該如何解決
來自pytorch官方的解決方案:
https://github.com/pytorch/pytorch/pull/56488#issuecomment-825128350
def worker_init_fn(worker_id):worker_seed = torch.initial_seed() % 2**32np.random.seed(worker_seed)random.seed(worker_seed)ds?=?DataLoader(ds,?10,?shuffle=False,?num_workers=4,?worker_init_fn=worker_init_fn)來自numpy.random原作者的解決方案:
https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562
def worker_init_fn(id):process_seed = torch.initial_seed()# Back out the base_seed so we can use all the bits.base_seed = process_seed - idss = np.random.SeedSequence([id, base_seed])# More than 128 bits (4 32-bit words) would be overkill.np.random.seed(ss.generate_state(4))ds?=?DataLoader(ds,?10,?shuffle=False,?num_workers=4,?worker_init_fn=worker_init_fn)一個更簡單但不保證正確性的解決方案:
def worker_init_fn(worker_id):np.random.seed((worker_id + torch.initial_seed()) % np.iinfo(np.int32).max)ds?=?DataLoader(ds,?10,?shuffle=False,?num_workers=4,?worker_init_fn=worker_init_fn)04 附上可運行的完整文件
import numpy as np import random import torch# np.random.seed(0)class Transform(object):def __init__(self):passdef __call__(self, item = None):return [np.random.randint(10000, 20000), random.randint(20000,30000)]class RandomDataset(object):def __init__(self):passdef __getitem__(self, ind):item = [ind, np.random.randint(1, 10000), random.randint(10000, 20000), 0]tsfm =Transform()(item)return np.array(item + tsfm)def __len__(self):return 20from torch.utils.data import DataLoaderdef worker_init_fn(worker_id):np.random.seed(np.random.get_state()[1][0] + worker_id)ds = RandomDataset() ds = DataLoader(ds, 10, shuffle=False, num_workers=4, worker_init_fn=worker_init_fn)for epoch in range(2):print("epoch {}".format(epoch))np.random.seed()for batch in ds:print(batch)如果覺得有用,就請分享到朋友圈吧!
往期精彩回顧適合初學者入門人工智能的路線及資料下載機器學習及深度學習筆記等資料打印機器學習在線手冊深度學習筆記專輯《統計學習方法》的代碼復現專輯 AI基礎下載機器學習的數學基礎專輯黃海廣老師《機器學習課程》視頻課 本站qq群851320808,加入微信群請掃碼:總結
以上是生活随笔為你收集整理的【深度学习】PyTorch 数据集随机值的完美实践的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Ubuntu/环境变量:修改/etc/e
- 下一篇: jeecg标签属性exp 用法