• 文件 >
  • 分散式檢查點 - torch.distributed.checkpoint
快捷方式

分散式檢查點 - torch.distributed.checkpoint

分散式檢查點 (DCP) 支援從多個等級並行載入和儲存模型。它處理載入時重新分片,這使您能夠在一個叢集拓撲中儲存並在另一個叢集拓撲中載入。

DCP 在以下幾個重要方面不同於 torch.savetorch.load

  • 它為每個檢查點生成多個檔案,每個等級至少有一個檔案。

  • 它在原地操作,這意味著模型應首先分配其資料,然後 DCP 使用該儲存而不是重新分配。

載入和儲存檢查點的入口點如下所示

torch.distributed.checkpoint.state_dict_saver.save(state_dict, *, checkpoint_id=None, storage_writer=None, planner=None, process_group=None)[source]

以 SPMD 樣式儲存分散式模型。

此函式不同於 torch.save(),因為它處理 ShardedTensorDTensor,每個等級只儲存其本地分片。

對於每個 Stateful 物件(同時具有 state_dictload_state_dict),儲存將在序列化之前呼叫 state_dict

警告

對於儲存的 state_dict,在 PyTorch 版本之間沒有向後相容性的保證。

警告

如果使用 process_group 引數,請確保只有其等級呼叫 save_state_dict 並且 state_dict 中的所有資料都屬於它。

注意

儲存 FSDP 的 ShardingStrategy.HYBRID_SHARD 的檢查點時,只有分片組中的一個應該呼叫 save_state_dict,並且需要傳入相應的程序組。

注意

如果沒有程序組可用,此函式將假設意圖是在本地程序中儲存

state_dict。

引數
  • state_dict (Dict[str, Any]) – 要儲存的 state_dict。

  • checkpoint_id (Union[str, os.PathLike, None]) – 此檢查點例項的 ID。checkpoint_id 的含義取決於儲存。它可以是資料夾或檔案的路徑。如果儲存是鍵值儲存,它也可以是鍵。(預設值:None)

  • storage_writer (Optional[StorageWriter]) – 用於執行寫入的 StorageWriter 例項。如果未指定,DCP 將根據 checkpoint_id 自動推斷寫入器。如果 checkpoint_id 也為 None,則會引發異常。(預設值:None)

  • planner (Optional[SavePlanner]) – SavePlanner 例項。如果未指定,將使用預設計劃程式。(預設值:None)

  • process_group (Optional[ProcessGroup]) – 用於跨等級同步的 ProcessGroup。(預設值:None)

返回值

儲存的檢查點的元資料物件。

返回型別

Metadata

示例

>>> my_model = MyModule()
>>> state_dict = {"model": my_model}
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1")
>>> torch.distributed.checkpoint.save(
>>>     state_dict=state_dict,
>>>     storage_writer=fs_storage_writer,
>>> )

注意

save_state_dict 使用集體來協調跨等級的寫入。對於基於 NCCL 的程序組,在通訊發生之前,必須將物件的內部張量表示移動到 GPU 裝置。在這種情況下,使用的裝置由 torch.cuda.current_device() 給出,使用者有責任確保透過 torch.cuda.set_device() 設定此裝置,以便每個等級都有一個單獨的 GPU。

torch.distributed.checkpoint.state_dict_saver.async_save(state_dict, *, checkpoint_id=None, storage_writer=None, planner=None, process_group=None)[source]

save 的非同步版本。此程式碼首先在 CPU 上取消分階段 state_dict,然後在單獨的執行緒中呼叫 save

警告

此功能處於實驗階段,可能會發生變化。

引數
  • state_dict (Dict[str, Any]) – 要儲存的 state_dict。

  • checkpoint_id (Union[str, os.PathLike, None]) – 此檢查點例項的 ID。checkpoint_id 的含義取決於儲存。它可以是資料夾或檔案的路徑。如果儲存是鍵值儲存,它也可以是鍵。(預設值:None)

  • storage_writer (Optional[StorageWriter]) – 用於執行寫入的 StorageWriter 例項。如果未指定,DCP 將根據 checkpoint_id 自動推斷寫入器。如果 checkpoint_id 也為 None,則會引發異常。(預設值:None)

  • planner (Optional[SavePlanner]) – SavePlanner 例項。如果未指定,將使用預設計劃程式。(預設值:None)

  • process_group (Optional[ProcessGroup]) – 用於跨等級同步的 ProcessGroup。(預設值:None)

返回值

儲存save方法返回的元資料物件的未來。

返回型別

未來

示例

>>> my_model = MyModule()
>>> state_dict = {"model": my_model}
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1")
>>> checkpoint_future = torch.distributed.checkpoint.async_save(
>>>     state_dict=state_dict,
>>>     storage_writer=fs_storage_writer,
>>> )
>>>
>>> # ... do some work ...
>>>
>>> checkpoint_future.result()
torch.distributed.checkpoint.state_dict_saver.save_state_dict(state_dict, storage_writer, process_group=None, coordinator_rank=0, no_dist=False, planner=None)[source]

此方法已棄用。請切換到 'save' 方法。

返回型別

Metadata

torch.distributed.checkpoint.state_dict_loader.load(state_dict, *, checkpoint_id=None, storage_reader=None, planner=None, process_group=None)[source]

在 SPMD 樣式中載入分散式 state_dict

每個程序將嘗試讀取最少的資料量來滿足請求的 state_dict。在載入 ShardedTensorDTensor 例項時,每個程序只讀取其本地分片的相關資料。

對於每個 Stateful 物件(同時具有 state_dictload_state_dict),load 方法會首先呼叫 state_dict 來進行反序列化,之後在反序列化完成後呼叫 load_state_dict

警告

state_dict 中的所有張量必須在呼叫此函式之前,在目標裝置上分配。

所有非張量資料都使用 torch.load() 進行載入,並在 state_dict 上就地修改。

警告

使用者必須在根模組上呼叫 load_state_dict 以確保載入後處理和非張量資料能夠正確傳播。

