捷徑

DDP 通訊鉤子

DDP 通訊鉤子是一個通用介面,用於透過覆蓋 DistributedDataParallel 中的原始 allreduce 來控制如何在工作節點之間傳遞梯度。提供了一些內建的通訊鉤子,使用者可以輕鬆應用任何這些鉤子來最佳化通訊。此外,鉤子介面還可以支援使用者自訂的通訊策略,以滿足更進階的使用案例。

如何使用通訊鉤子?

若要使用通訊鉤子,使用者只需在訓練迴圈之前讓 DDP 模型註冊鉤子,如下所示。

torch.nn.parallel.DistributedDataParallel.register_comm_hook()

通訊鉤子對什麼進行操作?

通訊鉤子提供了一種靈活的方式來進行梯度 allreduce。因此,它主要對 allreduce 之前每個副本上的梯度進行操作,這些梯度會被分桶以增加通訊和計算之間的重疊。特別是,torch.distributed.GradBucket 表示要進行 allreduce 的梯度張量桶。

類別 torch.distributed.GradBucket

此類別主要將扁平化的梯度張量(由 buffer() 返回)傳遞給 DDP 通訊鉤子。此張量可以進一步分解為此桶中每個參數張量的清單(由 get_per_parameter_tensors() 返回),以應用分層操作。

torch.distributed.GradBucket.index(self: torch._C._distributed_c10d.GradBucket) int

警告

由於桶在第一次迭代後會重建,因此不應依賴訓練開始時的索引。

返回

儲存一些連續層的梯度的桶的索引。所有梯度都會被分桶。

torch.distributed.GradBucket.buffer(self: torch._C._distributed_c10d.GradBucket) torch.Tensor
返回

扁平化的一維 torch.Tensor 緩衝區,可以進一步分解為此桶中每個參數張量的清單。

torch.distributed.GradBucket.gradients(self: torch._C._distributed_c10d.GradBucket) list[torch.Tensor]
返回

torch.Tensor 的清單。清單中的每個張量都對應一個梯度。

torch.distributed.GradBucket.is_last(self: torch._C._distributed_c10d.GradBucket) bool
返回

此桶是否是迭代中最後一個進行 allreduce 的桶。這也意味著此桶對應於正向傳遞中的前幾層。

torch.distributed.GradBucket.set_buffer(self: torch._C._distributed_c10d.GradBucket, buffer: torch.Tensor) None

使用輸入張量緩衝區替換儲存區中的張量。

torch.distributed.GradBucket.parameters(self: torch._C._distributed_c10d.GradBucket) list[torch.Tensor]
返回

一個 torch.Tensor 的清單。清單中的每個張量都對應一個模型參數。

預設通訊鉤子

預設通訊鉤子是簡單的**無狀態**鉤子,因此 register_comm_hook 中的輸入狀態是一個處理群組或 None。輸入 bucket 是一個 torch.distributed.GradBucket 物件。

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.allreduce_hook(process_group, bucket)[原始碼]

使用 GradBucket 張量呼叫 allreduce

一旦梯度張量在所有工作節點上聚合後,它的 then 回呼會取平均值並返回結果。

如果使用者註冊此 DDP 通訊鉤子,則預計 DDP 結果與未註冊鉤子的情況相同。因此,這不會改變 DDP 的行為,使用者可以將其用作參考,或修改此鉤子以記錄有用信息或任何其他目的,而不會影響 DDP 行為。

範例:
>>> ddp_model.register_comm_hook(process_group, allreduce_hook)
返回類型

Future[Tensor]

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook(process_group, bucket)[原始碼]

通過將 GradBucket 轉換為除以處理群組大小的 torch.float16 來壓縮。

此 DDP 通訊鉤子實現了一種簡單的梯度壓縮方法,將 GradBucket 張量轉換為半精度浮點格式(torch.float16),然後將其除以處理群組大小。它會對這些 float16 梯度張量進行 allreduce。壓縮後的梯度張量經過 allreduce 後,鏈式回呼 decompress 會將其轉換回輸入數據類型(例如 float32)。

