張量平行 - torch.distributed.tensor.parallel¶
張量平行 (TP) 建構在 PyTorch DistributedTensor (DTensor) 之上,並提供不同的平行樣式:列式、行式和序列平行。
警告
張量平行 API 仍處於實驗階段,可能會有所變更。
使用張量平行來平行化 nn.Module 的進入點是
- torch.distributed.tensor.parallel.parallelize_module(module, device_mesh, parallelize_plan)[原始碼]¶
根據使用者指定的計畫,透過平行化模組或子模組,在 PyTorch 中應用張量平行。
我們根據 parallelize_plan 平行化模組或子模組。parallelize_plan 包含
ParallelStyle,它表示使用者希望如何平行化模組或子模組。使用者也可以針對每個模組完整合格名稱 (FQN) 指定不同的平行樣式。
請注意,
parallelize_module只接受一維DeviceMesh,如果您有二維或 N 維DeviceMesh,請先將 DeviceMesh 切片成一維子 DeviceMesh,然後再傳遞給此 API(例如device_mesh["tp"])。- 參數
module (
nn.Module) – 要平行化的模組。device_mesh (
DeviceMesh) – 描述 DTensor 裝置網格拓撲的物件。parallelize_plan (Union[
ParallelStyle, Dict[str,ParallelStyle]]) – 用於平行化模組的計畫。它可以是包含我們如何為張量平行準備輸入/輸出的ParallelStyle物件,也可以是模組 FQN 及其對應ParallelStyle物件的字典。
- 傳回值
平行化的
nn.Module物件。- 傳回類型
- 範例:
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel >>> from torch.distributed.device_mesh import init_device_mesh >>> >>> # Define the module. >>> m = Model(...) >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> m = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel(), "w2": RowwiseParallel()}) >>>
備註
對於像注意力、MLP 層這樣的複雜模組架構,我們建議將不同的 ParallelStyles 組合在一起(例如
ColwiseParallel和RowwiseParallel),並作為 parallelize_plan 傳遞,以實現所需的切片計算。
張量平行支援以下平行樣式
- class torch.distributed.tensor.parallel.ColwiseParallel(*, input_layouts=None, output_layouts=None, use_local_output=True)[原始碼]¶
以列式方式分割相容的 nn.Module。目前支援 nn.Linear 和 nn.Embedding。使用者可以將其與 RowwiseParallel 組合在一起,以實現更複雜模組的切片。(例如 MLP、注意力)
- 關鍵字引數
input_layouts (Placement, 選用) – nn.Module 輸入張量的 DTensor 配置,用於註釋輸入張量以成為 DTensor。如果未指定,我們假設輸入張量會被複製。
output_layouts (Placement, 選用) – nn.Module 輸出的 DTensor 配置,用於確保 nn.Module 的輸出具有使用者想要的配置。如果未指定,則輸出張量會在最後一個維度上切片。
use_local_output (bool, 選用) – 是否針對模組輸出使用本機
torch.Tensor而不是DTensor,預設值:True。
- 傳回值
一個表示 nn.Module 的 Colwise 分片(sharding)的
ParallelStyle物件。
- 範例:
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> m = Model(...) # m is a nn.Module that contains a "w1" nn.Linear submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # By default, the input of the "w1" Linear will be converted to Replicated DTensor >>> # and the output of "w1" will return :class:`torch.Tensor` that shards on the last dim. >>> >>> sharded_mod = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel()}) >>> ...
備註
預設情況下,如果未指定
output_layouts,ColwiseParallel輸出會在最後一個維度上進行分片,如果存在需要特定張量形狀的運算子(例如,在配對的RowwiseParallel之前),請注意,如果輸出已分片,則運算子可能需要調整為分片大小。
- class torch.distributed.tensor.parallel.RowwiseParallel(*, input_layouts=None, output_layouts=None, use_local_output=True)[source]¶
以行式方式分割相容的 nn.Module。目前支援 nn.Linear 和 nn.Embedding。使用者可以將其與 ColwiseParallel 組合使用,以實現更複雜模組(例如 MLP、Attention)的分片。
- 關鍵字引數
input_layouts (Placement, 可選) – nn.Module 輸入張量的 DTensor 布局,用於標註輸入張量以成為 DTensor。如果未指定,我們假設輸入張量在最後一個維度上進行分片。
output_layouts (Placement, 可選) – nn.Module 輸出的 DTensor 布局,用於確保 nn.Module 的輸出具有使用者所需的布局。如果未指定,則複製輸出張量。
use_local_output (bool, 選用) – 是否針對模組輸出使用本機
torch.Tensor而不是DTensor,預設值:True。
- 傳回值
一個表示 nn.Module 的 Rowwise 分片的
ParallelStyle物件。
- 範例:
>>> from torch.distributed.tensor.parallel import parallelize_module, RowwiseParallel >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> m = Model(...) # m is a nn.Module that contains a "w2" nn.Linear submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # By default, the input of the "w2" Linear will be converted to DTensor that shards on the last dim >>> # and the output of "w2" will return a replicated :class:`torch.Tensor`. >>> >>> sharded_mod = parallelize_module(m, tp_mesh, {"w2": RowwiseParallel()}), >>> ...
- class torch.distributed.tensor.parallel.SequenceParallel(*, sequence_dim=1, use_local_output=False)[source]¶
SequenceParallel 複製相容的
nn.Module參數,並使用在序列維度上分片的輸入運行分片計算。目前支援nn.LayerNorm、nn.Dropout和 RMSNorm Python 實作此風格實現了論文 Reducing Activation Recomputation in Large Transformer Models 中描述的操作。
nn.Module的輸入和輸出都將在序列維度上進行分片。- 關鍵字引數
sequence_dim (int, 可選) –
nn.Module輸入張量的序列維度,用於標註輸入張量以成為在序列維度上分片的 DTensor,預設值:1。use_local_output (bool, 可選) – 是否對模組輸出使用本地
torch.Tensor而不是DTensor,預設值:False。
- 傳回值
一個表示
nn.Module的序列平行(Sequence Parallel)的ParallelStyle物件。
- 範例:
>>> from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> m = Model(...) # m is a nn.Module that contains a "norm" nn.LayerNorm submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # By default, the input of the "norm" will be converted to DTensor that shards on the sequence dim >>> # and the output of "norm" will return a sharded on sequence dimension :class:`DTensor`. >>> >>> sharded_mod = parallelize_module(m, tp_mesh, {"norm": SequenceParallel()}), >>> ...
備註
如果 nn.Module 中存在權重(例如
nn.LayerNorm或RMSNorm,它們預設使用 ones 初始化),SequenceParallel 風格會假設使用 ones 初始化。如果您對這些模組上的權重有自定義初始化,則需要在平行化之前/之後廣播權重,以確保它們被複製。
要使用 DTensor 布局簡單配置 nn.Module 的輸入和輸出,並執行必要的布局重新分配,而無需將模組參數分發到 DTensor,可以在調用 parallelize_module 時,在 parallelize_plan 中使用以下 ParallelStyle:
- class torch.distributed.tensor.parallel.PrepareModuleInput(*, input_layouts=None, desired_input_layouts=None, input_kwarg_layouts=None, desired_input_kwarg_layouts=None, use_local_output=False)[source]¶
配置 nn.Module 的輸入,以在運行時根據
input_layouts將 nn.Module 的輸入張量轉換為 DTensor,並根據desired_input_layouts執行布局重新分配。- 關鍵字引數
input_layouts (Union[Placement, Tuple[Optional[Placement]]]) – nn.Module 輸入張量的 DTensor 布局,用於將輸入張量轉換為 DTensor。如果某些輸入不是 torch.Tensor 或不需要轉換為 DTensor,則需要將
None指定為佔位符。預設值:None。desired_input_layouts (Union[Placement, Tuple[Optional[Placement]]]) – nn.Module 輸入張量的期望 DTensor 布局,用於確保 nn.Module 的輸入具有期望的 DTensor 布局。此參數的長度需要與
input_layouts相同。預設值:None。input_kwarg_layouts (Dict[str, Placement]) – nn.Module 輸入 kwargs 的 DTensor 布局,用於將輸入 kwarg 張量轉換為 DTensor。預設值:None。
desired_input_kwarg_layouts – (Dict[str, Placement]):nn.Module 輸入 kwargs 的期望 DTensor 布局,用於確保 nn.Module 的輸入具有期望的 DTensor 布局。預設值:None。
use_local_output (bool, 可選) – 是否對模組輸入使用本地
torch.Tensor而不是DTensor,預設值:False。
- 傳回值
一個準備 nn.Module 輸入的分片布局的
ParallelStyle物件。
- 範例:
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> block = TransformerBlock(...) # block is a nn.Module that contains an "attn" Attention submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # According to the style specified below, the first input of attn will be annotated to Sharded DTensor >>> # and then redistributed to Replicated DTensor. >>> parallelize_module( >>> block, # this can be a submodule or module >>> tp_mesh, >>> parallelize_plan={ >>> "attn": PrepareModuleInput( >>> input_layouts=(Shard(0), None, None, ...), >>> desired_input_layouts=(Replicate(), None, None, ...) >>> ), >>> } >>> )
- class torch.distributed.tensor.parallel.PrepareModuleOutput(*, output_layouts, desired_output_layouts, use_local_output=True)[source]¶
配置 nn.Module 的輸出,以在運行時根據
output_layouts將 nn.Module 的輸出張量轉換為 DTensor,並根據desired_output_layouts執行布局重新分配。- 關鍵字引數
output_layouts (Union[Placement, Tuple[Placement]]) – nn.Module 輸出張量的 DTensor 布局,用於在輸出張量為
torch.Tensor時將其轉換為 DTensor。如果某些輸出不是 torch.Tensor 或不需要轉換為 DTensor,則需要將None指定為佔位符。desired_output_layouts (Union[Placement, Tuple[Placement]]) – nn.Module 輸出張量的期望 DTensor 布局,用於確保 nn.Module 的輸出具有期望的 DTensor 布局。
use_local_output (bool, 可選) – 是否對模組輸出使用本地
torch.Tensor而不是DTensor,預設值:True。
- 傳回值
一個準備 nn.Module 輸出的分片布局的 ParallelStyle 物件。
- 範例:
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleOutput >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> block = TransformerBlock(...) # block is a nn.Module that contains an "attn" Attention submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # According to the style specified below, the output of the TransformerBlock will be converted to Replicated DTensor >>> # and then redistributed to Sharded DTensor. >>> parallelize_module( >>> block, # this can be a submodule or module >>> tp_mesh, >>> parallelize_plan = PrepareModuleOutput( >>> output_layouts=Replicate(), >>> desired_output_layouts=Shard(0) >>> ) >>> )
備註
當對上述 ParallelStyle 使用 Shard(dim) 作為輸入/輸出布局時,我們假設輸入/輸出激活張量在 TP 運行的 DeviceMesh 上的張量維度 dim 上均勻分片。例如,由於 RowwiseParallel 接受在最後一個維度上分片的輸入,因此它假設輸入張量已經在最後一個維度上均勻分片。對於不均勻分片的激活張量,可以將 DTensor 直接傳遞給分區的模組,並使用 use_local_output=False 在每個 ParallelStyle 之後返回 DTensor,其中 DTensor 可以跟踪不均勻的分片信息。
對於像 Transformer 這樣的模型,我們建議用戶在 parallelize_plan 中同時使用 ColwiseParallel 和 RowwiseParallel,以實現整個模型(例如 Attention 和 MLP)的所需分片。
通過以下上下文管理器支持平行化交叉熵損失計算(損失平行化):
- torch.distributed.tensor.parallel.loss_parallel()[原始碼]¶
一個啟用損失平行化的上下文管理器,當輸入在類別維度上進行分片時,可以使用此管理器來執行高效的平行化損失計算。目前僅支援交叉熵損失。
在此上下文管理器中,可以使用
cross_entropy()或CrossEntropyLoss,並對輸入參數做出以下假設。相應的backward()呼叫(如果有的話)也需要在此上下文管理器下進行。- 參數
輸入 (
DTensor) – 輸入邏輯值。假設在類別維度上進行分片。目標 (Union[
torch.Tensor,DTensor]) – 必須是真實類別索引(目前不支援類別機率)。假設在DeviceMesh上進行複製。權重 (Union[
torch.Tensor,DTensor], 可選) – 如果給定,則假設在DeviceMesh上進行複製。標籤平滑 – 目前不支援。
- 傳回值
一個複製的
DTensor。
範例
此處手動創建了一個分片的 DTensor 以展示其用法。在實務中,它通常是 TP 模組的輸出。
>>> from torch.distributed.tensor.parallel import loss_parallel >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> device_mesh = init_device_mesh("cuda", (8,)) >>> input = torch.randn(4, 16, device="cuda", requires_grad=True) >>> dist_input = distribute_tensor(input, device_mesh, placements=[Shard(1)]) >>> target = torch.randint(16, (4,), device="cuda") >>> with loss_parallel(): >>> loss = F.cross_entropy(dist_input, target, reduction="mean") >>> loss.backward() >>> ...
警告
loss_parallel API 仍處於實驗階段,可能會有所變更。