nn.Dataparallel pytorch 平行计算的两种方法
                                                            生活随笔
收集整理的這篇文章主要介紹了
                                nn.Dataparallel    pytorch 平行计算的两种方法
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.                        
                                1. nn.Dataparallel
多GPU加速訓練
原理:
 模型分別復制到每個卡中,然后把輸入切片,分別放入每個卡中計算,然后再用第一塊卡進行匯總求loss,反向傳播更新參數。
第一塊卡占用的內存多一點,因為output loss每次都會在第一塊GPU相加計算,這就造成了第一塊GPU的負載遠遠大于剩余其他的顯卡。
要求:
 batch_size > GPU 數量
第一種方法:
os.environment['CUDA_VISIBLE_DEVICES'] = '0,1,2,3' device_ids = [0,1,2,3] net = torch.nn. Dataparallel(net, device_ids =device_ids) net = net.cuda()第二種方法
os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2" if torch.cuda.is_available():self.device = "cuda"if torch.cuda.device_count() > 1:self.G = nn.DataParallel(self.G)self.D_A = nn.DataParallel(self.D_A)self.D_B = nn.DataParallel(self.D_B)self.vgg = nn.DataParallel(self.vgg)self.criterionHis = nn.DataParallel(self.criterionHis)self.criterionGAN = nn.DataParallel(self.criterionGAN)self.criterionL1 = nn.DataParallel(self.criterionL1)self.criterionL2 = nn.DataParallel(self.criterionL2)self.criterionGAN = nn.DataParallel(self.criterionGAN)self.G.cuda()self.vgg.cuda()self.criterionHis.cuda()self.criterionGAN.cuda()self.criterionL1.cuda()self.criterionL2.cuda()self.D_A.cuda()self.D_B.cuda()2.模型分別單獨放入每個指定的GPU中
把模型分別放到指定的GPU中,然后在運算的過程中,需要把利用**.to(cuda:x)** 去轉移數據。這樣暫用的內存比平行計算小。但是配置復雜一點。
vgg_encoder = VGGEncoder().to('cuda:0')attn=CoAttention(channel=512).to('cuda:1')decoder = Decoder().to('cuda:2')optimizer_decoder = Adam(decoder.parameters(), lr=args.learning_rate)optimizer_attn = Adam(attn.parameters(), lr=args.learning_rate)content = content.cuda() # 默認的是cuda:0style = style.cuda()content_features = vgg_encoder(content, output_last_feature=True)style_features = vgg_encoder(style, output_last_feature=True)content_features, style_features=attn(content_features.to('cuda:1'),style_features.to('cuda:1')) # 因為attn在cuda:1中總結
以上是生活随笔為你收集整理的nn.Dataparallel pytorch 平行计算的两种方法的全部內容,希望文章能夠幫你解決所遇到的問題。
                            
                        - 上一篇: python缩进教学_Python缩进和
 - 下一篇: Pytorch RuntimeERROR