範例:
>>> ddp_model.register_comm_hook(process_group, fp16_compress_hook)
返回類型

Future[Tensor]

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.bf16_compress_hook(process_group, bucket)[原始碼]

警告:此 API 為實驗性 API,需要 2.9.6 以上版本的 NCCL。

此 DDP 通訊鉤子實現了一種簡單的梯度壓縮方法,將 GradBucket 張量轉換為半精度 Brain 浮點格式torch.bfloat16),然後將其除以處理群組大小。它會對這些 bfloat16 梯度張量進行 allreduce。壓縮後的梯度張量經過 allreduce 後,鏈式回呼 decompress 會將其轉換回輸入數據類型(例如 float32)。

範例:
>>> ddp_model.register_comm_hook(process_group, bf16_compress_hook)
返回類型

Future[Tensor]

此外,還提供了一個通訊鉤子包裝器來支援 fp16_compress_hook()bf16_compress_hook() 作為包裝器,可以與其他通訊鉤子組合使用。

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_wrapper(hook)[原始碼]

將輸入張量轉換為 torch.float16,將鉤子的結果轉換回輸入 dtype。

此包裝器將給定 DDP 通訊鉤子的輸入梯度張量轉換為半精度浮點格式(torch.float16),並將給定鉤子的結果張量轉換回輸入數據類型,例如 float32。因此,fp16_compress_hook 等效於 fp16_compress_wrapper(allreduce_hook)

範例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
>>> ddp_model.register_comm_hook(state, fp16_compress_wrapper(powerSGD_hook))
返回類型

Callable[[Any, GradBucket], Future[Tensor]]

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.bf16_compress_wrapper(hook)[原始碼]

警告:此 API 為實驗性 API,需要 2.9.6 以上版本的 NCCL。

此包裝器將給定 DDP 通訊鉤子的輸入梯度張量轉換為半精度 Brain 浮點格式 <https://zh-tw.wikipedia.org/wiki/Bfloat16%E6%B5%AE%E7%B4%B0%E5%B0%8F%E6%95%B8%E7%B4%84%E4%BB%8B> `_ (``torch.bfloat16`),並將給定鉤子的結果張量轉換回輸入數據類型,例如 float32

因此,bf16_compress_hook 等效於 bf16_compress_wrapper(allreduce_hook)

範例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
>>> ddp_model.register_comm_hook(state, bf16_compress_wrapper(powerSGD_hook))
返回類型

Callable[[Any, GradBucket], Future[Tensor]]

PowerSGD 通訊鉤子

PowerSGD(Vogels 等人,NeurIPS 2019)是一種梯度壓縮演算法,可以提供非常高的壓縮率並加速受頻寬限制的分散式訓練。此演算法需要同時維護一些超參數和內部狀態。因此,PowerSGD 通訊鉤子是一個**有狀態**鉤子,使用者需要提供如下定義的狀態物件。

PowerSGD 狀態

class torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.PowerSGDState(process_group, matrix_approximation_rank=1, start_powerSGD_iter=1000, min_compression_rate=2, use_error_feedback=True, warm_start=True, orthogonalization_epsilon=0, random_seed=0, compression_stats_logging_frequency=10000, batch_tensors_with_same_shape=False)[原始碼]

在訓練期間儲存演算法的所有梯度的超參數和內部狀態。

特別是,matrix_approximation_rankstart_powerSGD_iter 是使用者應該調整的主要超參數。為了效能,我們建議保持二元超參數 use_error_feedbackwarm_start 為開啟狀態。

  1. matrix_approximation_rank 控制壓縮後的低秩張量的尺寸,這決定了壓縮率。秩越低,壓縮越強。

    1.1. 如果 matrix_approximation_rank 太低,完整模型的品質將需要更多訓練步驟才能達到或永遠無法達到,並導致準確性下降。

    1.2. 增加 matrix_approximation_rank 會大幅增加壓縮的計算成本,並且在超過某個 matrix_approximation_rank 門檻後,準確性可能不會進一步提高。

