了解torch.nn.DataParallel
生活随笔
收集整理的這篇文章主要介紹了
了解torch.nn.DataParallel
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
CLASS torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)
- 在模塊級實現數據并行。
- 該容器通過在批處理維度中分組,將輸入分割到指定的設備上,從而并行化給定模塊的應用程序(其他對象將在每個設備上復制一次)。在前向傳播中,模塊被復制到每個設備上,每個副本處理輸入的一部分。在反向傳播過程中,來自每個副本的梯度被累加到原始模塊中。
- 批處理大小應該大于所使用的gpu數量。
- 允許將任意位置和關鍵字輸入傳遞到DataParallel中,但有些類型是專門處理的
- tensor將分散到指定dim(默認為0)。tuple, list以及dict類型將淺拷貝。其他類型將在不同的線程之間共享,如果在模型的前向傳播中寫入,則可能被破壞。
- 并行模塊必須在device_ids[0]上有它的parameters和buffers,然后才能運行這個DataParallel模塊。
Parameters
- module (Module) – module to be parallelized
- device_ids (list of python:int or torch.device) – CUDA devices (default: all devices)
- output_device (int or torch.device) – device location of output (default: device_ids[0])
Example
net = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) output = net(input_var) # input_var can be on any device, including CPUpytorch文檔
總結
以上是生活随笔為你收集整理的了解torch.nn.DataParallel的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: ME525 刷机历险记
- 下一篇: C++ 程序越过windows Defe