• 文件 >
  • FullyShardedDataParallel
捷徑

FullyShardedDataParallel

class torch.distributed.fsdp.FullyShardedDataParallel(module, process_group=None, sharding_strategy=None, cpu_offload=None, auto_wrap_policy=None, backward_prefetch=BackwardPrefetch.BACKWARD_PRE, mixed_precision=None, ignored_modules=None, param_init_fn=None, device_id=None, sync_module_states=False, forward_prefetch=False, limit_all_gathers=True, use_orig_params=False, ignored_states=None, device_mesh=None)[source]

用於在資料平行工作節點之間分片模組參數的包裝器。

此靈感來自於 徐等人 以及 DeepSpeed 中的 ZeRO 階段 3。FullyShardedDataParallel 通常縮寫為 FSDP。

若要瞭解 FSDP 內部結構,請參閱 FSDP 注意事項

範例

>>> import torch
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> torch.cuda.set_device(device_id)
>>> sharded_module = FSDP(my_module)
>>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
>>> x = sharded_module(x, y=3, z=torch.Tensor([1]))
>>> loss = x.sum()
>>> loss.backward()
>>> optim.step()

使用 FSDP 包含包裝您的模組,然後初始化您的優化器。這是必需的,因為 FSDP 會變更參數變數。

設定 FSDP 時,您需要考慮目標 CUDA 裝置。如果裝置具有 ID (dev_id),則您有三種選擇

  • 將模組放置在該裝置上

  • 使用 torch.cuda.set_device(dev_id) 設定裝置

  • dev_id 傳遞至 device_id 建構函式引數。

這可確保 FSDP 執行個體的計算裝置為目標裝置。對於選項 1 和 3,FSDP 初始化一律在 GPU 上進行。對於選項 2,FSDP 初始化會在模組目前的裝置上進行,這可能是 CPU。

如果您使用的是 sync_module_states=True 旗標,則需要確保模組位於 GPU 上,或使用 device_id 引數指定 FSDP 建構函式中 FSDP 將移動模組至的 CUDA 裝置。這是必要的,因為 sync_module_states=True 需要 GPU 通訊。

FSDP 也會將輸入張量移動到轉發方法的 GPU 計算裝置,因此您不需要手動將它們從 CPU 移動。

對於 use_orig_params=TrueShardingStrategy.SHARD_GRAD_OP 會公開未分片的參數,而不是轉發後的分片參數,這與 ShardingStrategy.FULL_SHARD 不同。如果您想要檢查梯度,可以使用 summon_full_params 方法並設定 with_grads=True

使用 limit_all_gathers=True 時,您可能會在 FSDP 前向傳遞中看到 CPU 執行緒未發出任何核心指令的間隔。這是刻意設計的,表示速率限制器正在生效。以這種方式同步 CPU 執行緒可以防止為後續的全收集操作過度配置記憶體,並且實際上不應延遲 GPU 核心指令的執行。

基於與自動梯度相關的原因,FSDP 在前向和反向計算過程中,會將受管理模組的參數替換為 torch.Tensor 视图。如果您的模組的前向傳遞依賴於儲存的參數引用,而不是每次迭代都重新取得引用,則它將不會看到 FSDP 新建立的视图,並且自動梯度將無法正常工作。

最後,當使用 sharding_strategy=ShardingStrategy.HYBRID_SHARD 並將分片處理群組設定為節點內,而複製處理群組設定為節點間時,設定 NCCL_CROSS_NIC=1 有助於在某些叢集設定中,改善複製處理群組上的全歸約時間。

限制

使用 FSDP 時,需要注意以下幾項限制

  • 目前,當使用 CPU 卸載時,FSDP 不支援在 no_sync() 之外累積梯度。這是因為 FSDP 使用新歸約的梯度,而不是與任何現有梯度累積,這可能會導致錯誤的結果。

  • FSDP 不支援執行包含在 FSDP 實例中的子模組的前向傳遞。這是因為子模組的參數會被分片,但子模組本身不是 FSDP 實例,因此它的前向傳遞不會正確地全收集所有參數。

  • 由於 FSDP 註冊反向鉤子的方式,因此它不適用於雙重重播。

  • FSDP 在凍結參數時有一些限制。對於 use_orig_params=False,每個 FSDP 實例都必須管理全部凍結或全部未凍結的參數。對於 use_orig_params=True,FSDP 支援混合凍結和未凍結的參數,但建議避免這樣做,以防止梯度記憶體使用量高於預期。

  • 從 PyTorch 1.12 開始,FSDP 對共享參數的支援有限。如果您的使用案例需要增強的共享參數支援,請在 此議題 中留言。

  • 您應該避免在前向和反向傳遞之間修改參數而不使用 summon_full_params 上下文,因為修改可能不會持續存在。

