• 教程 >
  • 使用 torchrun 進行容錯分散式訓練
快捷方式

入門 || 什麼是 DDP || 單節點多 GPU 訓練 || 容錯 || 多節點訓練 || minGPT 訓練

使用 torchrun 進行容錯分散式訓練

建立日期:2022年9月27日 | 最後更新:2024年11月12日 | 最後驗證:2024年11月5日

作者:Suraj Subramanian

你將學到什麼
  • 使用 torchrun 啟動多 GPU 訓練任務

  • 儲存和載入訓練任務快照

  • 構建訓練指令碼以實現優雅重啟

GitHub 上檢視本教程使用的程式碼

前置條件
  • DDP 的高層級概覽

  • 熟悉 DDP 程式碼

  • 一臺配備多個 GPU 的機器(本教程使用 AWS p3.8xlarge 例項)

  • 安裝支援 CUDA 的 PyTorch

請觀看下方影片或在 YouTube 上觀看。

在分散式訓練中,單個程序故障可能會中斷整個訓練任務。由於這裡發生故障的可能性更高,因此使你的訓練指令碼具有魯棒性尤其重要。你可能還希望訓練任務具有彈性,例如,計算資源可以在任務執行期間動態加入和離開。

PyTorch 提供了一個名為 torchrun 的實用程式,它提供了容錯和彈性訓練功能。當發生故障時,torchrun 會記錄錯誤並嘗試從上次儲存的訓練任務“快照”自動重啟所有程序。

快照不僅儲存模型狀態,還可以包含已執行的 epoch 數量、最佳化器狀態或訓練任務連續性所需的任何其他有狀態屬性的詳細資訊。

為何使用 torchrun

torchrun 處理分散式訓練的細節,因此你無需關心這些。例如,

  • 你無需設定環境變數或顯式傳遞 rankworld_sizetorchrun 會自動分配這些以及其他幾個環境變數

  • 你無需在指令碼中呼叫 mp.spawn;你只需要一個通用的 main() 入口點,然後使用 torchrun 啟動指令碼。這樣,同一個指令碼可以在非分散式、單節點和多節點環境中執行。

  • 從上次儲存的訓練快照優雅地重啟訓練。

優雅重啟

為了實現優雅重啟,你應該這樣組織你的訓練指令碼:

def main():
  load_snapshot(snapshot_path)
  initialize()
  train()

def train():
  for batch in iter(dataset):
    train_step(batch)

    if should_checkpoint:
      save_snapshot(snapshot_path)

如果發生故障,torchrun 將終止所有程序並重新啟動它們。每個程序入口點首先載入並初始化上次儲存的快照,然後從那裡繼續訓練。因此,在任何故障發生時,你只會丟失上次儲存快照之後的訓練進度。

在彈性訓練中,無論何時發生成員變化(新增或移除節點),torchrun 都會終止並在可用裝置上生成程序。擁有這種結構可以確保你的訓練任務可以在無需手動干預的情況下繼續進行。

multigpu.pymultigpu_torchrun.py 的差異對比

程序組初始化

- def ddp_setup(rank, world_size):
+ def ddp_setup():
-     """
-     Args:
-         rank: Unique identifier of each process
-         world_size: Total number of processes
-     """
-     os.environ["MASTER_ADDR"] = "localhost"
-     os.environ["MASTER_PORT"] = "12355"
-     init_process_group(backend="nccl", rank=rank, world_size=world_size)
+     init_process_group(backend="nccl")
     torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

使用 torchrun 提供的環境變數

- self.gpu_id = gpu_id
+ self.gpu_id = int(os.environ["LOCAL_RANK"])

儲存和載入快照

定期將所有相關資訊儲存在快照中,可以使我們的訓練任務在中斷後無縫恢復。

+ def _save_snapshot(self, epoch):
+     snapshot = {}
+     snapshot["MODEL_STATE"] = self.model.module.state_dict()
+     snapshot["EPOCHS_RUN"] = epoch
+     torch.save(snapshot, "snapshot.pt")
+     print(f"Epoch {epoch} | Training snapshot saved at snapshot.pt")

+ def _load_snapshot(self, snapshot_path):
+     snapshot = torch.load(snapshot_path)
+     self.model.load_state_dict(snapshot["MODEL_STATE"])
+     self.epochs_run = snapshot["EPOCHS_RUN"]
+     print(f"Resuming training from snapshot at Epoch {self.epochs_run}")

在 Trainer 建構函式中載入快照

當重啟中斷的訓練任務時,你的指令碼將首先嚐試載入快照以從中恢復訓練。

class Trainer:
   def __init__(self, snapshot_path, ...):
   ...
+  if os.path.exists(snapshot_path):
+     self._load_snapshot(snapshot_path)
   ...

恢復訓練

訓練可以從上次執行的 epoch 恢復,而不是從頭開始。

def train(self, max_epochs: int):
-  for epoch in range(max_epochs):
+  for epoch in range(self.epochs_run, max_epochs):
      self._run_epoch(epoch)

執行指令碼

就像執行非多程序指令碼一樣簡單地呼叫你的入口點函式;torchrun 會自動生成程序。

if __name__ == "__main__":
   import sys
   total_epochs = int(sys.argv[1])
   save_every = int(sys.argv[2])
-  world_size = torch.cuda.device_count()
-  mp.spawn(main, args=(world_size, total_epochs, save_every,), nprocs=world_size)
+  main(save_every, total_epochs)
- python multigpu.py 50 10
+ torchrun --standalone --nproc_per_node=4 multigpu_torchrun.py 50 10

為本教程評分

© 版權所有 2024, PyTorch.

使用 Sphinx 構建,主題由 Read the Docs 提供。

文件

訪問 PyTorch 全面開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲取問題解答

檢視資源