引數
  • state_dict (Dict[str, Any]) – 要儲存的 state_dict。

  • checkpoint_id (Union[str, os.PathLike, None]) – 此檢查點例項的 ID。checkpoint_id 的含義取決於儲存。它可以是資料夾或檔案的路徑。如果儲存是鍵值儲存,它也可以是鍵。(預設值:None)

  • storage_reader (Optional[StorageReader]) – 用於執行讀取操作的 StorageWriter 例項。如果未指定此引數,DCP 會根據 checkpoint_id 自動推斷讀取器。如果 checkpoint_id 也為 None,則會丟擲異常。(預設: None)

  • planner (Optional[LoadPlanner]) – LoadPlanner 例項。如果未指定此引數,將使用預設的規劃器。(預設: None)

  • process_group (Optional[ProcessGroup]) – 用於跨等級同步的 ProcessGroup。(預設值:None)

返回值

無。

返回型別

示例
>>> my_model = MyModule()
>>> optimizer = Adagrad(my_model.parameters())
>>> model_state_dict = my_model.state_dict()
>>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader("/checkpoint/1")
>>> torch.distributed.checkpoint.load_state_dict(
>>>     state_dict=model_state_dict,
>>>     storage_reader=fs_storage_reader,
>>> )
>>> # module.load_state_dict() function might have customized steps
>>> # to flush the state_dict, must call it to
>>> # ensure correct behavior.
>>> my_model.load_state_dict(model_state_dict)

注意

load_state_dict 方法使用集體操作來協調跨程序的讀取。對於基於 NCCL 的程序組,物件的內部張量表示必須在進行通訊之前移動到 GPU 裝置上。在這種情況下,使用的裝置由 torch.cuda.current_device() 指定,使用者負責確保每個程序都擁有單獨的 GPU,方法是使用 torch.cuda.set_device()

torch.distributed.checkpoint.state_dict_loader.load_state_dict(state_dict, storage_reader, process_group=None, coordinator_rank=0, no_dist=False, planner=None)[source]

此方法已棄用。請切換到 'load' 方法。