參數
  • module (nn.Module) – 這是要用 FSDP 包裝的模組。

  • process_group (Optional[Union[ProcessGroup, Tuple[ProcessGroup, ProcessGroup]]]) – 這是模型在其上進行分片的處理群組,因此也是用於 FSDP 的全收集和歸約散佈集體通訊的群組。如果為 None,則 FSDP 使用預設的處理群組。對於混合分片策略(例如 ShardingStrategy.HYBRID_SHARD),使用者可以傳入處理群組的元組,分別表示要分片和複製的群組。如果為 None,則 FSDP 會為使用者建構處理群組,以便在節點內分片並在節點間複製。(預設值:None

  • sharding_strategy (Optional[ShardingStrategy]) – 這會設定分片策略,這可能會在記憶體節省和通訊開銷之間進行權衡。有關詳細資訊,請參閱 ShardingStrategy。(預設值:FULL_SHARD

  • cpu_offload (Optional[CPUOffload]) – 這會設定 CPU 卸載。如果設定為 None,則不會進行 CPU 卸載。有關詳細資訊,請參閱 CPUOffload。(預設值:None

  • auto_wrap_policy (Optional[Union[Callable[[nn.Module, bool, int], bool], ModuleWrapPolicy, CustomPolicy]]) –

    這指定了將 FSDP 套用到 module 的子模組的策略,這是通訊和計算重疊所必需的,因此會影響效能。如果為 None,則 FSDP 僅套用到 module,並且使用者應該手動將 FSDP 套用到父模組本身(由下而上)。為了方便起見,這可以直接接受 ModuleWrapPolicy,這允許使用者指定要包裝的模組類別(例如 transformer 區塊)。否則,這應該是一個可呼叫的物件,它接受三個參數 module: nn.Modulerecurse: boolnonwrapped_numel: int,並且應該返回一個 bool,指定如果 recurse=False,是否應該將 FSDP 套用到傳入的 module,或者如果 recurse=True,則遍歷是否應該繼續到模組的子樹中。使用者可以在可呼叫的物件中新增其他參數。torch.distributed.fsdp.wrap.py 中的 size_based_auto_wrap_policy 提供了一個可呼叫的物件範例,如果模組子樹中的參數超過 1 億個 numel,則將 FSDP 套用到該模組。我們建議在套用 FSDP 後列印模型並根據需要進行調整。

    範例

    >>> def custom_auto_wrap_policy(
    >>>     module: nn.Module,
    >>>     recurse: bool,
    >>>     nonwrapped_numel: int,
    >>>     # Additional custom arguments
    >>>     min_num_params: int = int(1e8),
    >>> ) -> bool:
    >>>     return nonwrapped_numel >= min_num_params
    >>> # Configure a custom `min_num_params`
    >>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
    

  • backward_prefetch (Optional[BackwardPrefetch]) – 這會設定全收集的顯式反向預取。如果為 None,則 FSDP 不進行反向預取,並且在反向傳遞中沒有通訊和計算重疊。有關詳細資訊,請參閱 BackwardPrefetch。(預設值:BACKWARD_PRE

  • mixed_precision (Optional[MixedPrecision]) – 這會為 FSDP 設定原生混合精度。如果設定為 None,則不使用混合精度。否則,可以設定參數、緩衝區和梯度歸約的 dtype。有關詳細資訊,請參閱 MixedPrecision。(預設值:None

  • ignored_modules (Optional[Iterable[torch.nn.Module]]) – 此實例會忽略其自身參數以及子模組的參數和緩衝區的模組。ignored_modules 中的任何模組都不應該是 FullyShardedDataParallel 實例,並且如果已經建構的 FullyShardedDataParallel 實例嵌套在此實例下,則不會忽略它們。當使用 auto_wrap_policy 或參數分片不是由 FSDP 管理時,可以使用此參數來避免在模組粒度上分片特定的參數。(預設值:None

  • param_init_fn (Optional[Callable[[nn.Module], None]]) –

    一個 Callable[torch.nn.Module] -> None,指定如何將當前位於元裝置上的模組初始化到實際裝置上。從 v1.12 開始,FSDP 通過 is_meta 檢測具有位於元裝置上的參數或緩衝區的模組,如果指定了 param_init_fn,則套用它,否則呼叫 nn.Module.reset_parameters()。對於這兩種情況,實作都應該只初始化模組的參數/緩衝區,而不是初始化其子模組的參數/緩衝區。這是為了避免重新初始化。此外,FSDP 還支援通過 torchdistX 的 (https://github.com/pytorch/torchdistX) deferred_init() API 進行延遲初始化,其中延遲的模組通過呼叫 param_init_fn(如果指定)或 torchdistX 的預設 materialize_module() 進行初始化。如果指定了 param_init_fn,則會將其套用到所有元裝置模組,這意味著它應該根據模組類型進行處理。FSDP 在參數扁平化和分片之前呼叫初始化函數。

    範例

    >>> module = MyModule(device="meta")
    >>> def my_init_fn(module: nn.Module):
    >>>     # E.g. initialize depending on the module type
    >>>     ...
    >>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy)
    >>> print(next(fsdp_model.parameters()).device) # current CUDA device
    >>> # With torchdistX
    >>> module = deferred_init.deferred_init(MyModule, device="cuda")
    >>> # Will initialize via deferred_init.materialize_module().
    >>> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy)
    

  • device_id (Optional[Union[int, torch.device]]) – 一個 inttorch.device,指定 FSDP 初始化在其上發生的 CUDA 裝置,包括模組初始化(如果需要)和參數分片。如果 module 位於 CPU 上,則應指定此參數以提高初始化速度。如果設定了預設的 CUDA 裝置(例如,通過 torch.cuda.set_device),則使用者可以將 torch.cuda.current_device 傳遞給它。(預設值:None

  • sync_module_states (bool) – 如果為 True,則每個 FSDP 模組都會將模組參數和緩衝區從 rank 0 廣播,以確保它們在各個 rank 中複製(將通訊成本添加到此建構函式)。這有助於以節省記憶體的方式透過 load_state_dict 載入 state_dict 检查點。請參閱 FullStateDictConfig 以取得範例。(預設值:False

  • forward_prefetch (bool) – 如果為 True,則 FSDP 會在目前的正向計算之前,*明確地* 預先提取下一個正向傳遞的 all-gather。這僅適用於 CPU 密集型工作負載,在這種情況下,更早發出下一個 all-gather 可以改善重疊。這應該只用於靜態圖形模型,因為預先提取會遵循第一次迭代的執行順序。(預設值:False

  • limit_all_gathers (bool) – 如果為 True,則 FSDP 會明確同步 CPU 執行緒,以確保 GPU 記憶體使用量僅來自*兩個*連續的 FSDP 实例(執行計算的當前实例和預先提取其 all-gather 的下一個实例)。如果為 False,則 FSDP 允許 CPU 執行緒在沒有任何額外同步的情況下發出 all-gather。(預設值:True)我們經常將此功能稱為「速率限制器」。此旗標應僅針對記憶體壓力較低的特定 CPU 密集型工作負載設定為 False,在這種情況下,CPU 執行緒可以積極發出所有內核,而無需擔心 GPU 記憶體使用量。

  • use_orig_params (bool) – 將此設定為 True 會讓 FSDP 使用 module 的原始參數。FSDP 透過 nn.Module.named_parameters() 向使用者公開這些原始參數,而不是 FSDP 內部的 FlatParameter。這表示優化器步驟會在原始參數上執行,啟用每個原始參數的超參數。FSDP 會保留原始參數變數,並在其未分片和分片形式之間操作其資料,其中它們始終分別是底層未分片或分片 FlatParameter 的视图。使用目前的演算法,分片形式始終為 1D,會失去原始張量結構。對於給定的 rank,原始參數可能會呈現全部、部分或完全沒有資料。在沒有資料的情況下,其資料將類似於大小為 0 的空張量。使用者不應編寫依賴於給定原始參數在其分片形式中呈現哪些資料的程式。True 是使用 torch.compile() 所必需的。將此設定為 False 會透過 nn.Module.named_parameters() 向使用者公開 FSDP 內部的 FlatParameter。(預設值:False

  • ignored_states (Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]) – 不會由此 FSDP 实例管理的已忽略參數或模組,這表示參數不會分片,並且其梯度不會跨 rank 減少。此參數與現有的 ignored_modules 參數統一,我們可能會很快棄用 ignored_modules。為了向後相容,我們保留了 ignored_statesignored_modules,但 FSDP 只允許其中一個指定為非 None

  • device_mesh (Optional[DeviceMesh]) – DeviceMesh 可以用作 process_group 的替代方案。當傳遞 device_mesh 時,FSDP 將使用底層處理群組進行 all-gather 和 reduce-scatter 集體通訊。因此,這兩個參數需要互斥。對於混合分片策略,例如 ShardingStrategy.HYBRID_SHARD,使用者可以傳入 2D DeviceMesh 而不是處理群組的元組。對於 2D FSDP + TP,使用者需要傳入 device_mesh 而不是 process_group。如需更多 DeviceMesh 資訊,請造訪:https://pytorch.com.tw/tutorials/recipes/distributed_device_mesh.html

apply(fn)[來源]

fn 遞迴地應用於每個子模組(由 .children() 返回)以及自身。

典型的用途包括初始化模型的參數(另請參閱 torch.nn.init)。

torch.nn.Module.apply 相比,此版本還會在應用 fn 之前收集完整的參數。不應從另一個 summon_full_params 上下文中呼叫它。

參數

fn (Module -> None) – 要應用於每個子模組的函式

傳回值

自身

傳回類型

Module

check_is_root()[來源]

檢查此实例是否為根 FSDP 模組。

傳回類型

bool

clip_grad_norm_(max_norm, norm_type=2.0)[來源]

裁剪所有參數的梯度範數。

範數是在所有參數的梯度上計算的,這些梯度被視為單個向量,並且梯度會被就地修改。

參數
  • max_norm (floatint) – 梯度的最大範數

  • norm_type (floatint) – 使用的 p 範數的類型。對於無限範數,可以是 'inf'

傳回值

參數的總範數(視為單個向量)。

傳回類型

Tensor

如果每個 FSDP 实例都使用 NO_SHARD,這表示沒有梯度在 rank 間分片,那麼您可以直接使用 torch.nn.utils.clip_grad_norm_()

如果至少有一個 FSDP 实例使用分片策略(即除了 NO_SHARD 之外的任何策略),那麼您應該使用此方法而不是 torch.nn.utils.clip_grad_norm_(),因為此方法處理了梯度在 rank 間分片的事實。

返回的總範數將具有所有參數/梯度中「最大」的 dtype,如 PyTorch 的類型提升語義所定義。例如,如果*所有*參數/梯度都使用低精度 dtype,則返回的範數的 dtype 將是該低精度 dtype,但如果存在至少一個使用 FP32 的參數/梯度,則返回的範數的 dtype 將是 FP32。

警告

由於此方法使用集體通訊,因此需要在所有 rank 上呼叫它。

static flatten_sharded_optim_state_dict(sharded_optim_state_dict, model, optim)[來源]

展平分片的優化器狀態字典。

API 與 shard_full_optim_state_dict() 類似。唯一的區別是輸入 sharded_optim_state_dict 應該從 sharded_optim_state_dict() 返回。因此,每個 rank 上都會有 all-gather 呼叫來收集 ShardedTensor

參數
傳回值

參考 shard_full_optim_state_dict()

傳回類型

Dict[str, Any]

forward(*args, **kwargs)[原始碼]

執行已包裝模組的前向傳遞,插入 FSDP 特定的前向和後向分片邏輯。

傳回類型

任何

static fsdp_modules(module, root_only=False)[原始碼]

返回所有巢狀的 FSDP 實例。

這可能包含 module 本身,並且如果 root_only=True,則僅包含 FSDP 根模組。

參數
  • module (torch.nn.Module) – 根模組,它可能是或可能不是 FSDP 模組。

  • root_only (bool) – 是否僅返回 FSDP 根模組。(預設值: False)

傳回值

巢狀在輸入 module 中的 FSDP 模組。

傳回類型

List[FullyShardedDataParallel]

static full_optim_state_dict(model, optim, optim_input=None, rank0_only=True, group=None)[原始碼]

返回完整的優化器狀態字典。

在 rank 0 上整合完整的優化器狀態,並將其作為 dict 返回,遵循 torch.optim.Optimizer.state_dict() 的慣例,即使用鍵 "state""param_groups"FSDP 模組中包含的扁平化參數(包含在 model 中)會映射回其未扁平化的參數。

這需要在所有 rank 上呼叫,因為它使用集體通訊。但是,如果 rank0_only=True,則狀態字典僅在 rank 0 上填充,而所有其他 rank 都返回空的 dict

torch.optim.Optimizer.state_dict() 不同,此方法使用完整的參數名稱作為鍵,而不是參數 ID。

如同在 torch.optim.Optimizer.state_dict() 中,優化器狀態字典中包含的張量不會被複製,因此可能會出現別名意外。為了獲得最佳實務,請考慮立即儲存返回的優化器狀態字典,例如使用 torch.save()

參數
  • model (torch.nn.Module) – 根模組(它可能是或可能不是 FullyShardedDataParallel 實例),其參數已傳遞至優化器 optim

  • optim (torch.optim.Optimizer) – model 參數的優化器。

  • optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – 傳遞至優化器 optim 的輸入,表示參數群組的 list 或參數的迭代器;如果為 None,則此方法假設輸入為 model.parameters()。此參數已棄用,不再需要傳遞它。(預設值: None)

  • rank0_only (bool) – 如果為 True,則僅在 rank 0 上儲存填充的 dict;如果為 False,則在所有 rank 上儲存。(預設值: True)

  • group (dist.ProcessGroup) – 模型的處理程序群組,如果使用預設處理程序群組,則為 None。(預設值: None)

傳回值

包含 model 的原始未扁平化參數的優化器狀態的 dict,並包含鍵“state”和“param_groups”,遵循 torch.optim.Optimizer.state_dict() 的慣例。如果 rank0_only=True,則非零 rank 返回空的 dict

傳回類型

Dict[str, Any]

static get_state_dict_type(module)[原始碼]

取得以 module 為根的 FSDP 模組的 state_dict_type 和對應的配置。

目標模組不必是 FSDP 模組。

傳回值

包含目前設定的 state_dict_type 和 state_dict / optim_state_dict 配置的 StateDictSettings

引發
  • AssertionError` 如果不同 FSDP 子模組的 StateDictSettings 不同。

傳回類型

StateDictSettings

property module: Module

返回已包裝的模組。

named_buffers(*args, **kwargs)[原始碼]

返回模組緩衝區的迭代器,同時產生緩衝區的名稱和緩衝區本身。

summon_full_params() 上下文管理器內部時,攔截緩衝區名稱並移除所有出現的 FSDP 特定扁平化緩衝區前綴。

傳回類型

Iterator[Tuple[str, Tensor]]

named_parameters(*args, **kwargs)[原始碼]

返回模組參數的迭代器,同時產生參數的名稱和參數本身。

summon_full_params() 上下文管理器內部時,攔截參數名稱並移除所有出現的 FSDP 特定扁平化參數前綴。

傳回類型

Iterator[Tuple[str, Parameter]]

no_sync()[原始碼]

停用跨 FSDP 實例的梯度同步。

在此上下文中,梯度將累積在模組變數中,稍後將在退出上下文後的第一個前向-後向傳遞中同步。這應該僅在根 FSDP 實例上使用,並且將遞迴地應用於所有子 FSDP 實例。

備註

這可能會導致更高的記憶體使用量,因為 FSDP 會累積完整的模型梯度(而不是梯度分片),直到最終同步。

備註

當與 CPU 卸載一起使用時,在上下文管理器內,梯度不會被卸載到 CPU。相反的,它們只會在最終同步後立即卸載。

傳回類型

產生器

static optim_state_dict(model, optim, optim_state_dict=None, group=None)[原始碼]

轉換對應於分片模型的優化器的狀態字典。

給定的狀態字典可以轉換為三種類型之一:1) 完整優化器 state_dict,2) 分片優化器 state_dict,3) 本地優化器 state_dict。

對於完整優化器 state_dict,所有狀態都未扁平化且未分片。僅限 Rank0 和僅限 CPU 可以透過 state_dict_type() 指定,以避免 OOM。

對於分片優化器 state_dict,所有狀態都未扁平化但已分片。僅限 CPU 可以透過 state_dict_type() 指定,以進一步節省記憶體。

對於本地 state_dict,將不會執行任何轉換。但是狀態將從 nn.Tensor 轉換為 ShardedTensor 以表示其分片性質(目前尚不支援)。

範例

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> from torch.distributed.fsdp import StateDictType
>>> from torch.distributed.fsdp import FullStateDictConfig
>>> from torch.distributed.fsdp import FullOptimStateDictConfig
>>> # Save a checkpoint
>>> model, optim = ...
>>> FSDP.set_state_dict_type(
>>>     model,
>>>     StateDictType.FULL_STATE_DICT,
>>>     FullStateDictConfig(rank0_only=False),
>>>     FullOptimStateDictConfig(rank0_only=False),
>>> )
>>> state_dict = model.state_dict()
>>> optim_state_dict = FSDP.optim_state_dict(model, optim)
>>> save_a_checkpoint(state_dict, optim_state_dict)
>>> # Load a checkpoint
>>> model, optim = ...
>>> state_dict, optim_state_dict = load_a_checkpoint()
>>> FSDP.set_state_dict_type(
>>>     model,
>>>     StateDictType.FULL_STATE_DICT,
>>>     FullStateDictConfig(rank0_only=False),
>>>     FullOptimStateDictConfig(rank0_only=False),
>>> )
>>> model.load_state_dict(state_dict)
>>> optim_state_dict = FSDP.optim_state_dict_to_load(
>>>     model, optim, optim_state_dict
>>> )
>>> optim.load_state_dict(optim_state_dict)
參數
  • model (torch.nn.Module) – 根模組(它可能是或可能不是 FullyShardedDataParallel 實例),其參數已傳遞至優化器 optim

  • optim (torch.optim.Optimizer) – model 參數的優化器。

  • optim_state_dict (Dict[str, Any]) – 要轉換的目標優化器 state_dict。如果值為 None,則將使用 optim.state_dict()。( 預設值: None)

  • group (dist.ProcessGroup) – 模型的處理群組,參數在其中進行分片,如果使用預設處理群組,則為 None。( 預設值: None)

傳回值

一個包含 model 的優化器狀態的 dict。優化器狀態的分片基於 state_dict_type

傳回類型

Dict[str, Any]

static optim_state_dict_to_load(model, optim, optim_state_dict, is_named_optimizer=False, load_directly=False, group=None)[原始碼]

轉換優化器狀態字典,以便可以將其載入與 FSDP 模型關聯的優化器。

給定一個透過 optim_state_dict() 轉換的 optim_state_dict,它會被轉換為可以載入到 optim 的扁平化優化器狀態字典,optimmodel 的優化器。model 必須由 FullyShardedDataParallel 分片。

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> from torch.distributed.fsdp import StateDictType
>>> from torch.distributed.fsdp import FullStateDictConfig
>>> from torch.distributed.fsdp import FullOptimStateDictConfig
>>> # Save a checkpoint
>>> model, optim = ...
>>> FSDP.set_state_dict_type(
>>>     model,
>>>     StateDictType.FULL_STATE_DICT,
>>>     FullStateDictConfig(rank0_only=False),
>>>     FullOptimStateDictConfig(rank0_only=False),
>>> )
>>> state_dict = model.state_dict()
>>> original_osd = optim.state_dict()
>>> optim_state_dict = FSDP.optim_state_dict(
>>>     model,
>>>     optim,
>>>     optim_state_dict=original_osd
>>> )
>>> save_a_checkpoint(state_dict, optim_state_dict)
>>> # Load a checkpoint
>>> model, optim = ...
>>> state_dict, optim_state_dict = load_a_checkpoint()
>>> FSDP.set_state_dict_type(
>>>     model,
>>>     StateDictType.FULL_STATE_DICT,
>>>     FullStateDictConfig(rank0_only=False),
>>>     FullOptimStateDictConfig(rank0_only=False),
>>> )
>>> model.load_state_dict(state_dict)
>>> optim_state_dict = FSDP.optim_state_dict_to_load(
>>>     model, optim, optim_state_dict
>>> )
>>> optim.load_state_dict(optim_state_dict)
參數
  • model (torch.nn.Module) – 根模組(它可能是或可能不是 FullyShardedDataParallel 實例),其參數已傳遞至優化器 optim

  • optim (torch.optim.Optimizer) – model 參數的優化器。

  • optim_state_dict (Dict[str, Any]) – 要載入的優化器狀態。

  • is_named_optimizer (bool) – 此優化器是 NamedOptimizer 還是 KeyedOptimizer。僅在 optim 是 TorchRec 的 KeyedOptimizer 或 torch.distributed 的 NamedOptimizer 時,才設定為 True。

  • load_directly (bool) – 如果設定為 True,則此 API 也會在返回結果之前呼叫 optim.load_state_dict(result)。否則,使用者有責任呼叫 optim.load_state_dict() (預設值: False)

  • group (dist.ProcessGroup) – 模型的處理群組,參數在其中進行分片,如果使用預設處理群組,則為 None。( 預設值: None)

傳回類型

Dict[str, Any]

register_comm_hook(state, hook)[原始碼]

註冊一個通訊鉤子。

這是一個增強功能,它為使用者提供了一個靈活的鉤子,他們可以在其中指定 FSDP 如何跨多個工作器聚合梯度。此鉤子可用於實作多種演算法,例如 GossipGrad 和梯度壓縮,這些演算法在使用 FullyShardedDataParallel 訓練時涉及不同的參數同步通訊策略。

警告

FSDP 通訊鉤子應該在執行初始正向傳遞之前註冊,並且只能註冊一次。

參數
  • state (object) –

    傳遞給鉤子以在訓練過程中維護任何狀態資訊。範例包括梯度壓縮中的錯誤回饋、在 GossipGrad 中接下來要通訊的對等節點等。它由每個工作器本地儲存,並由工作器上的所有梯度張量共用。

  • hook (Callable) – 可呼叫物件,它具有以下其中一個簽章:1) hook: Callable[torch.Tensor] -> None:此函數接收一個 Python 張量,它表示關於此 FSDP 單元包裝的模型(未被其他 FSDP 子單元包裝)的所有變數的完整、扁平化、未分片的梯度。然後它會執行所有必要的處理並返回 None;2) hook: Callable[torch.Tensor, torch.Tensor] -> None:此函數接收兩個 Python 張量,第一個表示關於此 FSDP 單元包裝的模型(未被其他 FSDP 子單元包裝)的所有變數的完整、扁平化、未分片的梯度。後者表示一個預先調整大小的張量,用於在縮減後儲存分片梯度的區塊。在這兩種情況下,可呼叫物件都會執行所有必要的處理並返回 None。具有簽章 1 的可呼叫物件預計會處理 NO_SHARD 情況的梯度通訊。具有簽章 2 的可呼叫物件預計會處理分片情況的梯度通訊。

static rekey_optim_state_dict(optim_state_dict, optim_state_key_type, model, optim_input=None, optim=None)[原始碼]

重新設定優化器狀態字典 optim_state_dict 的鍵,以使用鍵類型 optim_state_key_type

這可以用於在具有 FSDP 執行個體和沒有 FSDP 執行個體的模型的優化器狀態字典之間實現相容性。

要重新設定 FSDP 完整優化器狀態字典(即來自 full_optim_state_dict())的鍵以使用參數 ID 並且可載入到未包裝的模型

>>> wrapped_model, wrapped_optim = ...
>>> full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim)
>>> nonwrapped_model, nonwrapped_optim = ...
>>> rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model)
>>> nonwrapped_optim.load_state_dict(rekeyed_osd)

要重新設定來自未包裝模型的普通優化器狀態字典的鍵,使其可載入到包裝的模型

>>> nonwrapped_model, nonwrapped_optim = ...
>>> osd = nonwrapped_optim.state_dict()
>>> rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model)
>>> wrapped_model, wrapped_optim = ...
>>> sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model)
>>> wrapped_optim.load_state_dict(sharded_osd)
傳回值

使用 optim_state_key_type 指定的參數鍵重新設定鍵的優化器狀態字典。

傳回類型

Dict[str, Any]

static scatter_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None, group=None)[原始碼]

將完整優化器狀態字典從 rank 0 分散到所有其他 rank。

返回每個 rank 上的分片優化器狀態字典。返回值與 shard_full_optim_state_dict() 相同,並且在 rank 0 上,第一個參數應該是 full_optim_state_dict() 的返回值。

範例

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> model, optim = ...
>>> full_osd = FSDP.full_optim_state_dict(model, optim)  # only non-empty on rank 0
>>> # Define new model with possibly different world size
>>> new_model, new_optim, new_group = ...
>>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group)
>>> new_optim.load_state_dict(sharded_osd)

備註

可以使用 shard_full_optim_state_dict()scatter_full_optim_state_dict() 來取得分片的優化器狀態字典以供載入。假設完整的優化器狀態字典位於 CPU 記憶體中,前者需要每個 rank 在 CPU 記憶體中都擁有完整的字典,其中每個 rank 個別對字典進行分片,而無需任何通訊;而後者只需要 rank 0 在 CPU 記憶體中擁有完整的字典,其中 rank 0 將每個分片移至 GPU 記憶體(對於 NCCL),並將其適當傳輸給各個 rank。因此,前者具有較高的總體 CPU 記憶體成本,而後者具有較高的通訊成本。

參數
  • full_optim_state_dict (Optional[Dict[str, Any]]) – 對應於未扁平化參數的優化器狀態字典,如果在 rank 0 上,則持有完整的非分片優化器狀態;在非零 rank 上,此參數將被忽略。

  • model (torch.nn.Module) – 根模組(可能是或不是 FullyShardedDataParallel 實例),其參數對應於 full_optim_state_dict 中的優化器狀態。

  • optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – 傳遞給優化器的輸入,表示參數群組的 list 或參數的可迭代物件;如果為 None,則此方法假設輸入為 model.parameters()。此參數已棄用,且不再需要傳遞。 (預設值: None)

  • optim (Optional[torch.optim.Optimizer]) – 將載入此方法所返回的狀態字典的優化器。建議使用此參數,而非 optim_input。 (預設值: None)

  • group (dist.ProcessGroup) – 模型的處理程序群組,如果使用預設處理程序群組,則為 None。(預設值: None)

傳回值

完整的優化器狀態字典現在重新映射到扁平化參數,而不是未扁平化參數,並且僅限於包含此 rank 的優化器狀態部分。

傳回類型

Dict[str, Any]

static set_state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None)[source]

設定目標模組所有後代 FSDP 模組的 state_dict_type

同時也接受模型和優化器狀態字典的(可選)配置。目標模組不一定要是 FSDP 模組。如果目標模組是 FSDP 模組,則其 state_dict_type 也將被更改。

備註

此 API 應僅針對最上層(根)模組呼叫。

備註

此 API 使使用者能夠透明地使用傳統的 state_dict API 來取得模型檢查點,以防根 FSDP 模組被另一個 nn.Module 包裝的情況。例如,以下操作將確保在所有非 FSDP 實例上呼叫 state_dict,同時將 FSDP 分派到 sharded_state_dict 實作

範例

>>> model = DDP(FSDP(...))
>>> FSDP.set_state_dict_type(
>>>     model,
>>>     StateDictType.SHARDED_STATE_DICT,
>>>     state_dict_config = ShardedStateDictConfig(offload_to_cpu=True),
>>>     optim_state_dict_config = OptimStateDictConfig(offload_to_cpu=True),
>>> )
>>> param_state_dict = model.state_dict()
>>> optim_state_dict = FSDP.optim_state_dict(model, optim)
參數
  • module (torch.nn.Module) – 根模組。

  • state_dict_type (StateDictType) – 要設定的 state_dict_type

  • state_dict_config (Optional[StateDictConfig]) – 目標 state_dict_type 的配置。

  • optim_state_dict_config (Optional[OptimStateDictConfig]) – 優化器狀態字典的配置。

傳回值

包含模組先前狀態字典類型和配置的 StateDictSettings。

傳回類型

StateDictSettings

static shard_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None)[source]

分片一個完整的優化器狀態字典。

full_optim_state_dict 中的狀態重新映射到扁平化參數,而不是未扁平化參數,並限制為僅包含此 rank 的優化器狀態部分。第一個參數應該是 full_optim_state_dict() 的返回值。

範例

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> model, optim = ...
>>> full_osd = FSDP.full_optim_state_dict(model, optim)
>>> torch.save(full_osd, PATH)
>>> # Define new model with possibly different world size
>>> new_model, new_optim = ...
>>> full_osd = torch.load(PATH)
>>> sharded_osd = FSDP.shard_full_optim_state_dict(full_osd, new_model)
>>> new_optim.load_state_dict(sharded_osd)

備註

可以使用 shard_full_optim_state_dict()scatter_full_optim_state_dict() 來取得分片的優化器狀態字典以供載入。假設完整的優化器狀態字典位於 CPU 記憶體中,前者需要每個 rank 在 CPU 記憶體中都擁有完整的字典,其中每個 rank 個別對字典進行分片,而無需任何通訊;而後者只需要 rank 0 在 CPU 記憶體中擁有完整的字典,其中 rank 0 將每個分片移至 GPU 記憶體(對於 NCCL),並將其適當傳輸給各個 rank。因此,前者具有較高的總體 CPU 記憶體成本,而後者具有較高的通訊成本。

參數
  • full_optim_state_dict (Dict[str, Any]) – 對應於未扁平化參數的優化器狀態字典,並持有完整的非分片優化器狀態。

  • model (torch.nn.Module) – 根模組(可能是或不是 FullyShardedDataParallel 實例),其參數對應於 full_optim_state_dict 中的優化器狀態。

  • optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – 傳遞給優化器的輸入,表示參數群組的 list 或參數的可迭代物件;如果為 None,則此方法假設輸入為 model.parameters()。此參數已棄用,且不再需要傳遞。 (預設值: None)

  • optim (Optional[torch.optim.Optimizer]) – 將載入此方法所返回的狀態字典的優化器。建議使用此參數,而非 optim_input。 (預設值: None)

傳回值

完整的優化器狀態字典現在重新映射到扁平化參數,而不是未扁平化參數,並且僅限於包含此 rank 的優化器狀態部分。

傳回類型

Dict[str, Any]

static sharded_optim_state_dict(model, optim, group=None)[source]

以分片形式返回優化器狀態字典。

此 API 類似於 full_optim_state_dict(),但此 API 會將所有非零維度狀態分塊為 ShardedTensor 以節省記憶體。僅當使用上下文管理器 with state_dict_type(SHARDED_STATE_DICT): 衍生模型 state_dict 時,才應使用此 API。

如需詳細用法,請參閱 full_optim_state_dict()

警告

返回的狀態字典包含 ShardedTensor,不能直接被常規的 optim.load_state_dict 使用。

傳回類型

Dict[str, Any]

static state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None)[source]

設定目標模組所有後代 FSDP 模組的 state_dict_type

此上下文管理器與 set_state_dict_type() 具有相同的功能。有關詳細信息,請閱讀 set_state_dict_type() 的文檔。

範例

>>> model = DDP(FSDP(...))
>>> with FSDP.state_dict_type(
>>>     model,
>>>     StateDictType.SHARDED_STATE_DICT,
>>> ):
>>>     checkpoint = model.state_dict()
參數
  • module (torch.nn.Module) – 根模組。

  • state_dict_type (StateDictType) – 要設定的 state_dict_type

  • state_dict_config (Optional[StateDictConfig]) – 目標 state_dict_type 的模型 state_dict 配置。

  • optim_state_dict_config (Optional[OptimStateDictConfig]) – 目標 state_dict_type 的優化器 state_dict 配置。

傳回類型

產生器

static summon_full_params(module, recurse=True, writeback=True, rank0_only=False, offload_to_cpu=False, with_grads=False)[source]

使用此上下文管理器公開 FSDP 實例的完整參數。

在模型進行正向/反向傳播 *後*,這一點可能會有用,可以取得參數以進行額外的處理或檢查。它可以採用非 FSDP 模組,並會根據 recurse 引數,為所有包含的 FSDP 模組及其子模組召喚完整參數。

備註

這可以用於內部 FSDP。

備註

這 *不能* 在正向或反向傳播過程中使用。也不能從這個上下文開始正向和反向傳播。

備註

當上下文管理器結束後,參數將恢復為其本地分片,儲存行為與正向傳播相同。

備註

完整的參數可以被修改,但只有對應於本地參數分片的部份會在上下文管理器結束後保留(除非 writeback=False,在這種情況下,更改將被捨棄)。如果 FSDP 沒有對參數進行分片,目前只有在 world_size == 1NO_SHARD 配置時,才會保留修改,而不管 writeback 為何。

備註

此方法適用於本身不是 FSDP 但可能包含多個獨立 FSDP 單元的模組。在這種情況下,給定的引數將適用於所有包含的 FSDP 單元。

警告

請注意,目前不支援 rank0_only=Truewriteback=True 的組合,並且會引發錯誤。這是因為模型參數形狀在上下文中的各個 rank 之間會有所不同,寫入這些參數可能會導致在上下文結束時,各個 rank 之間的不一致。

警告

請注意,offload_to_cpurank0_only=False 會導致完整參數被冗餘地複製到同一台機器上的 GPU 的 CPU 記憶體中,這可能會導致 CPU OOM 的風險。建議將 offload_to_cpurank0_only=True 一起使用。

參數
  • recurse (bool, 選用) – 遞迴地召喚巢狀 FSDP 實例的所有參數(預設值:True)。

  • writeback (bool, 選用) – 如果為 False,則在上下文管理器結束後捨棄對參數的修改;停用此選項可以稍微提高效率(預設值:True)

  • rank0_only (bool, 選用) – 如果為 True,則完整參數只會在全域 rank 0 上實現。這表示在上下文中,只有 rank 0 會擁有完整參數,而其他 rank 將擁有分片參數。請注意,不支援將 rank0_only=Truewriteback=True 一起設定,因為模型參數形狀在上下文中的各個 rank 之間會有所不同,寫入這些參數可能會導致在上下文結束時,各個 rank 之間的不一致。

  • offload_to_cpu (bool, 選用) – 如果為 True,則完整參數會卸載到 CPU。請注意,目前只有在參數已分片時才會進行卸載(只有在 world_size = 1 或 NO_SHARD 配置時才不會進行分片)。建議將 offload_to_cpurank0_only=True 一起使用,以避免將模型參數的冗餘副本卸載到相同的 CPU 記憶體。

  • with_grads (bool, 選用) – 如果為 True,則梯度也會與參數一起取消分片。目前,只有在將 use_orig_params=True 傳遞給 FSDP 建構函數,並且將 offload_to_cpu=False 傳遞給此方法時才支援此功能。(預設值:False

傳回類型

產生器

類別 torch.distributed.fsdp.BackwardPrefetch(value)[來源]

這會配置顯式反向預取,它可以透過在反向傳播中啟用通訊和計算重疊來提高吞吐量,但會稍微增加記憶體使用量。

  • BACKWARD_PRE:這可以實現最大的重疊,但也會增加最多的記憶體使用量。這會在計算當前參數集的梯度 *之前* 預取下一組參數。這會重疊 *下一個 all-gather* 和 *當前梯度計算*,並且在高峰期,它會在記憶體中保存當前參數集、下一組參數和當前梯度集。

  • BACKWARD_POST:這可以實現較少的重疊,但需要的記憶體使用量也較少。這會在計算當前參數集的梯度 *之後* 預取下一組參數。這會重疊 *當前 reduce-scatter* 和 *下一個梯度計算*,並且它會在釋放當前參數集後才分配記憶體給下一組參數,只在高峰期在記憶體中保存下一組參數和當前梯度集。

  • FSDP 的 backward_prefetch 引數接受 None,這會完全停用反向預取。這沒有重疊,也不會增加記憶體使用量。一般來說,我們不建議使用此設定,因為它可能會顯著降低吞吐量。

更多技術背景:對於使用 NCCL 後端的單一行程群組,任何集合操作,即使是從不同串流發出的,都會爭奪相同的每個裝置 NCCL 串流,這意味著發出集合操作的相對順序對於重疊來說很重要。兩個反向預取值對應於不同的發出順序。

類別 torch.distributed.fsdp.ShardingStrategy(value)[來源]

這指定了 FullyShardedDataParallel 用於分散式訓練的分片策略。

  • FULL_SHARD:參數、梯度和優化器狀態都會被分片。對於參數,此策略會在正向傳播之前取消分片(透過 all-gather)、在正向傳播之後重新分片、在反向計算之前取消分片,以及在反向計算之後重新分片。對於梯度,它會在反向計算之後同步並分片它們(透過 reduce-scatter)。分片的優化器狀態會在每個 rank 本地更新。

  • SHARD_GRAD_OP:梯度和優化器狀態在計算過程中會被分片,此外,參數在計算之外也會被分片。對於參數,此策略會在正向傳播之前取消分片、在正向傳播之後不會重新分片,並且只會在反向計算之後重新分片。分片的優化器狀態會在每個 rank 本地更新。在 no_sync() 內部,參數在反向計算之後不會重新分片。

  • NO_SHARD:參數、梯度和優化器狀態不會被分片,而是像 PyTorch 的 DistributedDataParallel API 一樣在各個 rank 之間複製。對於梯度,此策略會在反向計算之後同步它們(透過 all-reduce)。未分片的優化器狀態會在每個 rank 本地更新。

  • HYBRID_SHARD:在節點內應用 FULL_SHARD,並在節點之間複製參數。這會減少通訊量,因為昂貴的 all-gather 和 reduce-scatter 只會在節點內完成,這對於中型模型來說可能更有效率。

  • _HYBRID_SHARD_ZERO2:在節點內應用 SHARD_GRAD_OP,並在節點之間複製參數。這類似於 HYBRID_SHARD,但這可能會提供更高的吞吐量,因為未分片的參數在正向傳播之後不會被釋放,從而在反向傳播之前節省了 all-gather 的成本。

類別 torch.distributed.fsdp.MixedPrecision(param_dtype=None, reduce_dtype=None, buffer_dtype=None, keep_low_precision_grads=False, cast_forward_inputs=False, cast_root_forward_inputs=True, _module_classes_to_ignore=(<class 'torch.nn.modules.batchnorm._BatchNorm'>, ))[來源]

這會配置 FSDP 原生的混合精度訓練。

變數
  • param_dtype (選用[torch.dtype]) – 這指定了模型參數在正向和反向傳播過程中的 dtype,因此也是正向和反向計算的 dtype。在正向和反向傳播之外,*分片* 參數會保持完整精度(例如,對於優化器步驟),並且對於模型檢查點,參數始終以完整精度儲存。(預設值:None

  • reduce_dtype (選用[torch.dtype]) – 這指定了梯度縮減的 dtype(即 reduce-scatter 或 all-reduce)。如果這是 None,但 param_dtype 不是 None,則它會採用 param_dtype 值,仍然以低精度執行梯度縮減。允許它與 param_dtype 不同,例如強制梯度縮減以完整精度執行。(預設值:None

  • buffer_dtype (Optional[torch.dtype]) – 這指定緩衝區的 dtype。FSDP 不會對緩衝區進行分片。而是在第一次前向傳遞時將其轉換為 buffer_dtype,並在之後保持該 dtype。對於模型檢查點,緩衝區會以全精度儲存,但 LOCAL_STATE_DICT 除外。(預設值: None)

  • keep_low_precision_grads (bool) – 如果為 False,則 FSDP 會在反向傳遞後將梯度提升至全精度,以便為優化器步驟做好準備。如果為 True,則 FSDP 會將梯度保持在用於梯度縮減的 dtype 中,如果使用支援以低精度執行的自定義優化器,則可以節省記憶體。(預設值: False)

  • cast_forward_inputs (bool) – 如果為 True,則此 FSDP 模組會將其前向參數和關鍵字參數轉換為 param_dtype。這是為了確保參數和輸入 dtype 與前向計算相符,這是許多操作所要求的。當僅對某些但並非所有 FSDP 模組套用混合精度時,可能需要將其設定為 True,在這種情況下,混合精度 FSDP 子模組需要重新轉換其輸入。(預設值: False)

  • cast_root_forward_inputs (bool) – 如果為 True,則根 FSDP 模組會將其前向參數和關鍵字參數轉換為 param_dtype,覆寫 cast_forward_inputs 的值。對於非根 FSDP 模組,這不會執行任何操作。(預設值: True)

  • _module_classes_to_ignore (Sequence[Type[torch.nn.modules.module.Module]]) – (Sequence[Type[nn.Module]]): 這指定在使用 auto_wrap_policy 時要忽略的混合精度模組類別:這些類別的模組將單獨套用 FSDP 並停用混合精度(表示最終的 FSDP 結構將偏離指定的策略)。如果未指定 auto_wrap_policy,則這不會執行任何操作。此 API 仍處於實驗階段,可能會有所變更。(預設值: (_BatchNorm,))

備註

此 API 仍處於實驗階段,可能會有所變更。

備註

只有浮點張量會轉換為其指定的 dtype。

備註

summon_full_params 中,參數會強制為全精度,但緩衝區則不會。

備註

即使層歸一化和批次歸一化的輸入是低精度(例如 float16bfloat16),它們也會累積在 float32 中。停用這些歸一化模組的 FSDP 混合精度僅表示仿射參數會保持在 float32 中。但是,這會為這些歸一化模組產生單獨的全收集和縮減散佈,這可能會很沒效率,因此如果工作負載允許,使用者應該盡可能仍然對這些模組套用混合精度。

備註

根據預設,如果使用者傳遞具有任何 _BatchNorm 模組的模型並指定 auto_wrap_policy,則批次歸一化模組將單獨套用 FSDP 並停用混合精度。請參閱 _module_classes_to_ignore 參數。

備註

MixedPrecision 預設為 cast_root_forward_inputs=Truecast_forward_inputs=False。對於根 FSDP 執行個體,其 cast_root_forward_inputs 優先於其 cast_forward_inputs。對於非根 FSDP 執行個體,將會忽略其 cast_root_forward_inputs 值。預設設定足以應付每個 FSDP 執行個體具有相同 MixedPrecision 設定且只需要在模型的前向傳遞開始時將輸入轉換為 param_dtype 的典型情況。

備註

對於具有不同 MixedPrecision 設定的巢狀 FSDP 執行個體,我們建議設定個別的 cast_forward_inputs 值,以便在每個執行個體的前向傳遞之前設定是否轉換輸入。在這種情況下,由於轉換發生在每個 FSDP 執行個體的前向傳遞之前,因此父 FSDP 執行個體應該在其 FSDP 子模組之前執行其非 FSDP 子模組,以避免由於 MixedPrecision 設定不同而導致啟用 dtype 被更改。

範例

>>> model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
>>> model[1] = FSDP(
>>>     model[1],
>>>     mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True),
>>> )
>>> model = FSDP(
>>>     model,
>>>     mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True),
>>> )

以上顯示了一個有效的範例。另一方面,如果將 model[1] 替換為 model[0],表示使用不同 MixedPrecision 的子模組先執行其前向傳遞,則 model[1] 將會錯誤地看到 float16 啟用,而不是 bfloat16 啟用。

class torch.distributed.fsdp.CPUOffload(offload_params=False)[來源]

這會設定 CPU 卸載。

變數

offload_params (bool) – 這指定在不參與計算時是否要將參數卸載到 CPU。如果為 True,則這也會將梯度卸載到 CPU,表示優化器步驟會在 CPU 上執行。

class torch.distributed.fsdp.StateDictConfig(offload_to_cpu=False)[來源]

StateDictConfig 是所有 state_dict 設定類別的基類。使用者應該實例化子類別(例如 FullStateDictConfig),以便為 FSDP 支援的對應 state_dict 類型設定設定。

變數

offload_to_cpu (bool) – 如果為 True,則 FSDP 會將狀態字典值卸載到 CPU,如果為 False,則 FSDP 會將其保留在 GPU 上。(預設值: False)

class torch.distributed.fsdp.FullStateDictConfig(offload_to_cpu=False, rank0_only=False)[來源]

FullStateDictConfig 是一個設定類別,旨在與 StateDictType.FULL_STATE_DICT 搭配使用。我們建議在儲存完整狀態字典時啟用 offload_to_cpu=Truerank0_only=True,以分別節省 GPU 記憶體和 CPU 記憶體。此設定類別旨在透過 state_dict_type() 上下文管理器使用,如下所示

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> fsdp = FSDP(model, auto_wrap_policy=...)
>>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
>>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg):
>>>     state = fsdp.state_dict()
>>>     # `state` will be empty on non rank 0 and contain CPU tensors on rank 0.
>>> # To reload checkpoint for inference, finetuning, transfer learning, etc:
>>> model = model_fn() # Initialize model in preparation for wrapping with FSDP
>>> if dist.get_rank() == 0:
>>>     # Load checkpoint only on rank 0 to avoid memory redundancy
>>>     state_dict = torch.load("my_checkpoint.pt")
>>>     model.load_state_dict(state_dict)
>>> # All ranks initialize FSDP module as usual. `sync_module_states` argument
>>> # communicates loaded checkpoint states from rank 0 to rest of the world.
>>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True)
>>> # After this point, all ranks have FSDP model with loaded checkpoint.
變數

rank0_only (bool) – 如果為 True,則只有排名 0 會儲存完整狀態字典,而排名非零則會儲存空字典。如果為 False,則所有排名都會儲存完整狀態字典。(預設值: False)

class torch.distributed.fsdp.ShardedStateDictConfig(offload_to_cpu=False, _use_dtensor=False)[來源]

ShardedStateDictConfig 是一個設定類別,旨在與 StateDictType.SHARDED_STATE_DICT 搭配使用。

變數

_use_dtensor (bool) – 如果為 True,則 FSDP 會將狀態字典值儲存為 DTensor,如果為 False,則 FSDP 會將其儲存為 ShardedTensor。(預設值: False)

警告

_use_dtensorShardedStateDictConfig 的私有欄位,FSDP 使用它來決定狀態字典值的類型。使用者不應手動修改 _use_dtensor

class torch.distributed.fsdp.LocalStateDictConfig(offload_to_cpu: bool = False)[原始碼]
class torch.distributed.fsdp.OptimStateDictConfig(offload_to_cpu=True)[原始碼]

OptimStateDictConfig 是所有 optim_state_dict 配置類別的基底類別。使用者應實例化一個子類別(例如 FullOptimStateDictConfig)以便為 FSDP 支援的相應 optim_state_dict 類型配置設定。

變數

offload_to_cpu (bool) – 如果為 True,則 FSDP 會將狀態字典的張量值卸載到 CPU,如果為 False,則 FSDP 會將它們保留在原始裝置上(除非啟用了參數 CPU 卸載,否則為 GPU)。(預設值:True

class torch.distributed.fsdp.FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False)[原始碼]
變數

rank0_only (bool) – 如果為 True,則只有排名 0 會儲存完整狀態字典,而排名非零則會儲存空字典。如果為 False,則所有排名都會儲存完整狀態字典。(預設值: False)

class torch.distributed.fsdp.ShardedOptimStateDictConfig(offload_to_cpu=True, _use_dtensor=False)[原始碼]

ShardedOptimStateDictConfig 是一個配置類別,旨在與 StateDictType.SHARDED_STATE_DICT 一起使用。

變數

_use_dtensor (bool) – 如果為 True,則 FSDP 會將狀態字典值儲存為 DTensor,如果為 False,則 FSDP 會將其儲存為 ShardedTensor。(預設值: False)

警告

_use_dtensorShardedOptimStateDictConfig 的私有欄位,FSDP 使用它來確定狀態字典值的類型。使用者不應手動修改 _use_dtensor

class torch.distributed.fsdp.LocalOptimStateDictConfig(offload_to_cpu: bool = False)[原始碼]
class torch.distributed.fsdp.StateDictSettings(state_dict_type: torch.distributed.fsdp.api.StateDictType, state_dict_config: torch.distributed.fsdp.api.StateDictConfig, optim_state_dict_config: torch.distributed.fsdp.api.OptimStateDictConfig)[原始碼]

文件

取得 PyTorch 的完整開發者文件

查看文件

教學

取得適用於初學者和進階開發者的深入教學

查看教學

資源

尋找開發資源並取得問題解答

查看資源