torch.distributed.fsdp.fully_shard¶
PyTorch FSDP2 (fully_shard)¶
PyTorch FSDP2 提供全分片資料並行(FSDP)實現,以提升 performant eager-mode 的效能,同時透過按引數分片提高易用性。
如果您是 FSDP 新手,我們建議您從 FSDP2 開始,因為它具有更高的易用性。
如果您當前使用 FSDP1,請評估以下差異,看是否應切換到 FSDP2
與 PyTorch FSDP1 (FullyShardedDataParallel) 相比
FSDP2 使用基於
DTensor的 dim-0 按引數分片,與 FSDP1 的平面引數分片相比,分片表示更簡單,同時保持相似的吞吐量效能。更具體地說,FSDP2 透過torch.chunk(dim=0)在 dim-0 上將每個引數分塊到資料並行工作節點上,而 FSDP1 則將一組張量扁平化、連線並分塊在一起,使得理解每個工作節點上存在哪些資料以及重新分片到不同的並行模式變得複雜。按引數分片提供了更直觀的使用者體驗,放寬了對凍結引數的限制,並允許使用無需通訊的(分片)狀態字典,這在 FSDP1 中則需要 all-gather 操作。FSDP2 實現了不同的記憶體管理方法來處理多流使用,避免了
torch.Tensor.record_stream。這確保了確定性和預期的記憶體使用,並且不像 FSDP1 的limit_all_gathers=True那樣需要阻塞 CPU。FSDP2 提供了 API,允許手動控制預取和集合操作排程,為高階使用者提供更多定製選項。有關詳細資訊,請參閱下面的
FSDPModule方法。FSDP2 簡化了一些 API 表面:例如,FSDP2 不直接支援完整狀態字典。使用者可以使用
DTensorAPI(如DTensor.full_tensor())或使用更高級別的 API(如 PyTorch Distributed Checkpoint 的分散式狀態字典 API)自行將包含DTensor的分片狀態字典重新分片為完整狀態字典。此外,還移除了一些其他引數;有關詳細資訊,請參閱此處。
如果您是首次接觸 FSDP 或以上任何一點符合您的用例,我們建議您考慮使用 FSDP2。
有關係統設計和實現的詳細資訊,請參閱此 RFC。
注意
torch.distributed.fsdp.fully_shard 目前處於原型階段並正在開發中。核心 API 可能不會改變,但如有必要,我們可能會進行一些 API 更改。
前端 API 是 fully_shard,可以在 module 上呼叫
- torch.distributed.fsdp.fully_shard(module, *, mesh=None, reshard_after_forward=True, shard_placement_fn=None, mp_policy=MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True), offload_policy=OffloadPolicy(), ignored_params=None)[原始碼]¶
將全分片資料並行 (FSDP) 應用於
module,其中 FSDP 在資料並行工作節點之間分片模組引數、梯度和最佳化器狀態,以犧牲通訊開銷來節省記憶體。初始化時,FSDP 會根據
mesh在資料並行工作節點之間分片模組引數。在前向計算之前,FSDP 會在資料並行工作節點之間 all-gather 分片引數以獲取未分片引數用於前向計算。如果reshard_after_forward為True,則 FSDP 會在前向計算後釋放未分片引數,並在後向計算梯度之前重新 all-gather 它們。梯度計算完成後,FSDP 會釋放未分片引數,並在資料並行工作節點之間 reduce-scatter 未分片梯度。此實現將分片引數表示為在 dim-0 上分片的
DTensor,而未分片引數將與module上的原始引數類似(例如,如果原始引數是torch.Tensor,則仍是torch.Tensor)。模組 forward pre-hook 在module上 all-gather 引數,模組 forward hook 在module上釋放它們(如果需要)。類似的 backward hook 會 all-gather 引數,然後釋放參數並 reduce-scatter 梯度。由於將多個張量組合在一起進行一次集合操作對於通訊效率至關重要,此實現將這種分組視為首要功能。在
module上呼叫fully_shard()會構建一個組,該組包含module.parameters()中的引數,但那些已在子模組的早期呼叫中分配給其他組的引數除外。這意味著fully_shard()應該在模型上自底向上呼叫。每個組的引數在一次集合操作中 all-gather,其梯度在一次集合操作中 reduce-scatter。將模型劃分為多個組(“逐層”)可以實現峰值記憶體節省和通訊/計算重疊。使用者通常不應該只在最頂層的根模組上呼叫fully_shard()。- 引數
module (Union[nn.Module, List[nn.Module]) – 要使用 FSDP 分片並分組進行通訊的模組或模組列表。
mesh (Optional[DeviceMesh]) – 此資料並行網格定義了分片和裝置。如果為 1D,則引數在 1D 網格(FSDP)上使用
(Shard(0),)放置進行全分片。如果為 2D,則引數在第 1 維上分片,並在第 0 維上覆制(HSDP),使用(Replicate(), Shard(0))放置。網格的裝置型別指定了用於通訊的裝置型別;如果是 CUDA 或類似 CUDA 的裝置型別,則使用當前裝置。reshard_after_forward (Union[bool, int]) –
控制前向計算後的引數行為,可權衡記憶體和通訊開銷:
如果為
True,則在前向計算後重新分片引數,並在後向計算中重新 all-gather。如果為
False,則在前向計算後在記憶體中保留未分片引數,並避免後向計算中的 all-gather。如果為
int,則表示前向計算後重新分片到的世界大小。它應該是網格分片維度大小的一個非平凡因子(即排除 1 和維度大小本身)。一個選擇可以是節點內大小(例如torch.cuda.device_count())。這允許後向計算中的 all-gather 在較小的世界大小上進行,但代價是記憶體使用高於設定為True的情況。根 FSDP 狀態的值被特別設定為
False作為啟發式處理,因為其引數通常會立即進行 all-gather 以用於後向計算。前向計算後,註冊到模組的引數取決於此設定:如果為
True,則註冊的引數是分片引數;如果為False,則為未分片引數;否則為重新分片到較小網格的引數。要在前向和後向計算之間修改引數,註冊的引數必須是分片引數。對於False或int,可以透過手動呼叫reshard()進行重新分片。
shard_placement_fn (Optional[Callable[[nn.Parameter], Optional[Shard]]]) – 此可呼叫物件可用於覆蓋引數的分片放置,以便在 dim-0 以外的維度上分片引數。如果此可呼叫物件返回
Shard放置(而非None),則 FSDP 將根據該放置進行分片(例如Shard(1))。如果在非零維度上分片,我們當前要求均勻分片,即該維度上的張量維度大小必須能被 FSDP 分片網格大小整除。mp_policy (MixedPrecisionPolicy) – 控制混合精度策略,為此模組提供引數/歸約混合精度。有關詳細資訊,請參閱
MixedPrecisionPolicy。offload_policy (OffloadPolicy) – 控制解除安裝策略,為此模組提供引數/梯度/最佳化器狀態解除安裝。有關詳細資訊,請參閱
OffloadPolicy及其子類。ignored_params (Optional[set[nn.Parameter]]) – 可選(Set[nn.Parameter]):不希望使用 FSDP 進行分片的引數集合。
- 返回值
應用了 FSDP 的模組(就地修改)。
- 返回型別
呼叫 fully_shard(module) 會動態構建一個新類,該類是 type(module) 和 FSDP 類 FSDPModule 的子類。例如,如果我們在模組 linear: nn.Linear 上呼叫 fully_shard(linear),則 FSDP 會構建一個新類 FSDPLinear 並將 linear 的型別更改為此新類。否則,fully_shard 不會改變模組結構和引數的完全限定名稱。FSDPModule 類允許在模組上提供一些 FSDP 特有的方法。
- class torch.distributed.fsdp.FSDPModule(*args, **kwargs)¶
-
- set_all_reduce_hook(hook, *, stream=None)[原始碼][原始碼]¶
- 引數
hook (Callable[[torch.Tensor], None]) – 使用者定義的 all-reduce 鉤子,期望的簽名是
hook(reduce_output: torch.Tensor) -> None,其中reduce_output是僅使用 FSDP 時的 reduce-scatter 輸出,或使用原生 HSDP 時的 all-reduce 輸出。stream (Optional[torch.cuda.Stream]) – 執行 all-reduce 鉤子的流。僅在不使用原生 HSDP 時應設定此引數。如果使用原生 HSDP,鉤子將在原生 HSDP all-reduce 使用的內部定義的 all-reduce 流中執行。
- set_is_last_backward(is_last_backward)[原始碼][原始碼]¶
設定下一個後向計算是否是最後一個。在最後一個後向計算中,FSDP 會等待未完成的梯度歸約,並清除用於後向預取的內部資料結構。這對於微批處理非常有用。
- set_modules_to_backward_prefetch(modules)[原始碼][原始碼]¶
設定此 FSDP 模組應在後向計算中顯式預取 all-gather 的 FSDP 模組。這會覆蓋根據逆前向計算後順序預取下一個 FSDP 模組的預設後向預取實現。
傳遞包含前一個 FSDP 模組的單元素列表,會獲得與預設重疊行為相同的 all-gather 重疊行為。傳遞至少包含兩個元素的列表,可以實現更積極的重疊,並且會使用更多預留記憶體。
- 引數
modules (List[FSDPModule]) – 要預取的 FSDP 模組列表。
- set_modules_to_forward_prefetch(modules)[原始碼][原始碼]¶
設定此 FSDP 模組應在前向計算中顯式預取 all-gather 的 FSDP 模組。預取在此模組的 all-gather copy-out 後執行。
傳遞包含下一個 FSDP 模組的單元素列表,會獲得與預設重疊行為相同的 all-gather 重疊行為,但預取的 all-gather 會從 CPU 更早發出。傳遞至少包含兩個元素的列表,可以實現更積極的重疊,並且會使用更多預留記憶體。
- 引數
modules (List[FSDPModule]) – 要預取的 FSDP 模組列表。
- set_post_optim_event(event)[原始碼][原始碼]¶
為根 FSDP 模組設定一個最佳化器步驟後事件,以便等待 all-gather 流。
預設情況下,根 FSDP 模組會在當前流上等待 all-gather 流,以確保最佳化器步驟在 all-gather 之前完成。然而,如果在最佳化器步驟後有不相關的計算,這可能會引入錯誤依賴。此 API 允許使用者提供自己的事件來等待。根模組等待事件後,事件將被丟棄,因此每次迭代都應使用新事件呼叫此 API。
- 引數
event (torch.Event) – 在最佳化器步驟之後記錄的事件,用於等待全域性收集流。
- set_reduce_scatter_divide_factor(factor)[source][source]¶
設定 reduce-scatter 的自定義除數因子。這會成為使用 NCCL 的 PreMulSum 的自定義規約操作,允許在規約前乘以該因子。
- 引數
factor (浮點型) – 自定義除數因子。
- set_requires_all_reduce(requires_all_reduce, *, recurse=True)[source][source]¶
設定模組是否應該對梯度進行 all-reduce(全域性規約)。這可用於實現僅使用 reduce-scatter 而不使用 all-reduce 的 HSDP 梯度累積。
- set_requires_gradient_sync(requires_gradient_sync, *, recurse=True)[source][source]¶
設定模組是否應該同步梯度。這可用於實現*不進行通訊*的梯度累積。對於 HSDP,這同時控制 reduce-scatter 和 all-reduce。這相當於 FSDP1 中的 no_sync。
- set_reshard_after_backward(reshard_after_backward, *, recurse=True)[source][source]¶
設定模組是否應該在反向傳播後重新分片引數。這可在梯度累積期間使用,以權衡更高的記憶體消耗來減少通訊,因為未分片的引數在下一次前向傳播前無需重新全域性收集。
- set_unshard_in_backward(unshard_in_backward)[source][source]¶
設定 FSDP 模組的引數在反向傳播中是否需要取消分片。這可在專家場景中使用,當用戶知道此 FSDP 模組引數組中的所有引數都不需要用於反向計算時(例如,嵌入層)。
- unshard(async_op=False)[source][source]¶
透過分配記憶體和全域性收集引數來取消模組的引數分片。此方法*不是*遞迴的。取消分片遵循
MixedPrecisionPolicy,因此如果設定了param_dtype,它將根據param_dtype進行全域性收集。- 引數
async_op (布林型別) – 如果為
True,則返回一個UnshardHandle,它具有wait()方法來等待取消分片操作。如果為False,則返回None並在函式內部等待 handle。- 返回型別
注意
如果
async_op=True,則 FSDP 將在模組的前向傳播前(pre-forward)為使用者等待待處理的取消分片操作。使用者僅需在等待必須發生在前向傳播前(pre-forward)時才顯式呼叫wait()。
- class torch.distributed.fsdp.UnshardHandle¶
用於等待
FSDPModule.unshard()操作的 handle。
- torch.distributed.fsdp.register_fsdp_forward_method(module, method_name)[source]¶
在
module上註冊一個方法,使其被視為 FSDP 的前向方法。FSDP 在前向傳播前(pre-forward)全域性收集引數,並可選地在後向傳播後(post-forward)釋放參數(取決於
reshard_after_forward)。預設情況下,FSDP 只知道對nn.Module.forward()執行此操作。此函式修補一個使用者指定的方法,使其分別在該方法之前/之後執行前向傳播前/後(pre/post-forward)鉤子。如果module不是FSDPModule,則此操作無效(no-op)。
- class torch.distributed.fsdp.MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True)¶
這配置了 FSDP 的混合精度。與 autocast 不同,它在模組級別而不是操作級別應用混合精度,這意味著低精度啟用會為反向傳播保留,並且高精度到低精度的轉換僅發生在模組邊界。
FSDP 與模組級混合精度配合良好,因為它無論如何都會在記憶體中保留高精度的分片引數。換句話說,FSDP 不需要額外的記憶體來為最佳化器步驟保留引數的高精度副本。
- 變數
param_dtype (Optional[torch.dtype]) – 這指定了未分片引數的 dtype,因此也指定了前向/反向計算以及引數全域性收集的 dtype。如果為
None,則未分片引數使用原始 dtype。最佳化器步驟使用原始 dtype 中的分片引數。(預設值:None)reduce_dtype (Optional[torch.dtype]) – 這指定了梯度規約的 dtype(即 reduce-scatter 或 all-reduce)。如果為
None但param_dtype不為None,則規約使用計算 dtype。這可用於在全精度下執行梯度規約,同時使用低精度進行計算。如果透過set_requires_gradient_sync()也停用了梯度規約,則 FSDP 將使用reduce_dtype累積梯度。(預設值:None)output_dtype (Optional[torch.dtype]) – 這指定了浮點前向輸出的轉換 dtype。這可用於幫助實現不同模組具有不同混合精度策略的場景。(預設值:
None)cast_forward_inputs (布林型別) – 這指定 FSDP 是否應將前向的浮點輸入張量轉換為
param_dtype。
- class torch.distributed.fsdp.OffloadPolicy¶
此基類表示不進行解除安裝的策略,僅用作
offload_policy引數的預設值。
- class torch.distributed.fsdp.CPUOffloadPolicy(pin_memory=True)¶
此解除安裝策略將引數、梯度和最佳化器狀態解除安裝到 CPU。分片引數在全域性收集前從主機複製到裝置。全域性收集的引數根據
reshard_after_forward釋放。分片梯度在反向傳播中從裝置複製到主機,最佳化器步驟在 CPU 上使用 CPU 最佳化器狀態執行。- 變數
pin_memory (布林型別) – 是否鎖定分片引數和梯度記憶體。鎖定記憶體可以提高 H2D/D2H 複製效率,並使複製與計算重疊。但是,其他程序無法使用鎖定記憶體。如果 CPU 記憶體不足,請將此項設定為
False。(預設值:True)