以下模組也適用於對非同步檢查點機制進行額外定製(torch.distributed.checkpoint.async_save

class torch.distributed.checkpoint.staging.AsyncStager(*args, **kwargs)[source]

此協議旨在為 dcp.async_save 提供定製和擴充套件性,允許使用者自定義在並行執行常規 dcp.save 路徑之前如何對資料進行暫存。預期的操作順序(在 torch.distributed.state_dict_saver.async_save 中具體定義)如下:

  1. AsyncStager.stage_data(state_dict)

    此呼叫使 AsyncStager 有機會對 state_dict 進行“暫存”。在此上下文中,暫存的預期和目的是建立 state_dict 的“訓練安全”表示形式,這意味著在暫存完成後對模組資料的任何更新都應該不會反映在從該方法返回的 state_dict 中。例如,在預設情況下,會在 CPU RAM 上建立整個 state_dict 的副本並在此返回,允許使用者繼續訓練而不會冒更改正在序列化的資料的風險。

  2. 並行呼叫 dcp.save 對從暫存方法返回的 state_dict 進行操作。此呼叫負責

    對 state_dict 進行序列化並將它寫入儲存。

  3. 如果 AsyncStager.should_synchronize_after_execute 為 True,則此方法將在

    序列化執行緒啟動後,在從 dcp.async_save 返回之前立即呼叫。如果此引數設定為 False,則假設使用者已經為進一步最佳化儲存延遲在訓練迴圈中定義了自定義同步點(例如,透過將暫存與前向/反向傳遞重疊),並且使用者有責任在適當的時間呼叫 AsyncStager.synchronize_staging

property should_synchronize_after_execute: bool

是否在執行暫存後進行同步。

stage(state_dict)[source]

返回 state_dict 的“暫存”副本。暫存副本的預期是,它不受在暫存呼叫完成後發生的任何更新的影響。

返回型別

Dict[str, Union[StatefulT, Any]]

synchronize_staging()[source]

stage 非同步進行的情況下,應該呼叫此方法以確保暫存完成,並且可以安全地開始修改原始 state_dict

class torch.distributed.checkpoint.staging.BlockingAsyncStager(cache_staged_state_dict=False, type_check=False)[source]

AsyncStager 的一種實現,它在 CPU RAM 上對 state_dict 進行暫存,並阻塞直到複製完成。此實現還提供了一個選項,可以使用固定記憶體來最佳化暫存延遲。

注意:在這種情況下,synchronize_staging 是一個空操作。

stage(state_dict)[source]

返回 state_dict 在 CPU 上的副本。

返回型別

Dict[str, Union[StatefulT, Any]]

synchronize_staging()[source]

空操作函式,因為暫存是阻塞的。

除了上述入口點之外,Stateful 物件(如下所述)在儲存/載入期間提供額外的定製。 .. automodule:: torch.distributed.checkpoint.stateful

class torch.distributed.checkpoint.stateful.Stateful(*args, **kwargs)[source]

用於可進行檢查點和恢復的物件的 Stateful 協議。

load_state_dict(state_dict)[source]

從提供的 state_dict 中恢復物件的狀態。

引數

state_dict (Dict[str, Any]) – 要從中恢復的 state dict

state_dict()[source]

物件應將其 state_dict 表示形式作為字典返回。該函式的輸出將被檢查點,並在以後在 load_state_dict() 中恢復。

警告

由於恢復檢查點的就地性質,此函式在 torch.distributed.checkpoint.load 期間也會被呼叫。

返回值

物件的 state dict

返回型別

Dict

示例 展示瞭如何使用 Pytorch Distributed Checkpoint 儲存 FSDP 模型。

以下型別定義了檢查點期間使用的 IO 介面

class torch.distributed.checkpoint.StorageReader[source]

load_state_dict 用於從儲存讀取資料的介面。

一個 StorageReader 例項在分散式檢查點中同時充當協調器和跟隨者。作為初始化的一部分,每個例項都會被告知其角色。

子類應該預期 load_state_dict 的以下呼叫順序

  1. (所有排名) 如果使用者傳遞有效的 checkpoint_id,則設定 checkpoint_id。

  2. (所有排名) read_metadata()

  3. (所有排名) set_up_storage_reader()

  4. (所有排名) prepare_local_plan()

  5. (協調器) prepare_global_plan()

  6. (所有排名) read_data()

abstract prepare_global_plan(plans)[source]

執行儲存載入的集中式規劃。

此方法僅在協調器例項上呼叫。

雖然此方法可以生成完全不同的計劃,但首選方法是在 LoadPlan::storage_data 中儲存特定於儲存的資料。

引數

plans (List[LoadPlan]) – LoadPlan 例項的列表,每個排名一個。

返回值

儲存全域性規劃後,經過轉換的 LoadPlan 的列表

返回型別

List[LoadPlan]

abstract prepare_local_plan(plan)[source]

執行特定於儲存的本地規劃。

雖然此方法可以生成完全不同的計劃,但推薦的方法是在 LoadPlan::storage_data 中儲存特定於儲存的資料。

引數

plan (LoadPlan) – 使用中的 LoadPlan 中的本地計劃。

返回值

儲存本地規劃後,經過轉換的 LoadPlan

返回型別

LoadPlan

abstract read_data(plan, planner)[source]

使用 plannerplan 中讀取所有專案以解析資料。

子類應呼叫 LoadPlanner::load_bytes 將 BytesIO 物件反序列化到正確的位置。

子類應呼叫 LoadPlanner::resolve_tensor 以獲取對應載入資料的張量的訪問許可權。

StorageLayer 負責正確排程所需的任何跨裝置複製操作。

引數
  • plan (LoadPlan) – 要執行的本地計劃

  • planner (LoadPlanner) – 用於解析專案的規劃器物件。

返回值

所有讀取完成後完成的 future。

返回型別

Future[None]

abstract read_metadata()[source]

讀取檢查點元資料。

返回值

與正在載入的檢查點關聯的元資料物件。

返回型別

Metadata

abstract reset(checkpoint_id=None)[source]

呼叫以指示將要進行全新的檢查點讀取。如果使用者為該檢查點讀取設定了 checkpoint_id,則可能會存在 checkpoint_id。checkpoint_id 的含義取決於儲存。它可以是資料夾/檔案的路徑,也可以是鍵值儲存的鍵。(預設值: None)

引數

checkpoint_id (Union[str, os.PathLike, None]) – 該檢查點例項的 ID。checkpoint_id 的含義取決於儲存。它可以是資料夾的路徑或檔案的路徑。如果儲存更像鍵值儲存,它也可以是鍵。

abstract set_up_storage_reader(metadata, is_coordinator)[source]

初始化此例項。

引數
  • metadata (Metadata) – 要使用的元資料模式。

  • is_coordinator (bool) – 此例項是否負責協調檢查點。

abstract classmethod validate_checkpoint_id(checkpoint_id)[source]

檢查給定的 checkpoint_id 是否受 stroage 支援。這使我們能夠啟用自動儲存選擇。

返回型別

bool

class torch.distributed.checkpoint.StorageWriter[source]

save_state_dict 用於寫入儲存的介面。

一個 StorageWriter 例項在分散式檢查點中同時充當協調器和跟隨者。作為初始化的一部分,每個例項都會被告知其角色。

子類應該預期以下呼叫順序。

  1. (所有排名) 如果使用者傳遞有效的 checkpoint_id,則設定 checkpoint_id。

  2. (所有排名) set_up_storage_writer()

  3. (所有排名) prepare_local_plan()

  4. (協調器) prepare_global_plan()

  5. (所有排名) write_data()

  6. (協調器) finish()

abstract finish(metadata, results)[source]

寫入元資料並將當前檢查點標記為成功。

用於序列化 metadata 的實際格式/模式是實現細節。唯一的要求是它可以恢復到相同的物件圖。

引數
  • metadata (Metadata) – 新檢查點的元資料

  • results (List[List[WriteResult]]) – 所有排名提供的 WriteResults 列表。

返回值

返回型別

abstract prepare_global_plan(plans)[source]

執行儲存的集中式規劃。

此方法僅在協調器例項上呼叫。

雖然此方法可以生成完全不同的計劃,但首選方法是在 SavePlan::storage_data 中儲存特定於儲存的資料。

引數

plans (List[SavePlan]) – 每個程序的 SavePlan 例項列表。

返回值

儲存全域性規劃後,經過轉換的 SavePlan 列表。

返回型別

List[SavePlan]

abstract prepare_local_plan(plan)[source]

執行特定於儲存的本地規劃。

雖然此方法可以生成完全不同的計劃,但推薦的方法是在 SavePlan::storage_data 中儲存儲存特定資料。

引數

plan (SavePlan) – 當前使用的 SavePlanner 中的本地計劃。

返回值

儲存本地規劃後,經過轉換的 SavePlan

返回型別

SavePlan

abstract reset(checkpoint_id=None)[source]

呼叫此方法表示即將發生新的檢查點寫入。如果使用者為此檢查點寫入設定了 checkpoint_id,則可能存在 checkpoint_id。checkpoint_id 的含義取決於儲存。它可以是資料夾/檔案的路徑或鍵值儲存的鍵。

引數

checkpoint_id (Union[str, os.PathLike, None]) – 此檢查點例項的 ID。checkpoint_id 的含義取決於儲存。它可以是資料夾或檔案的路徑。如果儲存是鍵值儲存,它也可以是鍵。(預設值:None)

abstract set_up_storage_writer(is_coordinator)[source]

初始化此例項。

引數

is_coordinator (bool) – 此例項是否負責協調檢查點。

storage_meta()[source]

返回儲存特定的元資料。這用於在檢查點中儲存其他資訊,這些資訊可能有助於提供請求級可觀察性。StorageMeta 在儲存呼叫期間傳遞給 SavePlanner。預設情況下返回 None。

TODO:提供示例

返回型別

Optional[StorageMeta]

abstract classmethod validate_checkpoint_id(checkpoint_id)[source]

檢查給定的 checkpoint_id 是否受 stroage 支援。這使我們能夠啟用自動儲存選擇。

返回型別

bool

abstract write_data(plan, planner)[source]

使用 planner 寫入 plan 中的所有專案以解析資料。

子類應該對計劃中的每個專案呼叫 SavePlanner::resolve_data 以訪問要寫入的底層物件。

子類應該延遲呼叫 resolve_data,因為它可能會分配記憶體。在張量的情況下,做出以下假設

  • 它們可能位於任何裝置上,包括與 WriteItem::tensor_data 上的裝置不匹配的裝置。

  • 它們可能是檢視,也可能不是連續的。只需要儲存投影。

引數
  • plan (SavePlan) – 要執行的儲存計劃。

  • planner (SavePlanner) – 用於解析專案到資料的計劃程式物件。

返回值

完成到 WriteResult 列表的未來。

返回型別

Future[List[WriteResult]]

以下型別定義了檢查點期間使用的計劃程式介面

class torch.distributed.checkpoint.LoadPlanner[source]

定義了 load_state_dict 用於規劃載入過程的協議的抽象類。

LoadPlanner 是有狀態的物件,可用於自定義整個載入過程。

LoadPlanner 充當 state_dict 的訪問代理,因此對它的任何轉換都將對整個程序可見。

在 load_state_dict 期間,計劃程式子類可以預期以下呼叫順序

  1. set_up_planner - 在所有程序上呼叫。

    表示載入檢查點開始。

  2. create_local_plan - 在所有程序上呼叫。

    處理 state_dict 並生成一個將傳送進行全域性規劃的 LoadPlan

  3. create_global_plan - 僅在協調器程序上呼叫。

    接收來自所有程序的 LoadPlan 並做出任何全域性決策。

  4. load_bytes - 在每個程序上呼叫多次

    這將對 state_dict 中的每個非張量值呼叫一次。

  5. resolve_tensor 和 commit_tensor - 在每個程序上呼叫多次

    它們將對 state_dict 中的每個張量值成對呼叫。

建議使用者擴充套件 DefaultLoadPlanner 而不是直接擴充套件此介面,因為大多數更改可以透過單個方法的更改來表達。

有兩種常見的擴充套件模式

重寫 state_dict。這是擴充套件載入過程的最簡單方法,因為它不需要了解 LoadPlan 如何工作的複雜性。我們需要保留對原始 state_dict 的引用,因為載入是在原地進行的,因此我們需要能夠在原地執行它

>>> class RenamePlanner(DefaultLoadPlanner):
>>>     def set_up_planner(
>>>         self,
>>>         state_dict: STATE_DICT_TYPE,
>>>         metadata: Metadata,
>>>         is_coordinator: bool,
>>>     ) -> None:
>>>         self.original_state_dict = state_dict
>>>         state_dict = {"foo_" + k: v for k, v in state_dict.items()}
>>>
>>>         if self.flatten_sharded_tensors:
>>>             state_dict = _flatten_sharded_tensors(state_dict)
>>>
>>>         if self.flatten_state_dict:
>>>             state_dict, self.mappings = flatten_state_dict(state_dict)
>>>
>>>         self.state_dict = state_dict
>>>         self.metadata = metadata
>>>         self.is_coordinator = is_coordinator
>>>
>>>     def load_bytes(self, read_item, value):
>>>         # Remove the "foo_" prefix
>>>         self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value)

