• 文件 >
  • 通用合併上下文管理器
捷徑

通用合併上下文管理器

通用合併上下文管理器促進了不均勻輸入上的分散式訓練。此頁面概述了相關類別的 API:JoinJoinableJoinHook。如需教學課程,請參閱使用合併上下文管理器進行不均勻輸入的分散式訓練

class torch.distributed.algorithms.Join(joinables, enable=True, throw_on_early_termination=False, **kwargs)[原始碼]

此類別定義了通用合併上下文管理器,它允許在程序合併後呼叫自訂掛鉤。

這些掛鉤應該遮蔽未合併程序的集體通訊,以防止掛起和錯誤,並確保演算法的正確性。如需有關掛鉤定義的詳細資訊,請參閱JoinHook

警告

上下文管理器要求每個參與的Joinable 在其每次迭代的集體通訊之前呼叫方法notify_join_context(),以確保正確性。

警告

上下文管理器要求JoinHook 物件中的所有 process_group 屬性都相同。如果有多個 JoinHook 物件,則使用第一個物件的 device。程序群組和裝置資訊用於檢查未合併的程序,以及在啟用 throw_on_early_termination 的情況下通知程序擲回例外,兩者都使用全簡化。

參數
  • joinables (List[Joinable]) – 參與的 Joinable 的清單;它們的掛鉤按照給定的順序進行迭代。

  • enable (bool) – 啟用不均勻輸入偵測的旗標;設定為 False 會停用上下文管理器的功能,並且只有在使用者知道輸入不會不均勻時才應設定(預設值:True)。

  • throw_on_early_termination (bool) – 控制在偵測到不均勻輸入時是否擲回例外的旗標(預設值:False)。

範例

>>> import os
>>> import torch
>>> import torch.distributed as dist
>>> import torch.multiprocessing as mp
>>> import torch.nn.parallel.DistributedDataParallel as DDP
>>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO
>>> from torch.distributed.algorithms.join import Join
>>>
>>> # On each spawned worker
>>> def worker(rank):
>>>     dist.init_process_group("nccl", rank=rank, world_size=2)
>>>     model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
>>>     optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01)
>>>     # Rank 1 gets one more input than rank 0
>>>     inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)]
>>>     with Join([model, optim]):
>>>         for input in inputs:
>>>             loss = model(input).sum()
>>>             loss.backward()
>>>             optim.step()
>>>     # All ranks reach here without hanging/erroring
static notify_join_context(joinable)[原始碼]

通知合併上下文管理器呼叫程序尚未合併。

然後,如果 throw_on_early_termination=True,則檢查是否已偵測到不均勻輸入(即,如果一個程序已合併),如果是,則擲回例外。

此方法應該從Joinable 物件的每次迭代的集體通訊之前呼叫。例如,這應該在 DistributedDataParallel 中的前向傳遞開始時呼叫。

在這個方法中,只有傳遞給上下文管理器的第一個 Joinable 對象會執行集體通訊,而對於其他對象,這個方法是無效的。

參數

joinable (Joinable) – 調用此方法的 Joinable 對象。

返回

如果 joinable 是第一個傳遞給上下文管理器的,則返回一個用於 all-reduce 的異步工作句柄,用於通知上下文管理器進程尚未加入;否則返回 None

class torch.distributed.algorithms.Joinable[source]

這定義了一個用於可加入類別的抽象基類別。

一個可加入類別(繼承自 Joinable)應該實作 join_hook(),它會返回一個 JoinHook 實例,以及返回設備和進程組信息的 join_device()join_process_group()

abstract property join_device: device

返回用於執行加入上下文管理器所需的集體通訊的設備。

abstract join_hook(**kwargs)[source]

為給定的 Joinable 返回一個 JoinHook 實例。

參數

kwargs (dict) – 一個包含任何關鍵字參數的 dict,用於在運行時修改加入鉤子的行為;所有共享同一個加入上下文管理器的 Joinable 實例都會收到相同的 kwargs 值。

返回類型

JoinHook

abstract property join_process_group: Any

返回加入上下文管理器本身所需的集體通訊的進程組。

class torch.distributed.algorithms.JoinHook[source]

這定義了一個加入鉤子,它在加入上下文管理器中提供了兩個入口點。

入口點:一個主鉤子,在存在未加入進程時重複調用,以及一個後置鉤子,在所有進程都加入後調用一次。

要為通用加入上下文管理器實作加入鉤子,請定義一個繼承自 JoinHook 的類別,並根據需要覆蓋 main_hook()post_hook()

main_hook()[source]

在存在未加入進程時調用此鉤子,以在訓練迭代中模擬集體通訊。

訓練迭代,即一次正向傳遞、反向傳遞和優化器步驟。

post_hook(is_last_joiner)[source]

在所有進程都加入後調用鉤子。

它會傳遞一個額外的 bool 參數 is_last_joiner,指示排名是否為最後加入的排名之一。

參數

is_last_joiner (bool) – 如果排名是最後加入的排名之一,則為 True;否則為 False

文件

訪問 PyTorch 的完整開發者文檔

查看文檔

教程

獲取針對初學者和高級開發者的深入教程

查看教程

資源

查找開發資源並獲得問題解答

查看資源