快捷方式

訓練指令碼

如果您的訓練指令碼適用於 torch.distributed.launch,那麼它在 torchrun 中將繼續可用,但存在以下差異

  1. 無需手動傳遞 RANKWORLD_SIZEMASTER_ADDRMASTER_PORT

  2. rdzv_backendrdzv_endpoint 可以提供。對於大多數使用者,這將設定為 c10d(參見 rendezvous)。預設的 rdzv_backend 會建立一個非彈性的 rendezvous,其中 rdzv_endpoint 儲存主地址。

  3. 請確保您的指令碼中包含 load_checkpoint(path)save_checkpoint(path) 邏輯。當任意數量的 worker 失敗時,我們將使用相同的程式引數重啟所有 worker,因此您將丟失到最近檢查點為止的進度(參見 elastic launch)。

  4. use_env 標誌已被移除。如果您之前透過解析 --local-rank 選項來解析本地 rank,您現在需要從環境變數 LOCAL_RANK 中獲取本地 rank(例如 int(os.environ["LOCAL_RANK"]))。

下面是一個說明性的訓練指令碼示例,該指令碼在每個 epoch 進行檢查點儲存,因此在發生故障時最壞情況下丟失的進度是一個完整的 epoch 的訓練量。

def main():
     args = parse_args(sys.argv[1:])
     state = load_checkpoint(args.checkpoint_path)
     initialize(state)

     # torch.distributed.run ensures that this will work
     # by exporting all the env vars needed to initialize the process group
     torch.distributed.init_process_group(backend=args.backend)

     for i in range(state.epoch, state.total_num_epochs)
          for batch in iter(state.dataset)
              train(batch, state.model)

          state.epoch += 1
          save_checkpoint(state)

有關符合 torchelastic 的訓練指令碼的具體示例,請訪問我們的示例頁面。

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲取問題解答

檢視資源