簡介
近年來,研究社群在自然語言處理、計算機視覺和其他領域的大型模型方面取得了許多成功。其中許多成功都得益於 Cloud TPU,它們是用於分散式訓練的強大硬體。為了支援 PyTorch 中的 TPU,PyTorch/XLA 庫為 XLA 裝置(最著名的是 TPU)提供了後端,併為在 TPU 上擴充套件大型 PyTorch 模型奠定了基礎。
然而,PyTorch 生態系統中現有的大多數模型擴充套件工具都假定使用 GPU(或 CPU)裝置,通常依賴於 CUDA 中的特定功能,並且不能直接在 TPU 上工作。缺乏擴充套件工具使得構建無法適應單個 TPU 晶片記憶體的大型模型變得具有挑戰性。
為了支援在 TPU 上進行模型擴充套件,我們作為 PyTorch/XLA 1.12 釋出的一部分,為 XLA 裝置實現了廣泛採用的 完全分片資料並行 (FSDP) 演算法。我們提供了一個 FSDP 介面,其高階設計與基於 CUDA 的 PyTorch FSDP 類相似,同時還處理了 XLA 中的一些限制(有關更多詳細資訊,請參閱下面的設計說明)。這個 FSDP 介面使我們能夠輕鬆地在 TPU 上構建具有 10B+ 引數的模型,並支援了許多研究探索。
在 PyTorch/XLA 中使用完全分片資料並行 (FSDP)
我們提供了一個包裝類 XlaFullyShardedDataParallel,用於給定 PyTorch 模型,以將其引數分片到資料並行工作器中。示例如下:
import torch
import torch_xla.core.xla_model as xm
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
model = FSDP(my_module)
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
output = model(x, y)
loss = output.sum()
loss.backward()
optim.step()
使用 XlaFullyShardedDataParallel 包裝 nn.Module 例項可在其上啟用 ZeRO-2 演算法,在該演算法中,其梯度和最佳化器狀態在整個訓練過程中都被分片。在其前向和後向傳播過程中,包裝模組的完整引數首先從其相應的分片中重建以進行計算。
可以使用巢狀 FSDP 包裝來進一步節省記憶體。這允許模型在任何給定時間僅儲存一個單獨層的完整引數。對於巢狀 FSDP,應首先使用內部 FSDP 包裝其單獨的子模組,然後再使用外部 FSDP 包裝基礎模型。這允許模型在任何給定時間僅儲存一個單獨層的完整引數。並且擁有一個外部包裝器可確保處理任何剩餘引數,對應於 ZeRO-3 演算法。巢狀 FSDP 包裝可以應用於子模組的任何深度,並且可以有超過 2 層的巢狀。
模型和最佳化器的模型檢查點儲存和載入可以像以前一樣透過儲存和載入它們的 .state_dict() 來完成。同時,每個訓練過程應儲存其自己的分片模型引數和最佳化器狀態的檢查點檔案,並在恢復時載入相應等級的檢查點檔案(無論 ZeRO-2 還是 ZeRO-3,即巢狀包裝與否)。提供了一個命令列工具和一個 Python 介面,用於將分片模型檢查點檔案合併為一個完整/未分片的模型檢查點檔案。
梯度檢查點(也稱為“啟用檢查點”或“重新例項化”)是另一種用於模型擴充套件的常見技術,可以與 FSDP 結合使用。我們提供了 checkpoint_module,這是一個包裝給定 nn.Module 例項以進行梯度檢查點(基於 torch_xla.utils.checkpoint.checkpoint)的函式。
下面的 MNIST 和 ImageNet 示例提供了(普通或巢狀)FSDP、模型檢查點的儲存和合並以及梯度檢查點的說明性用法。
PyTorch/XLA 中 FSDP 的起始示例
使用 FSDP 訓練 MNIST 和 ImageNet
MNIST 和 ImageNet 分類通常可以作為構建更復雜的深度學習模型的起點。我們提供了以下關於這兩個資料集的 FSDP 示例:
- MNIST:test/test_train_mp_mnist_fsdp_with_ckpt.py(它還說明了檢查點儲存和合並)
- ImageNet:test/test_train_mp_imagenet_fsdp.py
將它們與 MNIST 和 ImageNet 的普通資料並行示例進行比較,說明了如何調整訓練指令碼以使用 FSDP。需要記住的一個主要區別是,當在 FSDP 包裝模型上最佳化器進行步進時,應直接呼叫 optimizer.step() 而不是 xm.optimizer_step(optimizer)。後者會減少跨等級的梯度,這在 FSDP 中不是我們需要的,因為梯度已經減少並分片(來自其後向傳播中的 reduce-scatter 操作)。
安裝
FSDP 可在 PyTorch/XLA 1.12 及更新的每夜版本中獲得。請參閱 https://github.com/pytorch/xla#-available-images-and-wheels 以獲取安裝指南以及 Cloud TPU 分配。然後,在 TPU VM 上克隆 PyTorch/XLA 儲存庫,如下所示:
mkdir -p ~/pytorch && cd ~/pytorch
git clone --recursive https://github.com/pytorch/xla.git
cd ~/
在 v3-8 TPU 上訓練 MNIST
它在 2 個 epoch 中獲得了大約 98.9 的準確率
python3 ~/pytorch/xla/test/test_train_mp_mnist_fsdp_with_ckpt.py \
--batch_size 16 --drop_last --num_epochs 2 \
--use_nested_fsdp
上面的指令碼會在最後自動測試分片模型檢查點的合併。您也可以透過以下方式手動合併分片檢查點檔案:
python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \
--ckpt_prefix /tmp/mnist-fsdp/final_ckpt \
--ckpt_suffix "_rank-*-of-*.pth"
在 v3-8 TPU 上使用 ResNet-50 訓練 ImageNet
它在 100 個 epoch 中獲得了大約 75.9 的準確率,與不使用 FSDP 獲得的結果相同;將 ImageNet-1k 資料集下載並預處理到 /datasets/imagenet-1k
python3 ~/pytorch/xla/test/test_train_mp_imagenet_fsdp.py \
--datadir /datasets/imagenet-1k --drop_last \
--model resnet50 --test_set_batch_size 64 --eval_interval 10 \
--lr 0.4 --batch_size 128 --num_warmup_epochs 5 \
--lr_scheduler_divide_every_n_epochs 30 --lr_scheduler_divisor 10 \
--num_epochs 100 \
--use_nested_fsdp
您還可以探索這兩個示例中的其他選項,例如 --use_gradient_checkpointing 以在 ResNet 塊上應用梯度檢查點(即啟用檢查點),或 --compute_dtype bfloat16 以 bfloat16 精度執行前向和後向傳播。
大型模型示例
在 TPU 上構建大型模型時,我們通常需要注意記憶體限制(例如,TPU v3 中每個核心 16 GB,TPU v4 中每個晶片 32 GB)。對於無法適應單個 TPU 記憶體或主機 CPU 記憶體的大型模型,應使用巢狀 FSDP 來實現 ZeRO-3 演算法,將子模組構建與內部 FSDP 包裝交錯,以便在構建過程中永遠不需要將完整模型儲存在記憶體中。
我們在 https://github.com/ronghanghu/ptxla_scaling_examples 中說明了這些情況,其中提供了在 TPU v3 Pod(帶 128 個核心)上訓練具有 10B+ 引數的 Vision Transformer (ViT) 模型以及其他情況的示例。
設計說明
有人可能想知道為什麼我們需要在 PyTorch/XLA 中開發一個單獨的 FSDP 類,而不是直接重用 PyTorch 的 FSDP 類 或將其擴充套件到 XLA 後端。在 PyTorch/XLA 中使用單獨 FSDP 類的主要動機是,原生 PyTorch 的 FSDP 類嚴重依賴 XLA 裝置不支援的 CUDA 功能,而 XLA 也具有一些需要特殊處理的獨特特性。這些差異需要 FSDP 的不同實現,在單獨的類中構建會容易得多。
API 呼叫更改
一個顯著的區別是,原生 PyTorch FSDP 基於獨立的 CUDA 流在 eager 模式下進行非同步執行,而 PyTorch/XLA 在 lazy 模式下執行,並且不支援流。此外,TPU 要求所有裝置同構地執行相同的程式。因此,在 PyTorch/XLA FSDP 實現中,CUDA 呼叫和每程序異構性需要被 XLA API 和替代的同構實現所取代。
張量儲存處理
另一個顯著的區別是如何釋放張量的儲存,這在 XLA 中比在 CUDA 中要困難得多。為了實現 ZeRO-3,需要在模組的前向傳播之後釋放完整引數的儲存,以便下一個模組可以重用此記憶體緩衝區進行後續計算。PyTorch 的 FSPD 透過 p.data.storage().resize_(0) 釋放參數 p 的實際儲存來實現這一點。然而,XLA 張量沒有這個 .storage() 控制代碼,因為 XLA HLO IR 是完全函式式的,不提供任何操作來釋放張量或調整其儲存大小。在 PyTorch 介面之下,只有 XLA 編譯器可以決定何時釋放與 XLA 張量對應的 TPU 裝置記憶體,並且一個先決條件是隻有當張量物件在 Python 中被釋放時才能釋放記憶體——這在 FSDP 中不可能發生,因為這些引數張量被引用為模組屬性,並且也被 PyTorch autograd 儲存用於反向傳播。
我們解決此問題的方法是將張量的值屬性與其自動梯度變數屬性分離,並透過將其 .data 屬性設定為大小為 1 的虛擬標量來釋放 nn.Parameter 張量。這樣,完整引數的實際資料張量在 Python 中被解除引用,以便 XLA 可以回收其記憶體用於其他計算,而自動梯度仍然可以將基礎 nn.Parameter 跟蹤為引數資料的弱引用。為了使其工作,還需要處理引數的檢視,因為 PyTorch 中的檢視也保留對其實際資料的引用(這需要在 PyTorch/XLA 中修復與檢視相關的形狀問題)。
與 XLA 編譯器協作
如果 XLA 編譯器忠實地保留了我們 PyTorch 程式中的操作及其執行順序,上述解決方案應該足以釋放完整引數。但還有一個問題——XLA 試圖透過對 HLO IR 應用公共子表示式消除 (CSE) 來最佳化程式以加速其執行。在 FSDP 的樸素實現中,XLA 編譯器通常會在反向傳播中消除第二個 all-gather 以重建完整引數,當它看到這是來自前向傳播的重複計算時,並直接保留和重用我們想要在前向傳播之後釋放的完整引數。為了防止這種不希望的編譯器行為,我們在 PyTorch/XLA 中引入了 最佳化屏障操作,並用它來阻止消除第二個 all-gather。此最佳化屏障也應用於梯度檢查點的類似情況,以防止前向和反向傳播之間的 CSE 消除重新例項化。
未來,如果 CUDA 和 XLA 之間的區別不再像上面提到的那樣突出,那麼將 PyTorch/XLA FSDP 與原生 PyTorch FSDP 合併以實現統一介面可能是值得考慮的。
致謝
感謝 AWS 的 Junmin Hao 審閱 PyTorch/XLA FSDP 拉取請求。感謝 Meta PyTorch 團隊的 Brian Hirsh 對 PyTorch 核心問題的支援。感謝 Google 的 Isaack Karanja、Will Cromar 和 Blake Hechtman 對 GCP、XLA 和 TPU 問題的支援。
感謝 Meta FAIR 的 Piotr Dollar、Wan-Yen Lo、Alex Berg、Ryan Mark、Kaiming He、Xinlei Chen、Saining Xie、Shoubhik Debnath、Min Xu 和 Vaibhav Aggarwal 進行的各種 TPU 相關討論。