快捷方式

DeviceMesh 入門

創建於:Jan 24, 2024 | 最後更新於:Feb 24, 2025 | 最後驗證於:Nov 05, 2024

作者: Iris Zhang, Wanchao Liang

注意

editgithub 中檢視和編輯本教程。

前提條件

為分散式訓練設定分散式通訊器,即 NVIDIA Collective Communication Library (NCCL) 通訊器,可能是一個重大的挑戰。對於使用者需要組合不同並行模式的工作負載,使用者需要為每種並行解決方案手動設定和管理 NCCL 通訊器(例如,ProcessGroup)。這個過程可能既複雜又容易出錯。DeviceMesh 可以簡化這個過程,使其更易於管理且不易出錯。

什麼是 DeviceMesh

DeviceMesh 是一個更高級別的抽象,用於管理 ProcessGroup。它允許使用者輕鬆建立節點間和節點內程序組,而無需擔心如何為不同的子程序組正確設定 rank。使用者還可以透過 DeviceMesh 輕鬆管理底層程序組/裝置,用於多維並行。

PyTorch DeviceMesh

DeviceMesh 為何有用

DeviceMesh 在處理需要並行組合的多維並行(即 3D 並行)工作負載時非常有用。例如,當你的並行解決方案既需要跨主機通訊,也需要每臺主機內部通訊時。上圖顯示,我們可以建立一個 2D mesh,它連線每臺主機內部的裝置,並在同構設定中將每臺裝置與其在其他主機上的對應裝置連線起來。

如果沒有 DeviceMesh,使用者需要在應用任何並行之前手動設定 NCCL 通訊器和每個程序上的 cuda 裝置,這可能非常複雜。以下程式碼片段演示瞭如何在沒有 DeviceMesh 的情況下設定一個混合分片 2D 並行模式。首先,我們需要手動計算分片組 (shard group) 和複製組 (replicate group)。然後,我們需要將正確的分片組和複製組分配給每個 rank。

import os

import torch
import torch.distributed as dist

# Understand world topology
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
print(f"Running example on {rank=} in a world with {world_size=}")

# Create process groups to manage 2-D like parallel pattern
dist.init_process_group("nccl")
torch.cuda.set_device(rank)

# Create shard groups (e.g. (0, 1, 2, 3), (4, 5, 6, 7))
# and assign the correct shard group to each rank
num_node_devices = torch.cuda.device_count()
shard_rank_lists = list(range(0, num_node_devices // 2)), list(range(num_node_devices // 2, num_node_devices))
shard_groups = (
    dist.new_group(shard_rank_lists[0]),
    dist.new_group(shard_rank_lists[1]),
)
current_shard_group = (
    shard_groups[0] if rank in shard_rank_lists[0] else shard_groups[1]
)

# Create replicate groups (for example, (0, 4), (1, 5), (2, 6), (3, 7))
# and assign the correct replicate group to each rank
current_replicate_group = None
shard_factor = len(shard_rank_lists[0])
for i in range(num_node_devices // 2):
    replicate_group_ranks = list(range(i, num_node_devices, shard_factor))
    replicate_group = dist.new_group(replicate_group_ranks)
    if rank in replicate_group_ranks:
        current_replicate_group = replicate_group

要執行上述程式碼片段,我們可以利用 PyTorch Elastic。讓我們建立一個名為 2d_setup.py 的檔案。然後,執行以下 torch elastic/torchrun 命令。

torchrun --nproc_per_node=8 --rdzv_id=100 --rdzv_endpoint=localhost:29400 2d_setup.py

注意

為簡化演示,我們僅使用一個節點來模擬 2D 並行。注意,此程式碼片段也可用於多主機設定。

藉助 init_device_mesh(),我們只需兩行程式碼即可完成上述 2D 設定,並且如果需要,仍然可以訪問底層的 ProcessGroup

from torch.distributed.device_mesh import init_device_mesh
mesh_2d = init_device_mesh("cuda", (2, 4), mesh_dim_names=("replicate", "shard"))

# Users can access the underlying process group thru `get_group` API.
replicate_group = mesh_2d.get_group(mesh_dim="replicate")
shard_group = mesh_2d.get_group(mesh_dim="shard")

讓我們建立一個名為 2d_setup_with_device_mesh.py 的檔案。然後,執行以下 torch elastic/torchrun 命令。

torchrun --nproc_per_node=8 2d_setup_with_device_mesh.py

如何將 DeviceMesh 與 HSDP 結合使用

Hybrid Sharding Data Parallel (HSDP) 是一種 2D 策略,用於在主機內執行 FSDP,在主機間執行 DDP。

讓我們看一個例子,瞭解 DeviceMesh 如何透過簡單設定幫助將 HSDP 應用於你的模型。使用 DeviceMesh,使用者無需手動建立和管理分片組和複製組。

import torch
import torch.nn as nn

from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


# HSDP: MeshShape(2, 4)
mesh_2d = init_device_mesh("cuda", (2, 4))
model = FSDP(
    ToyModel(), device_mesh=mesh_2d, sharding_strategy=ShardingStrategy.HYBRID_SHARD
)

讓我們建立一個名為 hsdp.py 的檔案。然後,執行以下 torch elastic/torchrun 命令。

torchrun --nproc_per_node=8 hsdp.py

如何將 DeviceMesh 用於自定義並行解決方案

在進行大規模訓練時,你可能會有更復雜的自定義並行訓練組合。例如,你可能需要為不同的並行解決方案切分出子 mesh。DeviceMesh 允許使用者從父 mesh 中切分出子 mesh,並重用父 mesh 初始化時已建立的 NCCL 通訊器。

from torch.distributed.device_mesh import init_device_mesh
mesh_3d = init_device_mesh("cuda", (2, 2, 2), mesh_dim_names=("replicate", "shard", "tp"))

# Users can slice child meshes from the parent mesh.
hsdp_mesh = mesh_3d["replicate", "shard"]
tp_mesh = mesh_3d["tp"]

# Users can access the underlying process group thru `get_group` API.
replicate_group = hsdp_mesh["replicate"].get_group()
shard_group = hsdp_mesh["shard"].get_group()
tp_group = tp_mesh.get_group()

結論

總之,我們學習了 DeviceMeshinit_device_mesh(),以及它們如何用於描述叢集中裝置的佈局。

欲瞭解更多資訊,請參閱以下內容

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源