使用 PyTorch FSDP 微调 Llama 2 70B
引言
通過本文,你將了解如何使用 PyTorch FSDP 及相關(guān)最佳實踐微調(diào) Llama 2 70B。在此過程中,我們主要會用到 Hugging Face Transformers、Accelerate 和 TRL 庫。我們還將展示如何在 SLURM 中使用 Accelerate。
完全分片數(shù)據(jù)并行 (Fully Sharded Data Parallelism,F(xiàn)SDP) 是一種訓(xùn)練范式,在該范式中優(yōu)化器狀態(tài)、梯度和模型參數(shù)都會被跨設(shè)備分片。前向傳播時,每個 FSDP 單元執(zhí)行 all gather 以獲取完整的權(quán)重,然后用它們進(jìn)行計算并在計算后丟棄掉其他設(shè)備的分片。隨后是反向傳播,然后就是損失計算。反向傳播時,每個 FSDP 單元執(zhí)行 all gather 操作以獲取完整的權(quán)重,并執(zhí)行計算以獲得本地 batch 的梯度。這些梯度通過 reduce scatter 在設(shè)備上進(jìn)行均值計算并分片,這樣每個設(shè)備都可以更新其對應(yīng)分片的參數(shù)。有關(guān) PyTorch FSDP 的更多信息,請參閱此博文: 使用 PyTorch 完全分片數(shù)據(jù)并行技術(shù)加速大模型訓(xùn)練。
(圖源: 鏈接)
使用的硬件
節(jié)點數(shù): 2,至少 1 個節(jié)點
每節(jié)點 GPU 數(shù): 8
GPU 類型: A100
GPU 顯存: 80GB
節(jié)點內(nèi)互聯(lián): NVLink
每節(jié)點內(nèi)存: 1TB
每節(jié)點 CPU 核數(shù): 96
節(jié)點間互聯(lián): AWS 的 Elastic Fabric Adapter (EFA)
微調(diào) LLaMa 2 70B 面臨的挑戰(zhàn)
在嘗試使用 FSDP 微調(diào) LLaMa 2 70B 時,我們主要遇到了三個挑戰(zhàn):
- FSDP 會先加載整個預(yù)訓(xùn)練模型,然后再對模型進(jìn)行分片。這樣就意味著節(jié)點內(nèi)的每個進(jìn)程 (即 rank) 都會加載整個 Llama-70B 模型,因此需要 7048 GB ~ 2TB 的 CPU 內(nèi)存,這個算式中 4 是每個參數(shù)所需字節(jié)數(shù),8 是每個節(jié)點的 GPU 數(shù)。這會導(dǎo)致 CPU 內(nèi)存不足,進(jìn)而導(dǎo)致進(jìn)程終止。
- 使用
FULL_STATE_DICT來保存完整中間檢查點并將其卸載至 rank 0 的 CPU 內(nèi)存中需要花費大量時間,且由于在此期間通信庫需要無限期掛起等待保存完成,因此經(jīng)常會導(dǎo)致 NCCL 超時錯誤。然而,完全關(guān)掉這個選項也不好,因為在訓(xùn)練結(jié)束時我們需要保存完整的模型狀態(tài)字典,而不是 FSDP 式分片的狀態(tài)字典。 - 我們需要提高速度并減少顯存使用,以加快訓(xùn)練并節(jié)約計算成本。
下文,我們主要討論如何一一解決上述挑戰(zhàn),最終微調(diào)出一個 70B 的模型!
先列出重現(xiàn)結(jié)果所需的所有資源:
- 代碼庫: https://github.com/pacman100/DHS-LLM-Workshop/tree/main/chat_assistant/training,代碼中包含了使能 flash 注意力 V2 的熱補(bǔ)丁
- FSDP 配置文件: https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/configs/fsdp_config.yaml
- SLURM 啟動腳本 -
launch.slurm: https://gist.github.com/pacman100/1cb1f17b2f1b3139a63b764263e70b25 - 模型:
meta-llama/Llama-2-70b-chat-hf - 數(shù)據(jù)集: smangrul/code-chat-assistant-v1 (混合了 LIMA 和 GUANACO 數(shù)據(jù)集,且已轉(zhuǎn)換為訓(xùn)練所需的格式)
準(zhǔn)備工作
首先按照 此步驟 安裝 Flash Attention V2。然后,安裝最新的 PyTorch nightly (CUDA ≥11.8)。接著,根據(jù) 此文件 安裝其余依賴軟件。在本文中,我們是從主分支安裝 ?? Accelerate 和 ?? Transformers 的。
微調(diào)
應(yīng)對挑戰(zhàn) 1
PR 25107 和 PR 1777 解決了第一個挑戰(zhàn),且無需用戶側(cè)更改任何代碼。主要做的事情如下:
- 在所有 rank 上創(chuàng)建無權(quán)重的空模型 (使用
meta設(shè)備) - 僅在 rank 0 上將狀態(tài)字典加載至模型
- 其他 rank 僅對
meta設(shè)備上的參數(shù)執(zhí)行torch.empty(*param.size(), dtype=dtype) - 因此,只有 rank 0 上加載了完整的模型及權(quán)重,而所有其他 rank 上的權(quán)重是空的
- 設(shè)置
sync_module_states=True,以便 FSDP 實例在訓(xùn)練開始之前將權(quán)重廣播到各 rank
下面是在 2 個 GPU 上加載 7B 模型的輸出日志片段,它測量了各個階段內(nèi)存的消耗及其加載的模型參數(shù)量。我們可以觀察到,在加載預(yù)訓(xùn)練模型時,rank 0 和 rank 1 的 CPU 峰值內(nèi)存分別為 32744 MB 和 1506 MB 。因此可知,僅有 rank 0 加載了預(yù)訓(xùn)練模型,這就實現(xiàn)了 CPU 內(nèi)存的有效利用。你可在 此處 找到完整日志。
accelerator.process_index=0 GPU Memory before entering the loading : 0
accelerator.process_index=0 GPU Memory consumed at the end of the loading (end-begin): 0
accelerator.process_index=0 GPU Peak Memory consumed during the loading (max-begin): 0
accelerator.process_index=0 GPU Total Peak Memory consumed during the loading (max): 0
accelerator.process_index=0 CPU Memory before entering the loading : 926
accelerator.process_index=0 CPU Memory consumed at the end of the loading (end-begin): 26415
accelerator.process_index=0 CPU Peak Memory consumed during the loading (max-begin): 31818
accelerator.process_index=0 CPU Total Peak Memory consumed during the loading (max): 32744
accelerator.process_index=1 GPU Memory before entering the loading : 0
accelerator.process_index=1 GPU Memory consumed at the end of the loading (end-begin): 0
accelerator.process_index=1 GPU Peak Memory consumed during the loading (max-begin): 0
accelerator.process_index=1 GPU Total Peak Memory consumed during the loading (max): 0
accelerator.process_index=1 CPU Memory before entering the loading : 933
accelerator.process_index=1 CPU Memory consumed at the end of the loading (end-begin): 10
accelerator.process_index=1 CPU Peak Memory consumed during the loading (max-begin): 573
accelerator.process_index=1 CPU Total Peak Memory consumed during the loading (max): 1506
應(yīng)對挑戰(zhàn) 2
該挑戰(zhàn)可以通過在配置 FSDP 時將狀態(tài)字典類型設(shè)為 SHARDED_STATE_DICT 來解決。設(shè)為 SHARDED_STATE_DICT 后,每個 rank 各自保存各自 GPU 所需要的分片,這使得用戶可以快速保存中間檢查點并快速從其恢復(fù)訓(xùn)練。而當(dāng)使用 FULL_STATE_DICT 時,第一個進(jìn)程 (rank 0) 會用 CPU 收集整個模型,然后將其保存為標(biāo)準(zhǔn)格式。
我們可以用以下命令創(chuàng)建相應(yīng)的 accelerte 配置文件:
accelerate config --config_file "fsdp_config.yaml"
你可以從此處獲取生成的配置文件: fsdp_config.yaml。在該配置文件中,分片策略是 FULL_SHARD 。我們使用 TRANSFORMER_BASED_WRAP 作為自動模型包裝策略,它使用 _no_split_module 來搜索 transformer 塊名并自動進(jìn)行嵌套 FSDP 包裝。我們使用 SHAARDED_STATE_DICT 把中間檢查點和優(yōu)化器狀態(tài)保存為 PyTorch 官方推薦的格式。同時,如上一節(jié) 應(yīng)對挑戰(zhàn) 1 中所述,我們還需要確保訓(xùn)練開始時用 rank 0 來廣播參數(shù)。從配置文件中你還可以看到我們用的是 bf16 混合精度訓(xùn)練。
那么,在保存最終檢查點時,如果將其保存成單個文件呢?我們使用的是以下代碼段:
if trainer.is_fsdp_enabled:
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
trainer.save_model(script_args.output_dir) # 或者 , 如果整個模型小于 50 GB (即 LFS 單文件的最大尺寸),你還可以使用 trainer.push_to_hub() 把模型推到 hub 上去。
應(yīng)對挑戰(zhàn) 3
為了加快訓(xùn)練速度并減少顯存占用,我們可以使用 flash 注意力并開啟梯度檢查點優(yōu)化,從而在微調(diào)的同時節(jié)省計算成本。當(dāng)前,我們用了一個熱補(bǔ)丁來實現(xiàn) flash 注意力,具體代碼可見 這兒。
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness 一文基于對底層硬件 (即 GPU) 的內(nèi)存層次結(jié)構(gòu)的深刻理解而引入了一種更快、更節(jié)省內(nèi)存的無損注意力加速算法。底層硬件在設(shè)計內(nèi)存層次結(jié)構(gòu)時,遵循的實踐原則是: 帶寬/速度越高的內(nèi)存,其容量越小,因為它更貴。
根據(jù)博文 根據(jù)第一性原理讓深度學(xué)習(xí)性能起飛,我們可以發(fā)現(xiàn),當(dāng)前硬件上的注意力模塊是 內(nèi)存帶寬受限 的。原因是注意力機(jī)制 主要由逐元素操作 組成,如下左圖所示。我們可以觀察到,掩碼、softmax 和 dropout 操作占用了大部分時間,而非需要大量 FLOP 的矩陣乘法。
(圖源: 鏈接)
這正是 flash 注意力解決的問題,其想法是 去除冗余的 HBM 讀/寫操作。該算法通過將所有內(nèi)容保留在 SRAM 中,待執(zhí)行完所有中間步驟后再將最終結(jié)果寫回到 HBM,即 算子融合 來實現(xiàn)這一目的。下圖簡要描述了算子融合是如何克服內(nèi)存瓶頸的。
(圖源: 鏈接)
在前向和反向傳播過程中我們還使用了 平鋪 (Tiling) 優(yōu)化技巧,將 NxN 大小的 softmax 分?jǐn)?shù)計算切成塊,以克服 SRAM 內(nèi)存大小的限制。在使用平鋪技巧時,我們會使用在線 softmax 算法。同時,我們還在反向傳播中使用了 重計算 技巧,以大大降低在前向傳播過程中存儲整個 NxN softmax 分?jǐn)?shù)矩陣所帶來的內(nèi)存消耗。
如欲深入理解 flash 注意力,請參考博文 ELI5: FlashAttention、根據(jù)第一性原理讓深度學(xué)習(xí)性能起飛 以及原始論文 FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness。
綜合運用所有手段
你可參考 此腳本,以在 SLURM 中用 Accelerate 啟動器運行訓(xùn)練。下面還給出了一個等效命令,展示了如何使用 Accelerate 啟動器來運行訓(xùn)練。請注意,該命令會覆蓋 fsdp_config.yaml 中的 main_process_ip 、 main_process_port 、 machine_rank 、 num_processes 以及 num_machines 配置。另一個需要重點注意的是,這里的存儲是所有節(jié)點共享的。
accelerate launch \
--config_file configs/fsdp_config.yaml \
--main_process_ip $MASTER_ADDR \
--main_process_port $MASTER_PORT \
--machine_rank \$MACHINE_RANK \
--num_processes 16 \
--num_machines 2 \
train.py \
--model_name "meta-llama/Llama-2-70b-chat-hf" \
--dataset_name "smangrul/code-chat-assistant-v1" \
--max_seq_len 2048 \
--max_steps 500 \
--logging_steps 25 \
--eval_steps 100 \
--save_steps 250 \
--bf16 True \
--packing True \
--output_dir "/shared_storage/sourab/experiments/full-finetune-llama-chat-asst" \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--dataset_text_field "content" \
--use_gradient_checkpointing True \
--learning_rate 5e-5 \
--lr_scheduler_type "cosine" \
--weight_decay 0.01 \
--warmup_ratio 0.03 \
--use_flash_attn True
整個微調(diào)過程需要約 13.5 小時,下圖給出了訓(xùn)練損失曲線。
下例給出了使用上述模型完成的一段對話:
System Prompt: You are a helpful, respectful and honest assistant. Always answer as helpfully \
as possible, while being safe. Your answers should not include any harmful, \
unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that \
your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why \
instead of answering something not correct. If you don’t know the answer to a \
question, please don’t share false information.
- Human: What is Deep Learning? Explain like a Pirate.
+ Assistant: Arrr Grumete! Are ye lookin' fer a proper explanation o' Deep Learnin'?
+ Alright then, listen up me hearty! Deep Learnin' be a type o' machine learnin' that uses artificial neural networks to analyze data.
+ Aye, ye heard that right, mateys! Artificial neural networks be models o' the human brain, but instead o' using neurons, they use interconnected nodes or "artificial neurons" to process information.
+ These nodes work together to form layers o' processing units that can learn from large amounts o' data.
+ The more data these systems receive, the better they get at recognizing patterns and makin' decisions on their own.
+ So hoist the sails and set course for adventure with deep learnin', me hearties! It's a treasure trove o' knowledge just waitin' to be discovered.
+ Fair winds and following seas!
- Human: Now explain it like a chef.
+ Assistant: Certainly! Here's an explanation of deep learning from a culinary perspective:
+ Think of a recipe as a sequence of steps used to transform raw ingredients into a delicious dish.
+ Similarly, in deep learning, there are multiple layers of "ingredients" (or features) that are combined and transformed through various operations to produce a final output or prediction.
+ Each layer builds upon the previous one, refining and enhancing the representation of the input data until it reaches the desired outcome.
+ Just as different spices enhance the flavors of food, each layer adds depth and complexity to the model's understanding of the input, allowing it to make more accurate predictions.
+ And much like how chefs must carefully balance the amount of salt, sugar, and other seasonings in a recipe, the training process in deep learning involves adjusting the weights and biases of the network to achieve the perfect blend of accuracy and efficiency.
+ With enough practice and experimentation, anyone can become a master chef... I mean, a master of deep learning! Bon appétit!
整個對話使用的格式如下:
<|system|> system message <|endoftext|> <|prompter|> Q1 <|endoftext|> <|assistant|> A1 <|endoftext|> ...
總結(jié)
我們在多節(jié)點多 GPU 上使用 PyTorch FSDP 成功微調(diào)了一個 70B Llama 模型,并在此過程中解決了各種挑戰(zhàn)。我們看到了當(dāng)前在 ?? Transformers 和 ?? Accelerates 中應(yīng)如何初始化大模型從而有效克服 CPU 內(nèi)存不足的問題。我們還給出了如何高效地保存/加載中間檢查點,同時又能以易于使用的方式保存最終模型的最佳實踐。為了加速訓(xùn)練并減少 GPU 顯存使用,我們還強(qiáng)調(diào)了 flash 注意力和梯度檢查點機(jī)制的重要性。最后,我們向大家展示了在 ?? Accelerate 上僅需要簡單的配置就可以在多節(jié)點多 GPU 上微調(diào)大模型。
英文原文: https://hf.co/blog/ram-efficient-pytorch-fsdp
原文作者: Sourab Mangrulkar,Sylvain Gugger,Lewis Tunstall,Philipp Schmid
譯者: Matrix Yao (姚偉峰),英特爾深度學(xué)習(xí)工程師,工作方向為 transformer-family 模型在各模態(tài)數(shù)據(jù)上的應(yīng)用及大規(guī)模模型的訓(xùn)練推理。
FSDP MFU (Model FLOPS Utilization) 相關(guān)討論: https://github.com/huggingface/blog/issues/1649
總結(jié)
以上是生活随笔為你收集整理的使用 PyTorch FSDP 微调 Llama 2 70B的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 自己的花是让别人看的教案一等奖
- 下一篇: dnf二级密码有什么用(地下城与勇士)