訓練腳本¶
如果您的訓練腳本適用於 torch.distributed.launch,它將繼續適用於 torchrun,但有一些差異
- 無需手動傳遞 - RANK、- WORLD_SIZE、- MASTER_ADDR和- MASTER_PORT。
- 可以提供 - rdzv_backend和- rdzv_endpoint。對於大多數用戶,這將設定為- c10d(請參閱會合)。預設的- rdzv_backend會建立一個非彈性會合,其中- rdzv_endpoint保存主機地址。
- 請確保您的腳本中有 - load_checkpoint(path)和- save_checkpoint(path)邏輯。當任何數量的 worker 失敗時,我們會使用相同的程式參數重新啟動所有 worker,因此您將失去到最近檢查點的進度(請參閱彈性啟動)。
- use_env旗標已移除。如果您透過解析- --local-rank選項來解析本地端排名,則需要從環境變數- LOCAL_RANK中取得本地端排名(例如- int(os.environ["LOCAL_RANK"]))。
以下是一個訓練腳本的說明性範例,該腳本在每個 epoch 進行檢查點,因此在失敗時最壞情況下損失的進度是一個完整 epoch 的訓練。
def main():
     args = parse_args(sys.argv[1:])
     state = load_checkpoint(args.checkpoint_path)
     initialize(state)
     # torch.distributed.run ensures that this will work
     # by exporting all the env vars needed to initialize the process group
     torch.distributed.init_process_group(backend=args.backend)
     for i in range(state.epoch, state.total_num_epochs)
          for batch in iter(state.dataset)
              train(batch, state.model)
          state.epoch += 1
          save_checkpoint(state)
有關符合 torchelastic 的訓練腳本的具體範例,請造訪我們的範例頁面。