修改 resolve_tensor 和 commit_tensor 以處理載入時轉換。

>>> class MetaModelMaterialize(DefaultSavePlanner):
>>>     def resolve_tensor(self, read_item):
>>>         tensor = super().resolve_tensor(read_item)
>>>         return torch.empty_like(tensor, device="cpu")
>>>
>>>     def commit_tensor(self, read_item, tensor):
>>>         self.state_dict[read_item.dest_index.fqn] = tensor
abstract commit_tensor(read_item, tensor)[source]

在 StorageReader 完成將資料載入到 tensor 中後呼叫一次。

提供的張量與呼叫 resolve_tensor 返回的張量相同。僅當此 LoadPlanner 需要在將 tensor 複製回 state_dict 中的張量之前對其進行後處理時,才需要此方法。

張量的內容將遵循其裝置同步模型。

abstract create_global_plan(global_plan)[source]

計算全域性載入計劃並返回每個程序的計劃。

. 注意:這僅在協調器程序上呼叫

返回型別

List[LoadPlan]

abstract create_local_plan()[source]

根據 state_dict 和 set_up_planner 提供的元資料建立 LoadPlan。

. 注意:這在每個程序上呼叫。

返回型別

LoadPlan

abstract finish_plan(central_plan)[source]

接受來自協調器的計劃並返回最終的 LoadPlan。

返回型別

LoadPlan

abstract load_bytes(read_item, value)[source]

載入由 read_item``和 ``value 描述的專案。

此方法預計將修改底層 state_dict 的位置。

value 的內容由用於生成正在載入的檢查點的 SavePlanner 定義。

resolve_bytes(read_item)[source]

返回 StorageReader 用於載入 read_item 的 BytesIO。

BytesIO 應該與底層 state_dict 中的 BytesIO 具有別名,因為 StorageReader 將替換其內容。

返回型別

BytesIO

abstract resolve_tensor(read_item)[source]

返回由 read_item 描述的張量,以便 StorageReader 用於載入 read_item

張量應該與基礎 state_dict 中的一個張量建立別名,因為 StorageReader 將替換其內容。 如果由於任何原因,這不可行,規劃器可以使用 commit_tensor 方法將資料複製回 state_dict 中的張量。

返回型別

張量

abstract set_up_planner(state_dict, metadata=None, is_coordinator=False)[source]

初始化此例項以將資料載入到 state_dict 中。

. 注意:這在每個程序上呼叫。