為了調整 matrix_approximation_rank,我們建議從 1 開始,並以 2 的倍數增加(如指數網格搜尋,1、2、4…),直到達到令人滿意的準確性。通常只使用 1-4 的小值。對於某些 NLP 任務(如原始論文的附錄 D 所示),此值已增加到 32。

  1. start_powerSGD_iter 會將 PowerSGD 壓縮延遲到步驟 start_powerSGD_iter,並在步驟 start_powerSGD_iter 之前執行 vanilla allreduce。這種 **vanilla allreduce + PowerSGD** 的混合方案可以有效提高準確性,即使使用相對較小的 matrix_approximation_rank。這是因為訓練階段的開始通常對不準確的梯度非常敏感,過早壓縮梯度可能會使訓練很快走上次優的軌跡,這可能會對準確性造成不可恢復的影響。

為了調整 start_powerSGD_iter,我們建議從總訓練步驟的 10% 開始,並增加它直到達到令人滿意的準確性。如果訓練中有一個預熱階段,start_powerSGD_iter 通常應該不小於預熱步驟的數量。

  1. min_compression_rate 是壓縮圖層時所需的最小壓縮率。由於壓縮產生的計算開銷,只有在頻寬上有足夠的節省時,張量才值得壓縮,其中 (num_rows + num_cols) * matrix_approximation_rank * min_compression_rate < num_rows * num_cols。如果無法滿足指定的壓縮率門檻,張量將直接進行 allreduce 而不壓縮。

一旦 PowerSGD 壓縮開始,每 compression_stats_logging_frequency 次迭代就會記錄壓縮統計信息。

  1. orthogonalization_epsilon 是一個非常小的值(例如 1e-8),在正交化步驟中添加到每個歸一化矩陣列,以防止在任何列全為 0 時出現除以零的錯誤。如果已經可以防止這種情況(例如,通過批次歸一化),建議使用 0 的 epsilon 以確保準確性。

  2. batch_tensors_with_same_shape 控制是否在批處理操作中壓縮和解壓縮具有相同形狀的張量以實現更高的平行度。請注意,您還應該增加桶大小(即 DDP 構造函數中的 bucket_cap_mb 參數),以使更多形狀相同的張量出現在同一個桶中,但是這可能會減少計算和通信之間的重疊,並由於堆疊相同形狀的張量而增加記憶體佔用。如果壓縮/解壓縮計算是一個瓶頸,則設置為 True

警告

如果啟用了錯誤反饋或預熱,則 DDP 中允許的 start_powerSGD_iter 最小值為 2。這是因為 DDP 中還有一個內部優化,可以在迭代 1 時重建桶,這可能會與重建過程之前記憶的任何張量衝突。

PowerSGD 鉤子

警告

PowerSGD 通常需要與模型梯度大小相同的額外記憶體來啟用錯誤反饋,這可以補償有偏差的壓縮通信並提高準確性。

警告

PowerSGD 鉤子可能會與 Apex 自動混合精度套件 衝突。請改用 PyTorch 原生自動混合精度套件

torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.powerSGD_hook(state, bucket)[原始碼]

實作 PowerSGD 演算法。

此 DDP 通信鉤子實作了 論文 中描述的 PowerSGD 梯度壓縮演算法。一旦跨所有工作節點聚合梯度張量,此鉤子將按如下方式應用壓縮

  1. 將輸入的一維扁平化梯度張量視為每個參數張量的列表,並將所有張量分為兩組

    1.1. 在 allreduce 之前應該壓縮的張量,因為壓縮可以節省足夠的頻寬。

    1.2. 其餘的張量將直接進行 allreduce 而不壓縮,包括所有向量張量(用於偏差)。

  2. 處理未壓縮的張量

    2.1. 為這些未壓縮的張量分配連續的記憶體,並將所有未壓縮的張量作為一個批次進行 allreduce,而不進行壓縮;

    2.2. 將單個未壓縮的張量從連續記憶體複製回輸入張量。

  3. 處理應該通過 PowerSGD 壓縮壓縮的張量

    3.1. 對於每個張量 M,創建兩個低秩張量 P 和 Q 來分解 M,使得 M = PQ^T,其中 Q 從標準正態分布初始化並正交化;

    3.2. 計算 Ps 中的每個 P,它等於 MQ;

    3.3. 將 Ps 作為一個批次進行 allreduce;

    3.4. 正交化 Ps 中的每個 P;

    3.5. 計算 Qs 中的每個 Q,它大約等於 M^TP;

    3.6. 將 Qs 作為一個批次進行 allreduce;

    3.7. 計算所有壓縮張量中的每個 M,它大約等於 PQ^T。

