• 文件 >
  • 通用 Join 上下文管理器
快捷方式

通用 Join 上下文管理器

通用 join 上下文管理器有助於在不均勻輸入上進行分散式訓練。本頁概述了相關類的 API:JoinJoinableJoinHook。有關教程,請參閱 使用 Join 上下文管理器進行不均勻輸入的分散式訓練

class torch.distributed.algorithms.Join(joinables, enable=True, throw_on_early_termination=False, **kwargs)[原始碼][原始碼]

此類定義了通用 join 上下文管理器,允許在程序加入後呼叫自定義鉤子。

這些鉤子應該模擬未加入程序的集合通訊,以防止掛起和錯誤,並確保演算法正確性。有關鉤子定義的詳細資訊,請參閱 JoinHook

警告

上下文管理器要求每個參與的 Joinable 在其自身的每次迭代集合通訊之前呼叫方法 notify_join_context(),以確保正確性。

警告

上下文管理器要求所有 JoinHook 物件中的 process_group 屬性都相同。如果存在多個 JoinHook 物件,則使用第一個的 device。程序組和裝置資訊用於檢查未加入的程序,並在啟用 throw_on_early_termination 時通知程序丟擲異常,這兩者都使用 all-reduce。

引數
  • joinables (List[Joinable]) – 參與的 Joinable 列表;其鉤子按給定順序進行迭代。

  • enable (bool) – 一個標誌,用於啟用不均勻輸入檢測;將其設定為 False 會停用上下文管理器的功能,僅當用戶知道輸入不會不均勻時才應設定(預設值:True)。

  • throw_on_early_termination (bool) – 一個標誌,控制在檢測到不均勻輸入時是否丟擲異常(預設值:False)。

示例

>>> import os
>>> import torch
>>> import torch.distributed as dist
>>> import torch.multiprocessing as mp
>>> import torch.nn.parallel.DistributedDataParallel as DDP
>>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO
>>> from torch.distributed.algorithms.join import Join
>>>
>>> # On each spawned worker
>>> def worker(rank):
>>>     dist.init_process_group("nccl", rank=rank, world_size=2)
>>>     model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
>>>     optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01)
>>>     # Rank 1 gets one more input than rank 0
>>>     inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)]
>>>     with Join([model, optim]):
>>>         for input in inputs:
>>>             loss = model(input).sum()
>>>             loss.backward()
>>>             optim.step()
>>>     # All ranks reach here without hanging/erroring
static notify_join_context(joinable)[原始碼][原始碼]

通知 join 上下文管理器呼叫程序尚未加入。

然後,如果 throw_on_early_termination=True,則檢查是否已檢測到不均勻輸入(即是否有程序已加入),如果檢測到則丟擲異常。

此方法應在 Joinable 物件的每次迭代集合通訊之前呼叫。例如,應在 DistributedDataParallel 的前向傳播開始時呼叫此方法。

只有傳遞給上下文管理器的第一個 Joinable 物件在此方法中執行集合通訊,對於其他物件,此方法是空操作。

引數

joinable (Joinable) – 呼叫此方法的 Joinable 物件。

返回值

如果 joinable 是傳遞給上下文管理器的第一個物件,則返回用於通知上下文管理器程序尚未加入的 all-reduce 的非同步工作控制代碼;否則返回 None

class torch.distributed.algorithms.Joinable[原始碼][原始碼]

這定義了一個可加入類的抽象基類。

可加入類(繼承自 Joinable)除了應實現返回裝置和程序組資訊的 join_device()join_process_group() 方法外,還應實現返回 JoinHook 例項的 join_hook() 方法。

abstract property join_device: device

返回執行 join 上下文管理器所需的集合通訊的裝置。

abstract join_hook(**kwargs)[原始碼][原始碼]

為給定的 Joinable 返回一個 JoinHook 例項。

引數

kwargs (dict) – 一個 dict,包含用於在執行時修改 join 鉤子行為的任意關鍵字引數;所有共享同一 join 上下文管理器的 Joinable 例項都將收到相同的 kwargs 值。

返回型別

JoinHook

abstract property join_process_group: Any

返回 join 上下文管理器自身所需的集合通訊的程序組。

class torch.distributed.algorithms.JoinHook[原始碼][原始碼]

這定義了一個 join 鉤子,它在 join 上下文管理器中提供了兩個入口點。

入口點:一個主鉤子 (main hook),當存在未加入的程序時會重複呼叫;以及一個後鉤子 (post-hook),在所有程序都加入後呼叫一次。

要為通用 join 上下文管理器實現 join 鉤子,請定義一個繼承自 JoinHook 的類,並根據需要重寫 main_hook()post_hook() 方法。

main_hook()[原始碼][原始碼]

當存在未加入的程序時呼叫此鉤子,以模擬訓練迭代中的集合通訊。

訓練迭代,即一次前向傳播、一次後向傳播和一次最佳化器步。

post_hook(is_last_joiner)[原始碼][原始碼]

在所有程序加入後呼叫鉤子。

它會接收一個額外的 bool 引數 is_last_joiner,指示該 rank 是否是最後加入的之一。

引數

is_last_joiner (bool) – 如果該 rank 是最後加入的之一,則為 True;否則為 False

文件

檢視 PyTorch 的全面開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源