Pytorch中torch.nn.DataParallel负载均衡问题
生活随笔
收集整理的這篇文章主要介紹了
Pytorch中torch.nn.DataParallel负载均衡问题
小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.
1. 問題概述
現(xiàn)在Pytorc下進(jìn)行多卡訓(xùn)練主流的是采用torch.nn.parallel.DistributedDataParallel()(DDP)方法,但是在一些特殊的情況下這樣的方法就使用不了了,特別是在進(jìn)行與GAN相關(guān)的訓(xùn)練的時(shí)候,假如使用的損失函數(shù)是 WGAN-GP(LP),DRAGAN,那么其中會(huì)用到基于梯度的懲罰,其使用到的函數(shù)為torch.autograd.grad(),但是很不幸的是在實(shí)驗(yàn)的過程中該函數(shù)使用DDP會(huì)報(bào)錯(cuò):
File "/home/work/anaconda3/envs/xxxxx_py/lib/python3.6/site-packages/torch/autograd/__init__.py", line 93, in backwardallow_unreachable=True) # allow_unreachable flag RuntimeError: derivative for batch_norm_backward_elemt is not implemented那么需要并行(單機(jī)多卡)計(jì)算那么就只能使用torch.nn.DataParallel()了,但是也帶來另外一個(gè)問題那就是負(fù)載極其不均衡,使用這個(gè)并行計(jì)算方法會(huì)在主GPU上占據(jù)較多的現(xiàn)存,而其它的GPU顯存則只占用了一部分,這樣就使得無法再繼續(xù)增大batchsize了,下圖就是這種方式進(jìn)行計(jì)算,整個(gè)數(shù)據(jù)流的路線:
總結(jié)
以上是生活随笔為你收集整理的Pytorch中torch.nn.DataParallel负载均衡问题的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 高德地图地址解析经纬度以及经纬度解析地址
- 下一篇: sass使用指南