通用合併上下文管理器¶
通用合併上下文管理器促進了不均勻輸入上的分散式訓練。此頁面概述了相關類別的 API:Join、Joinable 和 JoinHook。如需教學課程,請參閱使用合併上下文管理器進行不均勻輸入的分散式訓練。
- class torch.distributed.algorithms.Join(joinables, enable=True, throw_on_early_termination=False, **kwargs)[原始碼]¶
- 此類別定義了通用合併上下文管理器,它允許在程序合併後呼叫自訂掛鉤。 - 這些掛鉤應該遮蔽未合併程序的集體通訊,以防止掛起和錯誤,並確保演算法的正確性。如需有關掛鉤定義的詳細資訊,請參閱 - JoinHook。- 警告 - 上下文管理器要求每個參與的 - Joinable在其每次迭代的集體通訊之前呼叫方法- notify_join_context(),以確保正確性。- 警告 - 上下文管理器要求 - JoinHook物件中的所有- process_group屬性都相同。如果有多個- JoinHook物件,則使用第一個物件的- device。程序群組和裝置資訊用於檢查未合併的程序,以及在啟用- throw_on_early_termination的情況下通知程序擲回例外,兩者都使用全簡化。- 參數
 - 範例 - >>> import os >>> import torch >>> import torch.distributed as dist >>> import torch.multiprocessing as mp >>> import torch.nn.parallel.DistributedDataParallel as DDP >>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO >>> from torch.distributed.algorithms.join import Join >>> >>> # On each spawned worker >>> def worker(rank): >>> dist.init_process_group("nccl", rank=rank, world_size=2) >>> model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank]) >>> optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01) >>> # Rank 1 gets one more input than rank 0 >>> inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)] >>> with Join([model, optim]): >>> for input in inputs: >>> loss = model(input).sum() >>> loss.backward() >>> optim.step() >>> # All ranks reach here without hanging/erroring 
- class torch.distributed.algorithms.Joinable[source]¶
- 這定義了一個用於可加入類別的抽象基類別。 - 一個可加入類別(繼承自 - Joinable)應該實作- join_hook(),它會返回一個- JoinHook實例,以及返回設備和進程組信息的- join_device()和- join_process_group()。