請注意,此通信鉤子在前 state.start_powerSGD_iter 次迭代中強制執行 vanilla allreduce。這不僅使用戶可以更好地控制加速和準確性之間的權衡,還可以為未來的通信鉤子開發人員抽象出 DDP 內部優化的一些複雜性。

參數
  • **state** (PowerSGDState) – 配置壓縮率並支持錯誤反饋、熱啟動等的狀態信息。要調整壓縮配置,主要需要調整 matrix_approximation_rankstart_powerSGD_itermin_compression_rate

  • **bucket** (dist.GradBucket) – 存儲一維扁平化梯度張量的桶,該張量批處理多個每個變量的張量。請注意,由於 DDP comm 鉤子僅支持單進程單設備模式,因此此桶中僅存儲一個張量。

返回

通信的未來處理程序,它會就地更新梯度。

返回類型

Future[Tensor]

範例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1,
                          start_powerSGD_iter=10, min_compression_rate=0.5)
>>> ddp_model.register_comm_hook(state, powerSGD_hook)
torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.batched_powerSGD_hook(state, bucket)[原始碼]

實作簡化的 PowerSGD 演算法。

此 DDP 通信鉤子實作了 論文 中描述的簡化 PowerSGD 梯度壓縮演算法。此變體不逐層壓縮梯度,而是壓縮批處理所有梯度的扁平化輸入張量。因此,它比 powerSGD_hook() **更快**,但通常會導致 **準確性低得多**,除非 matrix_approximation_rank 為 1。

警告

增加此處的 matrix_approximation_rank 可能不一定会提高準確性,因為批處理每個參數張量而不进行列/行对齐可能會破壞低秩結構。因此,用戶應始終首先考慮 powerSGD_hook(),並且只有在 matrix_approximation_rank 為 1 時才能達到令人滿意的準確性時才考慮此變體。

一旦跨所有工作節點聚合梯度張量,此鉤子將按如下方式應用壓縮

  1. 將輸入的一維扁平化梯度張量視為具有 0 填充的方形張量 M;

  2. 創建兩個低秩張量 P 和 Q 來分解 M,使得 M = PQ^T,其中 Q 從標準正態分布初始化並正交化;

  3. 計算 P,它等於 MQ;

  4. 對 P 進行 Allreduce;

  5. 對 P 進行正交化;

  6. 計算 Q,它大約等於 M^TP;

  7. 對 Q 進行 Allreduce;

  8. 計算 M,它大約等於 PQ^T。

  9. 將輸入張量截斷為原始長度。

請注意,此通信鉤子在前 state.start_powerSGD_iter 次迭代中強制執行 vanilla allreduce。這不僅使用戶可以更好地控制加速和準確性之間的權衡,還可以為未來的通信鉤子開發人員抽象出 DDP 內部優化的一些複雜性。

參數
  • **state** (PowerSGDState) – 配置壓縮率並支持錯誤反饋、熱啟動等的狀態信息。要調整壓縮配置,主要需要調整 matrix_approximation_rankstart_powerSGD_iter

  • **bucket** (dist.GradBucket) – 存儲一維扁平化梯度張量的桶,該張量批處理多個每個變量的張量。請注意,由於 DDP comm 鉤子僅支持單進程單設備模式,因此此桶中僅存儲一個張量。

返回

通信的未來處理程序,它會就地更新梯度。

返回類型

Future[Tensor]

