pytorch DDP加速之gradient accumulation设置
                                                            生活随笔
收集整理的這篇文章主要介紹了
                                pytorch DDP加速之gradient accumulation设置
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.                        
                                pytorch DDP
參考:https://zhuanlan.zhihu.com/p/250471767
 GPU高效通信算法-Ring Allreduce: https://www.zhihu.com/question/57799212/answer/612786337
 梯度累積: https://www.zhihu.com/question/303070254/answer/573037166
gradient accumulation
在梯度累加的情況下,假設一次梯度累加循環有accumulation_steps個step,每次梯度累加循環會進行K次 all_reduce,但事實上,每次梯度累加循環只會有一次 optimizer.step(),即只應用一次參數更新,這意味著在每一次梯度累加循環中,我們其實只要進行一次gradient all_reduce即可滿足要求,有accumulation_steps - 1次all_reduce被浪費了。而每次 all_reduce的時間成本是比較高的。 解決問題的思路在于,對前accumulation_steps - 1次step取消其梯度同步。DDP給我們提供了一個暫時取消梯度同步的context函數 no_sync()(源代碼)。在這個context下,DDP不會進行梯度同步。
for epoch in range(epoches):for j, data in enumerate(train_loader):# 前accumulation_steps - 1個step,不進行梯度同步,累積梯度。if accumulation_count % accumulation_steps != 0:with model.no_sync():loss = model(data)loss = loss/accumulation_stepsloss.backward()else:loss = model(data)loss = loss / accumulation_stepsloss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 1)model_optimizer.step()if model_scheduler is not None:model_scheduler.step()model_optimizer.zero_grad()accumulation_count += 1優雅的寫法(兼容單卡和DDP模式):
from contextlib import nullcontext # 如果python版本小于3.7,則使用下面這個: # from contextlib import suppress as nullcontextif local_rank != -1:model = DDP(model)optimizer.zero_grad() for epoch in range(epoches):for i, data in enumerate(train_loader):# 只在DDP模式下,輪數不是accumulation_steps整數倍的時候使用no_syncmcontext = model.no_sync if local_rank != -1 and accumulation_count % accumulation_steps != 0 else nullcontextwith mcontext():loss = model(data)loss = loss / accumulation_stepsloss.backward()# 輪數為accumulation_steps整數倍的時候,傳播梯度,并更新參數if accumulation_count % accumulation_steps == 0:optimizer.step()if model_scheduler is not None:model_scheduler.step()optimizer.zero_grad()accumulation_count += 1總結
以上是生活随笔為你收集整理的pytorch DDP加速之gradient accumulation设置的全部內容,希望文章能夠幫你解決所遇到的問題。
 
                            
                        - 上一篇: 使用PowerShell配置Micros
- 下一篇: WP10回滚WP8.1详细教程,变砖也可