class torch.distributed.checkpoint.LoadPlan(items: List[torch.distributed.checkpoint.planner.ReadItem], storage_data: Any = None, planner_data: Any = None)[source]
class torch.distributed.checkpoint.ReadItem(type: torch.distributed.checkpoint.planner.LoadItemType, dest_index: torch.distributed.checkpoint.metadata.MetadataIndex, dest_offsets: torch.Size, storage_index: torch.distributed.checkpoint.metadata.MetadataIndex, storage_offsets: torch.Size, lengths: torch.Size)[source]
class torch.distributed.checkpoint.SavePlanner[source]

抽象類,定義 save_state_dict 用於規劃儲存過程的協議。

SavePlanner 是有狀態的物件,可用於自定義整個儲存過程。

SavePlanner 充當 state_dict 的訪問代理,因此對它的任何轉換都將對整個過程可見。

規劃器子類可以預期在 save_state_dict 期間以下列順序呼叫:

  1. set_up_planner - 在所有程序上呼叫。

    發出檢查點儲存開始的訊號。

  2. create_local_plan - 在所有程序上呼叫。

    處理 state_dict 並生成一個 SavePlan,該計劃將被髮送用於全域性規劃。

  3. create_global_plan - 僅在協調器程序上呼叫。

    獲取所有等級的 SavePlan,並做出任何全域性決策。

  4. finish_plan - 在所有等級上呼叫。

    這使每個等級有機會適應全域性規劃決策。

  5. resolve_data - 在每個等級上呼叫多次。

    state_dict 上查詢儲存層要寫入的值。

建議使用者擴充套件 DefaultSavePlanner 而不是直接擴充套件此介面,因為大多數更改可以透過單個方法中的更改來表達。

擴充套件通常有 3 種模式:

重寫 state_dict。 這是擴充套件儲存過程的最簡單方法,因為它不需要了解 SavePlan 的工作原理。

>>> class RenamePlanner(DefaultSavePlanner):
>>>     def set_up_planner(
>>>         self,
>>>         state_dict: STATE_DICT_TYPE,
>>>         storage_meta: Optional[StorageMeta],
>>>         is_coordinator: bool,
>>>     ) -> None:
>>>         # prefix all keys with `foo_``
>>>         super().set_up_planner({"foo_" + k: v for k, v in state_dict.items()}, storage_meta, is_coordinator)

同時修改本地計劃和查詢。 這在需要精確控制如何持久化資料時很有用。

>>> class FP16Planner(DefaultSavePlanner):
>>>     def create_local_plan(self):
>>>         plan = super().create_local_plan()
>>>         for p in plan:
>>>             if p.tensor_data is not None:
>>>                 p.tensor_data.properties.dtype = torch.float16
>>>         return plan
>>>
>>>     def resolve_data(self, write_item):
>>>         item = super().resolve_data(write_item)
>>>         return item if write_item.type == WriteItemType.BYTE_IO else item.to(torch.float16)

使用全域性規劃步驟做出不能由每個等級單獨做出的集中決策。

>>> from itertools import islice
>>> from dataclasses import replace
>>> class DDPLoadBalancingPlanner(DefaultSavePlanner):
>>>     # This uses the default local plan behavior of having all non-sharded writes in rank 0
>>>     # This sample doesn't handle ShardedTensors
>>>     def create_global_plan(self, all_plans):
>>>         def chunk(it, size):
>>>             it = iter(it)
>>>         return list(iter(lambda: tuple(islice(it, size)), ()))
>>>         all_plans = [
>>>             replace(plan, items=items) for plan, items in
>>>                 zip(all_plans, chunk(all_plans[0].items, len(all_plans)))
>>>         ]
>>>         return super().create_global_plan(all_plans)

最後,一些規劃器需要在檢查點中儲存額外的元資料,這是透過讓每個等級在其本地計劃中貢獻其資料項並讓全域性規劃器聚合它們來完成的。

>>> class SaveExtraDataPlanner(DefaultSavePlanner):
>>>     def create_local_plan(self) -> SavePlan:
>>>         plan = super().create_local_plan()
>>>         return replace(plan, planner_data="per-rank-data")
>>>
>>>     def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]:
>>>         global_plan, metadata = super().create_global_plan(all_plans)
>>>         merged_data = [p.planner_data for p in global_plan]
>>>         metadata = replace(metadata, planner_data=merged_data)
>>>         return global_plan, metadata
abstract create_global_plan(all_plans)[source]

計算全域性檢查點計劃並返回每個等級的本地計劃。

僅在協調器等級上呼叫。

返回型別

Tuple[List[SavePlan], Metadata]

abstract create_local_plan()[source]

計算當前等級的儲存計劃。

這將被聚合並傳遞給 create_global_plan。 規劃器特定資料可以透過 SavePlan::planner_data 傳遞。

在所有等級上呼叫。

返回型別

SavePlan

abstract finish_plan(new_plan)[source]

合併由 create_local_plan 建立的計劃和 create_global_plan 的結果。

在所有等級上呼叫。

返回型別

SavePlan

abstract resolve_data(write_item)[source]

轉換和準備 write_item(來自 state_dict)以供儲存,確保冪等性和執行緒安全性。

state_dict 中查詢與 write_item 關聯的物件,並在儲存層使用它之前應用任何轉換(如序列化)。

在每個等級上呼叫多次,每個等級至少呼叫一次(在最終 SavePlan 中的每個 WriteItem 上呼叫一次)。

此方法應該是冪等且執行緒安全的。 StorageWriter 實現可以根據需要儘可能頻繁地呼叫它。

任何分配記憶體的轉換都應該在呼叫此方法時延遲完成,以減少檢查點所需的峰值記憶體。

返回張量時,它們可以位於任何裝置或格式上,它們也可以是檢視。 儲存層負責弄清楚如何儲存它們。

返回型別

Union[Tensor, BytesIO]

abstract set_up_planner(state_dict, storage_meta=None, is_coordinator=False)[source]

初始化此規劃器以儲存 state_dict

實現應該儲存這些值,因為它們在儲存過程中不會被提供。