範例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1)
>>> ddp_model.register_comm_hook(state, batched_powerSGD_hook)

調試通信鉤子

顧名思義,調試通信鉤子**僅**用於調試和性能優化目的。

警告

調試通信鉤子不一定輸出正確的結果。

torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks.noop_hook(_, bucket)[原始碼]

回傳一個包裝輸入的 future,使其成為不產生任何通訊開銷的無效操作。

此掛鉤只能用於 allreduce 優化的空間分析,而不是用於正常的梯度同步。例如,如果在註冊此掛鉤後,訓練時間的加速不到 10%,則通常意味著 allreduce 不是此情況下的效能瓶頸。如果 GPU 追蹤難以取得或追蹤分析因 allreduce 與計算之間的重疊或各個等級之間的去同步化等因素而變得複雜,則此類檢測會特別有用。

範例:
>>> ddp_model.register_comm_hook(None, noop_hook)
返回類型

Future[Tensor]

通訊掛鉤的檢查點

狀態通訊掛鉤可以作為模型檢查點的一部分儲存,以啟用訓練器重新啟動。為了使掛鉤可序列化,應定義 __setstate____getstate__

警告

__getstate__ 應從回傳的字典中排除不可序列化的屬性。

警告

__setstate__ 應正確初始化從提供的 state 中排除的不可序列化屬性。

PowerSGDState 已實作 __setstate____getstate__,並且可以用作參考。

類別 torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.PowerSGDState[原始碼]
__getstate__()[原始碼]

回傳將被醃製和儲存的 Dict[str, Any]

process_group 不可序列化,並且從回傳的狀態中排除。

__setstate__(state)[原始碼]

取得提供的 state 並設定為此 PowerSGDState 執行個體。

process_group 設定為預設值。

以下是一個簡單的端到端範例,說明如何儲存和重新載入 PowerSGD 狀態和掛鉤。

import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp

from torch.nn.parallel import DistributedDataParallel
from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook as powerSGD

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(24,24)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(24,12)

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def run_demo(demo_fn, world_size):
    mp.spawn(
        demo_fn,
        args=(world_size,),
        nprocs=world_size,
        join=True)

def demo_serialization(rank, world_size):
    setup(rank, world_size)

    CHECKPOINT = tempfile.gettempdir() + "/checkpoint.pt"

    model = SimpleModel().to(rank)
    ddp_model = DistributedDataParallel(model, device_ids=[rank])

    powersgd_hook = powerSGD.powerSGD_hook
    powersgd_state = powerSGD.PowerSGDState(process_group=None)

    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
    ddp_model.register_comm_hook(powersgd_state, powersgd_hook)

    state = {
        'state_dict': ddp_model.state_dict(),
        'comm_hook': powersgd_hook,
        'comm_hook_state': powersgd_state}

    if rank == 0:
        torch.save(state, CHECKPOINT)

    dist.barrier()
    map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
    checkpoint = torch.load(CHECKPOINT, map_location=map_location)

    new_ddp_model = DistributedDataParallel(SimpleModel().to(rank), device_ids=[rank])
    new_ddp_model.load_state_dict(checkpoint['state_dict'])
    powersgd_hook = checkpoint['comm_hook']
    powersgd_state = checkpoint['comm_hook_state']

    new_ddp_model.register_comm_hook(powersgd_state, powersgd_hook)

    if rank == 0:
        os.remove(CHECKPOINT)

    cleanup()

if __name__ == "__main__":
    n_gpus = torch.cuda.device_count()
    assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
    world_size = n_gpus
    run_demo(demo_serialization, world_size)

致謝

非常感謝 PowerSGD 論文作者Thijs Vogels 對 PowerSGD 通訊掛鉤的程式碼審查,以及比較實驗,這些實驗顯示 PowerSGD 通訊掛鉤的效能與原始論文中的實作不相上下。

文件

存取 PyTorch 的完整開發人員文件

查看文件

教學課程

取得適用於初學者和進階開發人員的深入教學課程

查看教學課程

資源

尋找開發資源並取得問題解答

查看資源