torch.distributed.tensor¶
注意
torch.distributed.tensor 目前處於 alpha 階段,正在開發中。我們致力於為文件中列出的大多數 API 提供向後相容性,但在必要時可能會進行 API 更改。
PyTorch DTensor (分散式張量)¶
PyTorch DTensor 提供了簡單靈活的張量分片原語,可透明地處理分散式邏輯,包括跨裝置/主機的分片儲存、運算元計算和集合通訊。DTensor 可用於構建不同的並行解決方案,並在處理多維分片時支援分片 state_dict 表示。
請參閱基於 DTensor 構建的 PyTorch 原生並行解決方案示例:
DTensor 遵循 SPMD(單程式多資料)程式設計模型,使使用者能夠像編寫具有相同收斂屬性的單裝置程式一樣編寫分散式程式。它透過指定 DeviceMesh 和 Placement 提供統一的張量分片佈局(DTensor Layout)。
DeviceMesh使用 n 維陣列表示叢集的裝置拓撲和通訊器。Placement描述了邏輯張量在DeviceMesh上的分片佈局。DTensor 支援三種類型的 Placement:Shard、Replicate和Partial。
DTensor 類 API¶
DTensor 是 torch.Tensor 的子類。這意味著一旦建立 DTensor,就可以以與 torch.Tensor 非常相似的方式使用它,包括執行不同型別的 PyTorch 運算元,就像在單個裝置上執行一樣,從而實現 PyTorch 運算元的正確分散式計算。
除了現有的 torch.Tensor 方法外,它還提供了一組額外的方法來與 torch.Tensor 互動、將 DTensor Layout redistribute 到新的 DTensor、獲取所有裝置上的完整張量內容等。
- class torch.distributed.tensor.DTensor(local_tensor, spec, *, requires_grad)¶
DTensor(分散式張量)是torch.Tensor的子類,提供類似單裝置的抽象,用於對多裝置torch.Tensor進行程式設計。它透過DeviceMesh和以下型別的Placement來描述分散式張量分片佈局(DTensor Layout):Shard:張量在DeviceMesh維度的裝置上按張量維度dim進行分片Replicate:張量在DeviceMesh維度的裝置上進行復制Partial:張量在DeviceMesh維度的裝置上等待歸約
呼叫 PyTorch 運算元時,
DTensor會覆蓋 PyTorch 運算元以執行分片計算並在必要時發出通訊。除了運算元計算之外,DTensor會正確轉換或傳播 Placement(DTensor Layout)(基於運算元本身的語義)並生成新的DTensor輸出。為了確保呼叫 PyTorch 運算元時
DTensor分片計算的數值正確性,DTensor要求運算元的每個張量引數都是 DTensor。注意
不建議在此處直接使用 Tensor 子類建構函式建立
DTensor(因為它無法正確處理自動微分,因此不是公共 API)。請參閱 create_dtensor 部分以瞭解如何建立DTensor。- 返回型別
- __create_chunk_list__()[source][source]¶
返回一個 ChunkStorageMetadata 列表,這是一個描述當前 rank 上本地分片/複製品的大小/偏移量的資料類。對於 DTensor,每個 rank 將擁有一個本地分片/複製品,因此返回列表通常只有一個元素。
這個 dunder 方法主要用於分散式檢查點目的。
- 返回值
一個 List[
ChunkStorageMetadata] 物件,表示當前 rank 上的分片大小/偏移量。
- static from_local(local_tensor, device_mesh=None, placements=None, *, run_check=False, shape=None, stride=None)[source][source]¶
根據指定的
device_mesh和placements,從每個 rank 上的本地torch.Tensor建立一個DTensor。- 引數
local_tensor (torch.Tensor) – 每個 rank 上的本地 torch.Tensor。
device_mesh (
DeviceMesh, optional) – 放置張量的 DeviceMesh,如果未指定,則必須在 DeviceMesh 上下文管理器下呼叫,預設值:Noneplacements (List[
Placement], optional) – 描述如何將本地 torch.Tensor 放置在 DeviceMesh 上的 placements,必須與device_mesh.ndim具有相同數量的元素。
- 關鍵字引數
run_check (bool, optional) – 以額外通訊為代價,執行跨 rank 的健全性檢查,檢查每個本地張量的元資訊以確保正確性。如果
placements中有Replicate,Device Mesh 維度上第一個 rank 的資料將被廣播到其他 rank。預設值:Falseshape (torch.Size, optional) – 一個整數列表,指定基於 local_tensor 構建的 DTensor 的大小。請注意,如果
local_tensor的形狀在不同 rank 上不同,則需要提供此引數。如果未提供,shape將假設給定分散式張量在 rank 之間均勻分片來計算。預設值:Nonestride (tuple, optional) – 一個整數列表,指定 DTensor 的步長。如果未提供,
stride將假設給定分散式張量在 rank 之間均勻分片來計算。預設值:None
- 返回值
一個
DTensor物件- 返回型別
注意
當
run_check=False時,使用者有責任確保傳入的本地張量在跨 rank 上是正確的(即對於Shard(dim)Placement 張量被分片,或者對於Replicate()Placement 張量被複制)。如果不是,則建立的 DTensor 的行為是未定義的。注意
from_local是可微分的,建立的 DTensor 物件的 requires_grad 將取決於 local_tensor 是否需要梯度。
- full_tensor(*, grad_placements=None)[source][source]¶
返回此 DTensor 的完整張量。它將執行必要的集合操作,從其 DeviceMesh 中的其他 rank 收集本地張量並將它們拼接在一起。它是以下程式碼的語法糖:
dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()- 關鍵字引數
grad_placements (List[
Placement], optional) – placements 描述了此函式返回的完整張量的任何梯度佈局的未來佈局。full_tensor 將 DTensor 轉換為完整的 torch.Tensor,並且返回的 torch.tensor 可能在程式碼後續部分不會像原始複製的 DTensor layout 一樣使用。此引數是使用者可以提供給自動微分的提示,以防返回張量的梯度佈局與原始複製的 DTensor layout 不匹配。如果未指定,我們將假定完整張量的梯度佈局被複制。- 返回值
一個
torch.Tensor物件,表示此 DTensor 的完整張量。- 返回型別
注意
full_tensor是可微分的。
- redistribute(device_mesh=None, placements=None, *, async_op=False)[source][source]¶
redistribute執行必要的集合操作,將當前 DTensor 從其當前 placements 重新分佈到新的 placements,或者從其當前 DeviceMesh 重新分佈到新的 DeviceMesh。例如,我們可以透過為 DeviceMesh 的每個維度指定 Replicate Placement,將一個 Sharded DTensor 轉換為 Replicated DTensor。當從一個 Device Mesh 維度上的當前 placements 重新分佈到新的 placements 時,我們將執行以下操作,包括通訊集合或區域性操作:
Shard(dim)->Replicate():all_gatherShard(src_dim)->Shard(dst_dim):all_to_allReplicate()->Shard(dim): 區域性分塊(例如torch.chunk)Partial()->Replicate():all_reducePartial()->Shard(dim):reduce_scatter
redistribute會正確地找出針對 1-D 或 N-D DeviceMesh 上建立的 DTensor 的必要重新分佈步驟。- 引數
device_mesh (
DeviceMesh, optional) – 放置 DTensor 的 DeviceMesh。如果未指定,將使用當前 DTensor 的 DeviceMesh。預設值:Noneplacements (List[
Placement], optional) – 描述如何將 DTensor 放置到 DeviceMesh 中的新 placements,必須與device_mesh.ndim具有相同數量的元素。預設值:在所有 Device Mesh 維度上覆制
- 關鍵字引數
async_op (bool, optional) – 是否非同步執行 DTensor 重新分佈操作。預設值:False
- 返回值
一個
DTensor物件- 返回型別
注意
redistribute是可微分的,這意味著使用者無需擔心重新分佈操作的向後公式。注意
redistribute目前僅支援在同一 DeviceMesh 上重新分佈 DTensor。如果您需要將 DTensor 重新分佈到不同的 DeviceMesh,請提交一個 issue。
- to_local(*, grad_placements=None)[source][source]¶
獲取此 DTensor 在其當前 rank 上的本地張量。對於分片(sharding),它返回邏輯張量檢視的本地分片(local shard),對於複製(replication),它返回其當前 rank 上的複製品(replica)。
- 關鍵字引數
grad_placements (List[
Placement], optional) – placements 描述了此函式返回的張量的任何梯度佈局的未來佈局。to_local 將 DTensor 轉換為本地張量,並且返回的本地張量在程式碼後續部分可能不會像原始 DTensor layout 一樣使用。此引數是使用者可以提供給自動微分的提示,以防返回張量的梯度佈局與原始 DTensor layout 不匹配。如果未指定,我們將假定梯度佈局與原始 DTensor 保持相同,並將其用於梯度計算。- 返回值
一個
torch.Tensor或AsyncCollectiveTensor物件。它表示在其當前 rank 上的本地張量。當返回AsyncCollectiveTensor物件時,意味著本地張量尚未準備好(即通訊尚未完成)。在這種情況下,使用者需要呼叫wait來等待本地張量準備好。- 返回型別
注意
to_local是可微分的,返回的本地張量的requires_grad將取決於 DTensor 是否需要梯度。
- property device_mesh: DeviceMesh¶
與此 DTensor 物件關聯的
DeviceMesh屬性。注意
device_mesh是隻讀屬性,無法設定。
- property placements: tuple[torch.distributed.tensor.placement_types.Placement, ...]¶
此 DTensor 的 placements 屬性,描述了此 DTensor 在其 DeviceMesh 上的佈局。
注意
placements是隻讀屬性,無法設定。
作為分散式通訊器的 DeviceMesh¶
DeviceMesh 是從 DTensor 構建的抽象,用於描述叢集的裝置拓撲並表示多維通訊器(基於 ProcessGroup)。有關如何建立/使用 DeviceMesh 的詳細資訊,請參閱 DeviceMesh recipe。
DTensor Placement 型別¶
DTensor 支援在每個 DeviceMesh 維度上使用以下型別的 Placement:
- class torch.distributed.tensor.placement_types.Shard(dim)[source][source]¶
Shard(dim)Placement 描述了 DTensor 在對應DeviceMesh維度上按張量維度dim進行的分片,其中 DeviceMesh 維度上的每個 rank 僅持有全域性張量的一個分片/部分。Shard(dim)Placement 遵循torch.chunk(dim)的語義,當張量維度不能在 DeviceMesh 維度上均勻整除時,DeviceMesh 維度上的最後幾個分片可能是空的。ShardPlacement 可以用於所有 DTensor API(例如distribute_tensor、from_local等)。- 引數
dim (int) – 描述 DTensor 在其對應 DeviceMesh 維度上分片的張量維度。
警告
在張量維度大小不能被 DeviceMesh 維度均勻整除的情況下在該張量維度上進行分片,目前是實驗性的且可能會發生變化。
- class torch.distributed.tensor.placement_types.Replicate[source][source]¶
Replicate()佈局描述了 DTensor 在對應的DeviceMesh維度上進行復制,在該 DeviceMesh 維度的每個程序上都持有一個全域性張量的副本。所有 DTensor API(例如distribute_tensor、DTensor.from_local等)都可以使用Replicate佈局。
- class torch.distributed.tensor.placement_types.Partial(reduce_op='sum')[source][source]¶
Partial(reduce_op)佈局描述了 DTensor 在指定的DeviceMesh維度上等待歸約,在該 DeviceMesh 維度的每個程序上都持有全域性張量的部分值。使用者可以使用redistribute將PartialDTensor 重新分佈到指定的DeviceMesh維度上的Replicate或Shard(dim)佈局,這將觸發底層的必要通訊操作(例如allreduce、reduce_scatter)。- 引數
reduce_op (str, optional) – 用於將 partial DTensor 歸約生成 Replicated/Sharded DTensor 的歸約操作。僅支援逐元素歸約操作,包括:“sum”、“avg”、“product”、“max”、“min”,預設值:“sum”。
注意
Partial佈局可以作為 DTensor 運算元的結果生成,並且只能由DTensor.from_localAPI 使用。
建立 DTensor 的不同方法¶
- 構造
DTensor有三種方法 distribute_tensor()從每個程序上的邏輯或“全域性”torch.Tensor建立一個DTensor。這可用於對葉子torch.Tensor(例如模型引數/buffers 和輸入)進行分片。DTensor.from_local()從每個程序上的本地torch.Tensor建立一個DTensor,這可用於從非葉子torch.Tensor(例如前向/後向過程中的中間啟用張量)建立DTensor。DTensor 提供了專用的張量工廠函式(例如
empty()、ones()、randn()等),允許透過直接指定DeviceMesh和Placement來建立不同的DTensor。與distribute_tensor()相比,這可以直接在裝置上具體化分片記憶體,而無需在初始化邏輯張量記憶體後再執行分片。
從邏輯 torch.Tensor 建立 DTensor¶
torch.distributed 中的 SPMD(單程式多資料)程式設計模型啟動多個程序(例如透過 torchrun)執行相同的程式,這意味著程式內的模型將首先在不同程序上初始化(例如模型可能初始化在 CPU 上、meta 裝置上,或者如果記憶體足夠直接初始化在 GPU 上)。
DTensor 提供了一個 distribute_tensor() API,可以將模型權重或張量分片到 DTensor 中,它會從每個程序上的“邏輯”張量建立一個 DTensor。這將使建立的 DTensor 遵守單裝置語義,這對於數值正確性至關重要。
- torch.distributed.tensor.distribute_tensor(tensor, device_mesh=None, placements=None, *, src_data_rank=0)[source]¶
根據指定的
placements將葉子torch.Tensor(例如 nn.Parameter/buffers)分佈到device_mesh。device_mesh和placements的秩(rank)必須相同。要分佈的tensor是邏輯或“全域性”張量,API 將使用 DeviceMesh 維度第一個程序的tensor作為真實來源,以保留單裝置語義。如果你想在 Autograd 計算過程中構造一個 DTensor,請改用DTensor.from_local()。- 引數
tensor (torch.Tensor) – 要分佈的 torch.Tensor。請注意,如果你想在一個維度上進行分片,而該維度的大小不能被該 mesh 維度上的裝置數量均勻整除,我們將使用
torch.chunk語義來分片張量並分散分片。不均勻分片的行為是實驗性的,可能會發生變化。device_mesh (
DeviceMesh, optional) – 用於分佈張量的 DeviceMesh,如果未指定,必須在 DeviceMesh 上下文管理器下呼叫,預設值:Noneplacements (List[
Placement], optional) – 描述如何在 DeviceMesh 上放置張量的佈局列表,其元素數量必須與device_mesh.ndim相同。如果未指定,預設情況下,我們將從 device_mesh 的每個維度的第一個程序處複製張量到整個 device_mesh。
- 關鍵字引數
src_data_rank (int, optional) – 邏輯/全域性張量的源資料程序。
distribute_tensor()使用它來向其他程序分散/廣播分片/副本。預設情況下,我們使用每個 DeviceMesh 維度上的group_rank=0作為源資料,以保留單裝置語義。如果顯式傳遞None,則distribute_tensor()只使用其本地資料,而不是嘗試透過分散/廣播保留單裝置語義。預設值:0- 返回值
一個
DTensor或XLAShardedTensor物件。- 返回型別
注意
當使用
xla裝置型別初始化 DeviceMesh 時,distribute_tensor將返回XLAShardedTensor。有關更多詳細資訊,請參見 此 issue。XLA 整合是實驗性的,可能會發生變化。
除了 distribute_tensor() 之外,DTensor 還提供了 distribute_module() API,以便在 nn.Module 級別上更容易地進行分片
- torch.distributed.tensor.distribute_module(module, device_mesh=None, partition_fn=None, input_fn=None, output_fn=None)[source]¶
該函式暴露了三個函式來控制 module 的 parameters/inputs/outputs
1. 透過指定
partition_fn在執行時執行之前對 module 執行分片(例如允許使用者根據指定的partition_fn將 Module 引數轉換為DTensor引數)。 2. 透過指定input_fn和output_fn在執行時執行期間控制 module 的輸入或輸出(例如將輸入轉換為DTensor,將輸出轉換回torch.Tensor)。- 引數
module (
nn.Module) – 使用者要進行分割槽的 module。device_mesh (
DeviceMesh) – 用於放置 module 的裝置 mesh。partition_fn (Callable) – 分割槽引數的函式(例如將某些引數分片到
device_mesh上)。如果未指定partition_fn,預設情況下我們將 module 的所有 module 引數複製到整個 mesh 上。input_fn (Callable) – 指定輸入分佈,例如可以控制 module 的輸入如何分片。
input_fn將作為 module 的forward_pre_hook(前向預處理鉤子)安裝。output_fn (Callable) – 指定輸出分佈,例如可以控制輸出如何分片,或將其轉換回 torch.Tensor。
output_fn將作為 module 的forward_hook(前向處理鉤子)安裝。
- 返回值
一個包含所有引數/buffers 都是
DTensor的 module。- 返回型別
注意
當使用
xla裝置型別初始化 DeviceMesh 時,distribute_module將返回帶有 PyTorch/XLA SPMD 註解引數的 nn.Module。有關更多詳細資訊,請參見 此 issue。XLA 整合是實驗性的,可能會發生變化。
DTensor 工廠函式¶
DTensor 還提供了專用的張量工廠函式,允許透過直接使用類似 torch.Tensor 工廠函式 API(例如 torch.ones、torch.empty 等)建立 DTensor,方法是額外指定建立的 DTensor 的 DeviceMesh 和 Placement。與 distribute_tensor() 相比,這可以直接在裝置上具體化分片記憶體,而無需在初始化邏輯張量記憶體後再執行分片。
- torch.distributed.tensor.zeros(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)[source]¶
返回一個填充標量值 0 的
DTensor。- 引數
size (int...) – 定義輸出
DTensor形狀的整數序列。可以是可變數量的引數或列表、元組等集合。例如:zeros(1,2,3..) 或 zeros([1,2,3..]) 或 zeros((1,2,3..))- 關鍵字引數
requires_grad (bool, optional) – 如果 autograd 應該記錄返回的
DTensor上的操作。預設值:False。dtype (
torch.dtype, optional) – 返回的DTensor所需的資料型別。預設值:如果為None,則使用全域性預設值(參見torch.set_default_dtype())。layout (
torch.layout, optional) – 返回的DTensor所需的佈局。預設值:torch.strided。device_mesh –
DeviceMesh型別,包含程序的 mesh 資訊placements –
Placement型別的序列:Shard、Replicate
- 返回值
每個程序上的
DTensor物件- 返回型別
- torch.distributed.tensor.ones(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)[source]¶
返回一個填充標量值 1 的
DTensor,其形狀由可變引數size定義。- 引數
size (int...) – 定義輸出
DTensor形狀的整數序列。可以是可變數量的引數或列表、元組等集合。例如:ones(1,2,3..) 或 ones([1,2,3..]) 或 ones((1,2,3..))- 關鍵字引數
dtype (
torch.dtype, optional) – 返回的DTensor所需的資料型別。預設值:如果為None,則使用全域性預設值(參見torch.set_default_dtype())。layout (
torch.layout, optional) – 返回的 DTensor 所需的佈局。預設值:torch.strided。requires_grad (bool, optional) – 如果 autograd 應該記錄返回的
DTensor上的操作。預設值:False。device_mesh –
DeviceMesh型別,包含程序的 mesh 資訊placements –
Placement型別的序列:Shard、Replicate
- 返回值
每個程序上的
DTensor物件- 返回型別
- torch.distributed.tensor.empty(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)[source]¶
返回一個填充未初始化資料的
DTensor。DTensor的形狀由可變引數size定義。- 引數
size (int...) – 定義輸出
DTensor形狀的整數序列。可以是可變數量的引數或列表、元組等集合。例如:empty(1,2,3..) 或 empty([1,2,3..]) 或 empty((1,2,3..))- 關鍵字引數
dtype (
torch.dtype, optional) – 返回的DTensor所需的資料型別。預設值:如果為None,則使用全域性預設值(參見torch.set_default_dtype())。layout (torch.layout, optional):返回的DTensor所需的佈局。預設值:torch.strided。requires_grad (bool, optional) – 如果 autograd 應該記錄返回的
DTensor上的操作。預設值:False。device_mesh –
DeviceMesh型別,包含程序的 mesh 資訊placements –
Placement型別的序列:Shard、Replicate
- 返回值
每個程序上的
DTensor物件- 返回型別
- torch.distributed.tensor.full(size, fill_value, *, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)[source]¶
根據
device_mesh和placements返回一個填充fill_value的DTensor,其形狀由引數size定義。- 引數
- 關鍵字引數
dtype (
torch.dtype, optional) – 返回的DTensor所需的資料型別。預設值:如果為None,則使用全域性預設值(參見torch.set_default_dtype())。layout (
torch.layout, optional) – 返回的 DTensor 所需的佈局。預設值:torch.strided。requires_grad (bool, optional) – 如果 autograd 應該記錄返回的
DTensor上的操作。預設值:False。device_mesh –
DeviceMesh型別,包含程序的 mesh 資訊。placements –
Placement型別的序列:Shard、Replicate
- 返回值
每個程序上的
DTensor物件- 返回型別
- torch.distributed.tensor.rand(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)[source]¶
返回一個填充在區間
[0, 1)內均勻分佈的隨機數的DTensor。張量的形狀由可變引數size定義。- 引數
size (int...) – 定義輸出
DTensor形狀的整數序列。可以是可變數量的引數或列表、元組等集合。例如:ones(1,2,3..) 或 ones([1,2,3..]) 或 ones((1,2,3..))- 關鍵字引數
dtype (
torch.dtype, optional) – 返回的DTensor所需的資料型別。預設值:如果為None,則使用全域性預設值(參見torch.set_default_dtype())。layout (
torch.layout, optional) – 返回的 DTensor 所需的佈局。預設值:torch.strided。requires_grad (bool, optional) – 如果 autograd 應該記錄返回的
DTensor上的操作。預設值:False。device_mesh –
DeviceMesh型別,包含程序的 mesh 資訊。placements –
Placement型別的序列:Shard、Replicate
- 返回值
每個程序上的
DTensor物件- 返回型別
- torch.distributed.tensor.randn(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)[source]¶
返回一個
DTensor,其中填充了來自均值為 0、方差為 1 的正態分佈的隨機數。張量的形狀由可變引數size定義。- 引數
size (int...) – 定義輸出
DTensor形狀的整數序列。可以是可變數量的引數或列表、元組等集合。例如:ones(1,2,3..) 或 ones([1,2,3..]) 或 ones((1,2,3..))- 關鍵字引數
dtype (
torch.dtype, optional) – 返回的DTensor所需的資料型別。預設值:如果為None,則使用全域性預設值(參見torch.set_default_dtype())。layout (
torch.layout, optional) – 返回的 DTensor 所需的佈局。預設值:torch.strided。requires_grad (bool, optional) – 如果 autograd 應該記錄返回的
DTensor上的操作。預設值:False。device_mesh –
DeviceMesh型別,包含程序的 mesh 資訊。placements –
Placement型別的序列:Shard、Replicate
- 返回值
每個程序上的
DTensor物件- 返回型別
除錯¶
日誌記錄¶
啟動程式時,可以使用來自 torch._logging 的 TORCH_LOGS 環境變數開啟附加日誌記錄。
TORCH_LOGS=+dtensor 將顯示 logging.DEBUG 級別的訊息及其以上的所有級別。
TORCH_LOGS=dtensor 將顯示 logging.INFO 級別的訊息及其以上級別。
TORCH_LOGS=-dtensor 將顯示 logging.WARNING 級別的訊息及其以上級別。
除錯工具¶
為了除錯應用了 DTensor 的程式,並瞭解底層發生了哪些集合通訊操作的更多細節,DTensor 提供了一個 CommDebugMode。
- class torch.distributed.tensor.debug.CommDebugMode¶
CommDebugMode是一個上下文管理器,用於計算其上下文中功能性集合通訊操作的數量。它透過使用TorchDispatchMode來實現此功能。注意
並非所有集合通訊操作都已支援。
示例用法
mod = ... comm_mode = CommDebugMode() with comm_mode: mod.sum().backward() print(comm_mode.get_comm_counts())
- generate_comm_debug_tracing_table(noise_level=3)[source][source]¶
生成詳細表格,顯示模組級別的操作和集合通訊跟蹤資訊。資訊量取決於 noise_level。
列印模組級別的集合通訊計數
列印未包含在平凡操作中的 DTensor 操作,模組資訊
列印未包含在平凡操作中的操作
列印所有操作
為了視覺化維度少於 3 的 DTensor 的分片 (sharding),DTensor 提供了 visualize_sharding()。
實驗性功能¶
DTensor 還提供了一系列實驗性功能。這些功能要麼處於原型階段,要麼基本功能已完成但正在尋求使用者反饋。如果您對這些功能有反饋,請向 PyTorch 提交一個 issue。
- torch.distributed.tensor.experimental.context_parallel(mesh, *, buffers=None, buffer_seq_dims=None, no_restore_buffers=None)[source]¶
context_parallel是一個實驗性 API,用於啟用上下文並行 (CP)。此 API 執行兩個操作:1) 使用啟用 CP 的版本修補 SDPA (torch.nn.functional.scaled_dot_product_attention),2) 沿序列維度對buffers進行分片,並且每個 rank 將根據mesh保留相應的分片。- 引數
mesh (
DeviceMesh) – 用於上下文並行的裝置網格。buffers (Optional[List[torch.Tensor]]) – 使用依賴於序列維度的緩衝區。例如輸入批次、標籤和位置嵌入緩衝區。這些緩衝區必須沿序列維度進行分片以確保準確性。分片將原地進行,緩衝區的形狀將在上下文中改變。緩衝區將在上下文結束後恢復。
no_restore_buffers可用於指定哪些緩衝區不需要恢復。請注意,buffers不應包含任何 nn.Parameter。buffer_seq_dims (Optional[List[int]]) –
buffers的序列維度。no_restore_buffers (Optional[Set[torch.Tensor]]) – 此集合中的緩衝區在上下文退出後不會被恢復。此集合必須是
buffers的子集。如果在上下文退出後不再使用這些緩衝區,可以將它們放在此列表中以避免額外的恢復時間。
- 返回型別
Generator[None, None, None]
警告
torch.distributed._tensor.experimental.attention.context_parallel 是 PyTorch 中的一個原型功能。該 API 可能隨時更改。
- torch.distributed.tensor.experimental.local_map(func, out_placements, in_placements=None, device_mesh=None, *, redistribute_inputs=False)[source]¶
local_map()是一個實驗性 API,允許使用者將DTensor傳遞給一個旨在應用於torch.Tensor的函式。其實現方式是提取DTensor的本地分量,呼叫函式,然後根據out_placements將輸出包裝成DTensor。- 引數
func (Callable) – 應用於
DTensor的每個本地分片上的函式。out_placements (Union[PlacementType, Tuple[PlacementType, …]]) –
func的展平輸出中DTensor的期望佈局。如果展平output是單個值,則out_placements應為 PlacementType 型別。否則,如果展平output有多個值,out_placements應為 PlacementType 值的元組,與展平output一一對應。此外,對於Tensor輸出,我們使用 PlacementType 作為其佈局(一個 Tuple[Placement] 值)。對於非 Tensor 輸出,PlacementType 應為 None。請注意,唯一的例外是沒有傳入DTensor引數時。在這種情況下,即使 out_placements 不是 None,結果函式也應忽略期望的佈局,因為該函式不是在DTensor環境下執行的。in_placements (Tuple[PlacementType, …], optional) –
func的展平輸入中DTensor的所需佈局。如果指定了in_placements,local_map()將檢查每個DTensor引數的佈局是否與所需佈局相同。如果佈局不同且redistribute_inputs為False,將引發異常。否則,如果redistribute_inputs為True,引數將首先被重新分發到所需的 sharding 佈局,然後才將其本地張量傳遞給func。唯一的例外是所需佈局不為None且引數是torch.Tensor。在這種情況下,將跳過佈局檢查,引數將直接傳遞給func。如果in_placements為None,將不執行佈局檢查。預設值:Nonedevice_mesh (
DeviceMesh, optional) – 所有DTensor所在的裝置網格。如果未指定,將從輸入DTensor的裝置網格中推斷。local_map 要求所有DTensor位於同一個裝置網格上。預設值:None。redistribute_inputs (bool, optional) – 布林值,指示當輸入
DTensor的佈局與所需輸入佈局不同時是否重新分片。如果此值為False並且某個DTensor輸入具有不同的佈局,將引發異常。預設值:False。
- 返回值
一個
Callable,它將func應用於輸入DTensor的每個本地分片,並返回由func返回值構建的DTensor。- 引發
AssertionError – 如果輸入
DTensor不在同一個裝置網格上,或者如果它們所在的裝置網格與傳入的device_mesh引數不同。AssertionError – 對於任何非 DTensor 輸出,我們要求其在
out_placements中對應的輸出佈局為 None。如果不是這種情況,將引發 AssertionError。ValueError – 如果
redistribute_inputs=False但輸入DTensor需要根據in_placements進行重新分發。
示例
>>> def mm_allreduce_forward(device_mesh, W, X): >>> partial_sum_tensor = torch.mm(W, X) >>> reduced_tensor = funcol.all_reduce(partial_sum_tensor, "sum", device_mesh) >>> return reduced_tensor >>> >>> W = torch.randn(12, 8, requires_grad=False) >>> X = torch.randn(8, 16, requires_grad=False) >>> Y = torch.mm(W, X) >>> row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh >>> col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh >>> >>> # local_mm_allreduce_forward is the function wrapped with DTensor/Tensor convertion >>> local_mm_allreduce_forward = local_map( >>> mm_allreduce_forward, >>> out_placements=[Replicate()], >>> in_placements=[col_wise, row_wise], >>> device_mesh=device_mesh, >>> ) >>> >>> W_dt = distribute_tensor( ... W, device_mesh, (col_wise) ... ) # col-wisely sharded W tensor >>> X_dt = distribute_tensor( ... X, device_mesh, (row_wise) ... ) # row-wisely sharded X tensor >>> Y_dt = local_mm_allreduce_forward( ... device_mesh, W_dt, X_dt ... ) # apply local_mm_allreduce_forward to DTensors
注意
此 API 目前是實驗性的,可能會隨時更改
- torch.distributed.tensor.experimental.register_sharding(op)[source]¶
register_sharding()是一個實驗性 API,允許使用者在張量輸入和輸出是 DTensor 時為運算子註冊分片策略。在以下情況下可能很有用:(1) 對於op不存在預設分片策略,例如當op是DTensor不支援的自定義運算子時;(2) 當用戶想要覆蓋現有運算子的預設分片策略時。- 引數
op (Union[OpOverload, List[OpOverload]]) – 要註冊自定義分片函式的運算子或運算子列表。
- 返回值
一個函式裝飾器,可用於包裝一個函式,該函式定義了
op中指定運算子的分片策略。定義的分片策略將註冊到 DTensor,如果 DTensor 已實現了該運算子,則會覆蓋預設分片策略。定製的分片函式接受與原始op相同的輸入(除了如果引數是torch.Tensor,它將被替換為 DTensor 內部使用的張量狀物件)。該函式應返回一個 2 元組序列,每個元組指定可接受的輸出佈局及其對應的輸入佈局。
示例
>>> @register_sharding(aten._softmax.default) >>> def custom_softmax_sharding(x, dim, half_to_float): >>> softmax_dim = dim if dim >= 0 else dim + x.ndim >>> acceptable_shardings = [] >>> >>> all_replicate = ([Replicate()], [Replicate(), None, None]) >>> acceptable_shardings.append(all_replicate) >>> >>> for sharding_dim in range(x.ndim): >>> if sharding_dim != softmax_dim: >>> all_sharded = ( >>> [Shard(sharding_dim)], >>> [Shard(sharding_dim), None, None], >>> ) >>> acceptable_shardings.append(all_sharded) >>> >>> return acceptable_shardings
注意
此 API 目前是實驗性的,可能會隨時更改