在所有等級上呼叫。

class torch.distributed.checkpoint.SavePlan(items: List[torch.distributed.checkpoint.planner.WriteItem], storage_data: Any = None, planner_data: Any = None)[source]
class torch.distributed.checkpoint.planner.WriteItem(index, type, tensor_data=None)[source]

用於儲存儲存資訊的 Dataclass。

tensor_storage_size()[source]

計算底層張量的儲存大小,如果這不是張量寫入,則返回 None。

返回值

可選 [int] 儲存大小,以位元組為單位,如果存在底層張量。

返回型別

Optional[int]

我們提供基於檔案系統的儲存層

class torch.distributed.checkpoint.FileSystemReader(path)[source]
property checkpoint_id: Union[str, PathLike]

返回用於儲存檢查點的 checkpoint_id。

class torch.distributed.checkpoint.FileSystemWriter(path, single_file_per_rank=True, sync_files=True, thread_count=1, per_thread_copy_ahead=10000000, cache_staged_state_dict=False, overwrite=True)[source]

使用檔案 IO 的 StorageWriter 的基本實現。

此實現做出以下假設和簡化

  • 檢查點路徑為空或不存在的目錄。

  • 檔案建立是原子的

檢查點包含每個寫入請求一個檔案,以及一個帶有序列化元資料的 .metadata 檔案。

stage(state_dict)[source]

AsyncStager.stage 的重寫

返回型別

Dict[str, Union[StatefulT, Any]]

我們提供 LoadPlannerSavePlanner 的預設實現,可以處理所有 torch.distributed 結構,例如 FSDP、DDP、ShardedTensor 和 DistributedTensor。

class torch.distributed.checkpoint.DefaultSavePlanner(flatten_state_dict=True, flatten_sharded_tensors=True, dedup_replicated_tensors=None, dedup_save_to_lowest_rank=False)[source]
lookup_object(index)[source]

從 planner 介面擴充套件,使其易於擴充套件預設 planner。

返回型別

任何

transform_object(write_item, object)[source]

從 planner 介面擴充套件,使其易於擴充套件預設 planner。

class torch.distributed.checkpoint.DefaultLoadPlanner(flatten_state_dict=True, flatten_sharded_tensors=True, allow_partial_load=False)[source]

在 LoadPlanner 之上新增多個功能的 DefaultLoadPlanner。

特別是它添加了以下內容

flatten_state_dict: 處理具有巢狀字典的 state_dict flatten_sharded_tensors: 對於 2D 並行模式下的 FSDP allow_partial_load: 如果為 False,則如果 state_dict 中存在鍵,但在檢查點中不存在,則會引發執行時錯誤。

lookup_tensor(index)[source]

從 planner 介面擴充套件,使其易於擴充套件預設 planner。

返回型別

張量

transform_tensor(read_item, tensor)[source]

從 planner 介面擴充套件,使其易於擴充套件預設 planner。

由於遺留的設計決策,FSDPDDP 的狀態字典可能具有不同的鍵或完全限定名稱(例如,layer1.weight),即使原始的非並行模型相同。此外,FSDP 提供了各種型別的模型狀態字典,例如完整狀態字典和分片狀態字典。此外,最佳化器狀態字典使用引數 ID 而不是完全限定名稱來標識引數,這可能在使用並行性(例如,管道並行性)時會導致問題。

為了解決這些挑戰,我們為使用者提供了一組 API,以便輕鬆管理 state_dict。get_model_state_dict 返回一個模型狀態字典,其鍵與非並行模型狀態字典返回的鍵一致。類似地,get_optimizer_state_dict 提供最佳化器狀態字典,其鍵在應用的所有並行性中保持一致。為了實現這種一致性,get_optimizer_state_dict 將引數 ID 轉換為與非並行模型狀態字典中找到的完全限定名稱相同的名稱。

請注意,這些 API 返回的結果可以直接與 torch.distributed.checkpoint.save()torch.distributed.checkpoint.load() 方法一起使用,無需任何其他轉換。

請注意,此功能是實驗性的,API 簽名將來可能會更改。

torch.distributed.checkpoint.state_dict.get_state_dict(model, optimizers, *, submodules=None, options=None)[source]

返回模型 state_dict 和最佳化器 state_dict。

get_state_dict 可以處理由 PyTorch FSDP/fully_shard、DDP/replicate、tensor_parallel/parallelize_module 以及這些並行性的任何組合並行的任何模組。 get_state_dict 的主要功能是:1.) 返回可以用不同數量的訓練器和/或不同並行性重新分片的模型和最佳化器 state_dict。 2.) 隱藏特定於並行性的 state_dict API。使用者無需呼叫這些 API。 3.) 對結果 state_dict 進行健全性檢查。

結果狀態字典的鍵是規範的 FQN(完全限定名稱)。規範的 FQN 指的是基於引數在 nn.Module 層次結構中的位置的 FQN。更具體地說,引數的規範 FQN 是當模組未由任何並行性分佈時由 module.named_parameters()module.named_buffers() 返回的 FQN。

由於最佳化器在內部使用引數 ID 來表示引數,因此在呼叫此 API 時,將從引數 ID 轉換為規範的 FQN。

示例

>>> import torch
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> from torch.nn.parallel import DistributedDataParallel as DDP
>>> from torch.distributed.checkpoint.state_dict import get_state_dict
>>> fsdp_model = FSDP(copy.deepcopy(model))
>>> fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)
>>> ddp_model = DDP(copy.deepcopy(model))
>>> ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)
>>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim)
>>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(fsdp_model, fsdp_optim)
>>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(),
>>> # the asserts will fail.
>>> assert ddp_state_dict == fsdp_state_dict
>>> assert ddp_optim_state == fsdp_optim_state_dict
引數
  • get_state_dict 也可以處理未並行的模組。在這種情況下,get_state_dict 只執行一個功能——將最佳化器引數 ID 轉換為規範的 FQN。

  • 最佳化器 (Union[None, 最佳化器, Iterable[最佳化器]]) – 用於最佳化 model 的最佳化器。

  • 子模組 (已棄用) – Optional[Set[nn.Module]]: 僅返回屬於子模組的模型引數。

  • 選項 (StateDictOptions) – 用於控制模型狀態字典和最佳化器狀態字典返回值的選項。有關詳細資訊,請參閱 StateDictOptions

返回值

Tuple,其中包含模型狀態字典和最佳化器狀態字典。

返回型別

Tuple[Dict[str, ValueType], OptimizerStateType]

torch.distributed.checkpoint.state_dict.get_model_state_dict(model, *, submodules=None, options=None)[source]

返回 model 的模型狀態字典。

有關詳細資訊,請參閱 get_state_dict

引數
  • get_state_dict 也可以處理未並行的模組。在這種情況下,get_state_dict 只執行一個功能——將最佳化器引數 ID 轉換為規範的 FQN。

  • 子模組 (已棄用) – Optional[Set[nn.Module]]: 僅返回屬於子模組的模型引數。

  • 選項 (StateDictOptions) – 用於控制模型狀態字典和最佳化器狀態字典返回值的選項。有關詳細資訊,請參閱 StateDictOptions

返回值

model 的狀態字典。

返回型別

Dict[str, ValueType]

torch.distributed.checkpoint.state_dict.get_optimizer_state_dict(model, optimizers, *, submodules=None, options=None)[source]

返回最佳化器的組合狀態字典。

有關詳細資訊,請參閱 get_state_dict

引數
  • get_state_dict 也可以處理未並行的模組。在這種情況下,get_state_dict 只執行一個功能——將最佳化器引數 ID 轉換為規範的 FQN。

  • 最佳化器 (Union[None, 最佳化器, Iterable[最佳化器]]) – 用於最佳化 model 的最佳化器。

  • 子模組 (已棄用) – Optional[Set[nn.Module]]: 僅返回屬於子模組的模型引數。

  • 選項 (StateDictOptions) – 用於控制模型狀態字典和最佳化器狀態字典返回值的選項。有關詳細資訊,請參閱 StateDictOptions

返回值

optimizers 的狀態字典。

返回型別

OptimizerStateType

torch.distributed.checkpoint.state_dict.set_state_dict(model, optimizers, *, model_state_dict, optim_state_dict, options=None)[source]

載入模型狀態字典和最佳化器狀態字典。

get_state_dict 的對應函式,用於將狀態字典設定到模型和最佳化器。給定的 model_state_dictoptim_state_dict 不必由 get_state_dict 返回,但必須滿足以下要求:1) 所有 FQN 都是 get_state_dict 中定義的規範 FQN,2) 如果張量被分片,則它必須是 ShardedTensor 或 DTensor,3) 最佳化器狀態字典不能包含引數 ID;鍵應該是規範的 FQN。

引數
  • get_state_dict 也可以處理未並行的模組。在這種情況下,get_state_dict 只執行一個功能——將最佳化器引數 ID 轉換為規範的 FQN。

  • 最佳化器 (Union[最佳化器, Iterable[最佳化器]]) – 用於最佳化 model 的最佳化器。

  • model_state_dict (Dict[str, ValueType]) – (Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]): 要載入的模型狀態字典。如果 model_state_dict 的鍵是 nn.Module,則鍵是 model 的子模組,而值應該是子模組的狀態字典。載入狀態字典時,將子模組的字首附加到狀態字典。

  • optim_state_dict (OptimizerStateType) – OptimizerStateType: 要載入的最佳化器狀態字典。

  • 選項 (StateDictOptions) – 用於控制模型狀態字典和最佳化器狀態字典載入方式的選項。有關詳細資訊,請參閱 StateDictOptions

返回值

  • missing_keys 是一個包含模型狀態字典中缺失鍵的 str 列表。

  • unexpected_keys 是一個包含模型狀態字典中意外部索引鍵的 str 列表。

返回型別

NamedTuple,其中包含 missing_keysunexpected_keys 欄位

torch.distributed.checkpoint.state_dict.set_model_state_dict(model, model_state_dict, *, options=None)[source]

載入模型狀態字典。

get_model_state_dict 的對應函式,用於將狀態字典設定到模型。有關詳細資訊,請參閱 set_state_dict

引數
  • get_state_dict 也可以處理未並行的模組。在這種情況下,get_state_dict 只執行一個功能——將最佳化器引數 ID 轉換為規範的 FQN。

  • model_state_dict (Dict[str, ValueType]) – (Dict[str, ValueType]): 要載入的模型狀態字典。如果 model_state_dict 的鍵是 nn.Module,則鍵是 model 的子模組,而值應該是子模組的狀態字典。載入狀態字典時,將子模組的字首附加到狀態字典。

  • 選項 (StateDictOptions) – 用於控制模型狀態字典和最佳化器狀態字典載入方式的選項。有關詳細資訊,請參閱 StateDictOptions

返回值

  • missing_keys 是一個包含缺失鍵的 str 列表

  • unexpected_keys 是一個包含意外部索引鍵的 str 列表

返回型別

NamedTuple,其中包含 missing_keysunexpected_keys 欄位

torch.distributed.checkpoint.state_dict.set_optimizer_state_dict(model, optimizers, optim_state_dict, *, options=None)[source]

載入最佳化器狀態字典。

get_optimizer_state_dict 的對應函式,用於將狀態字典設定到最佳化器。有關詳細資訊,請參閱 set_state_dict

引數
  • get_state_dict 也可以處理未並行的模組。在這種情況下,get_state_dict 只執行一個功能——將最佳化器引數 ID 轉換為規範的 FQN。

  • 最佳化器 (Union[最佳化器, Iterable[最佳化器]]) – 用於最佳化 model 的最佳化器。

  • optim_state_dict (OptimizerStateType) – OptimizerStateType: 要載入的最佳化器狀態字典。

  • 選項 (StateDictOptions) – 用於控制模型狀態字典和最佳化器狀態字典載入方式的選項。有關詳細資訊,請參閱 StateDictOptions

返回值

返回型別

class torch.distributed.checkpoint.state_dict.StateDictOptions(full_state_dict=False, cpu_offload=False, ignore_frozen_params=False, keep_submodule_prefixes=True, strict=True, broadcast_from_rank0=False, flatten_optimizer_state_dict=False)[source]

此資料類指定 get_state_dict/set_state_dict 的工作方式。

  • full_state_dict: 如果將其設定為 True,則返回的狀態字典中的所有張量將被收集。返回的狀態字典中不會包含任何 ShardedTensor 和 DTensor。

  • cpu_offload: 將所有張量解除安裝到 CPU。為了防止 CPU 記憶體不足,如果 full_state_dict 也為 True,則只有 rank0 將獲得狀態字典,而所有其他 rank 將獲得空狀態字典。

  • ignore_frozen_params: 如果值為 True,則返回的狀態字典將不包含任何凍結引數 - requires_grad 為 False。預設值為 False。

  • keep_submodule_prefixes (已棄用): 當 submodules 不為 None 時,此選項指示是否保留 state_dict 鍵中的子模組字首。例如,如果子模組為 module.pretrain 且引數的完整 FQN 為 pretrain.layer1.weight。當此選項為 True 時,引數在返回的 state_dict 中的鍵將為 pretrain.layer1.weight。如果選項為 False,則鍵將為 layer1.weight。請注意,如果 keep_submodule_prefixes 為 False,則可能存在衝突的 FQN,因此 submodules 中應只有一個子模組。

  • strict: 當 set_state_dict 呼叫 model.load_state_dict() 時的 strict 選項。

  • broadcast_from_rank0: 當此選項為 True 時,rank0 應接收一個

    完整的 state_dict 並將 state_dict/optim_state_dict 中的張量逐個廣播到其他 rank。其他 rank 將接收這些張量並根據模型和最佳化器中的本地分片進行分片。full_state_dict 在使用此選項時必須設定為 True。此選專案前僅支援 DTensor,不支援舊的 ShardedTensor。

對於習慣於使用和共享 torch.save 格式的模型的使用者,提供了以下方法,這些方法提供了在格式之間轉換的離線實用程式。

torch.distributed.checkpoint.format_utils.dcp_to_torch_save(dcp_checkpoint_dir, torch_save_path)[source]

給定包含 DCP 檢查點的目錄,此函式將其轉換為 Torch 儲存檔案。

引數
  • dcp_checkpoint_dir (Union[str, PathLike]) – 包含 DCP 檢查點的目錄。

  • torch_save_path (Union[str, PathLike]) – 用於儲存轉換後的 Torch 儲存檔案的檔名。

警告

為了避免 OOM,建議僅在一個 rank 上執行此函式。

torch.distributed.checkpoint.format_utils.torch_save_to_dcp(torch_save_path, dcp_checkpoint_dir)[source]

給定 torch 儲存檔案的路徑,將其轉換為 DCP 檢查點。

引數
  • torch_save_path (Union[str, PathLike]) – Torch 儲存檔案的名稱。

  • dcp_checkpoint_dir (Union[str, PathLike]) – 用於儲存 DCP 檢查點的目錄。

警告

為了避免 OOM,建議僅在一個 rank 上執行此函式。

以下類也可以用於從 torch.save 格式線上載入和重新分片模型。

class torch.distributed.checkpoint.format_utils.BroadcastingTorchSaveReader(checkpoint_id=None, coordinator_rank=0)[source]

用於讀取 Torch Save 檔案的 StorageReader。此 reader 將在協調器 rank 上讀取整個檢查點,然後將每個張量廣播並分片到所有 rank。

. 注意,旨在與 DynamicMetaLoadPlanner 一起使用

警告

當前實現僅支援載入張量。

>>> sd = {"mode": model}
>>> dcp.load(
>>>    sd,
>>>    storage_reader=BroadcastingTorchSaveReader(),
>>>    planner=DynamicMetaLoadPlanner(),
>>>    checkpoint_id="path_to_model.pt"
>>> )
prepare_global_plan(global_plan)[source]

StorageReader 方法的實現

返回型別

List[LoadPlan]

prepare_local_plan(plan)[source]

StorageReader 方法的實現

返回型別

LoadPlan

read_data(plan, planner)[source]

在協調器 rank 上讀取 torch save 資料,並在之後廣播,這會產生通訊成本,但避免了在每個 rank 上載入整個檢查點,有望防止 OOM 問題

返回型別

Future[None]

read_metadata()[source]

擴充套件預設 StorageReader 以支援構建元資料檔案

返回型別

Metadata

reset(checkpoint_id=None)[source]

StorageReader 方法的實現

set_up_storage_reader(metadata, is_coordinator)[source]

StorageReader 方法的實現

classmethod validate_checkpoint_id(checkpoint_id)[source]

StorageReader 方法的實現

返回型別

bool

class torch.distributed.checkpoint.format_utils.DynamicMetaLoadPlanner(flatten_state_dict=True, flatten_sharded_tensors=True, allow_partial_load=False)[source]

DefaultLoadPlanner 的擴充套件,它根據傳入的 state dict 建立一個新的 Metadata 物件,避免了從磁碟讀取元資料的需要。這在讀取沒有元資料檔案的格式(如 Torch Save 檔案)時非常有用。

. 注意,旨在與 BroadcastingTorchSaveReader 一起使用

警告

當前實現僅支援載入張量。

>>> sd = {"mode": model}
>>> dcp.load(
>>>    sd,
>>>    storage_reader=BroadcastingTorchSaveReader(),
>>>    planner=DynamicMetaLoadPlanner(),
>>>    checkpoint_id="path_to_model.pt"
>>> )
set_up_planner(state_dict, metadata=None, is_coordinator=False)[source]

規劃器的設定,透過從 state dict 建立 Metadata 物件來擴充套件預設行為

以下實驗介面可用於在生產環境中提高可觀察性

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題的解答

檢視資源