• 文件 >
  • DDP Communication Hooks
快捷方式

DDP 通訊 Hook

DDP 通訊 hook 是一個通用介面,用於透過重寫 DistributedDataParallel 中的 vanilla allreduce 來控制跨 worker 通訊梯度的方式。提供了一些內建的通訊 hook,使用者可以輕鬆應用其中任何一個來最佳化通訊。此外,該 hook 介面還支援使用者自定義通訊策略,以用於更高階的用例。

如何使用通訊 Hook?

要使用通訊 hook,使用者只需讓 DDP 模型在訓練迴圈開始前註冊該 hook,如下所示。

torch.nn.parallel.DistributedDataParallel.register_comm_hook()

通訊 Hook 對什麼進行操作?

通訊 hook 提供了一種靈活的方式來 allreduce 梯度。因此,它主要在 allreduce 之前對每個副本上的梯度進行操作,這些梯度被分桶(bucketized)以增加通訊和計算之間的重疊。特別地,torch.distributed.GradBucket 代表一個包含待 allreduce 梯度張量的桶。

class torch.distributed.GradBucket

此類主要將展平的梯度張量(由 buffer() 返回)傳遞給 DDP 通訊 hook。此張量可以進一步分解為此桶內每個引數的張量列表(由 get_per_parameter_tensors() 返回),以應用層級操作。

torch.distributed.GradBucket.index(self: torch._C._distributed_c10d.GradBucket) int

警告

由於桶在第一次迭代後會重建,因此不應依賴訓練開始時的索引。

返回值

儲存幾個連續層梯度的桶的索引。所有梯度均已分桶。

torch.distributed.GradBucket.buffer(self: torch._C._distributed_c10d.GradBucket) torch.Tensor
返回值

一個展平的 1D torch.Tensor 緩衝區,可以進一步分解為此桶內每個引數的張量列表。

torch.distributed.GradBucket.gradients(self: torch._C._distributed_c10d.GradBucket) list[torch.Tensor]
返回值

一個 torch.Tensor 列表。列表中的每個張量對應一個梯度。

torch.distributed.GradBucket.is_last(self: torch._C._distributed_c10d.GradBucket) bool
返回值

此桶是否是迭代中最後一個進行 allreduce 的桶。這也意味著此桶對應於前向傳播中的前幾層。

torch.distributed.GradBucket.set_buffer(self: torch._C._distributed_c10d.GradBucket, buffer: torch.Tensor) None

用輸入的張量緩衝區替換桶中的張量。

torch.distributed.GradBucket.parameters(self: torch._C._distributed_c10d.GradBucket) list[torch.Tensor]
返回值

一個 torch.Tensor 列表。列表中的每個張量對應一個模型引數。

預設通訊 Hook

預設通訊 hook 是簡單的 無狀態 hook,因此 register_comm_hook 中的輸入 state 要麼是一個 process group,要麼是 None。輸入 bucket 是一個 torch.distributed.GradBucket 物件。

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.allreduce_hook(process_group, bucket)[source][source]

使用 GradBucket 張量呼叫 allreduce

一旦梯度張量在所有 worker 上聚合,其 then 回撥將計算均值並返回結果。

如果使用者註冊此 DDP 通訊 hook,則 DDP 結果預計與未註冊 hook 的情況相同。因此,這不會改變 DDP 的行為,使用者可以將其用作參考或修改此 hook 來記錄有用資訊或用於任何其他目的,同時不影響 DDP 行為。

示例:
>>> ddp_model.register_comm_hook(process_group, allreduce_hook)
返回值型別

Future[Tensor]

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook(process_group, bucket)[source][source]

透過將 GradBucket 強制轉換為 torch.float16 併除以 process group 大小來壓縮。

此 DDP 通訊 hook 實現了一種簡單的梯度壓縮方法,將 GradBucket 張量強制轉換為半精度浮點格式(torch.float16),然後除以 process group 大小。它 allreduce 這些 float16 梯度張量。一旦壓縮的梯度張量被 allreduce,鏈式回撥 decompress 將其強制轉換回輸入資料型別(例如 float32)。

示例:
>>> ddp_model.register_comm_hook(process_group, fp16_compress_hook)
返回值型別

Future[Tensor]

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.bf16_compress_hook(process_group, bucket)[source][source]

警告:此 API 處於實驗階段,需要 NCCL 版本高於 2.9.6。

此 DDP 通訊 hook 實現了一種簡單的梯度壓縮方法,將 GradBucket 張量強制轉換為半精度 Brain 浮點格式torch.bfloat16),然後除以 process group 大小。它 allreduce 這些 bfloat16 梯度張量。一旦壓縮的梯度張量被 allreduce,鏈式回撥 decompress 將其強制轉換回輸入資料型別(例如 float32)。

示例:
>>> ddp_model.register_comm_hook(process_group, bf16_compress_hook)
返回值型別

Future[Tensor]

此外,還提供了一個通訊 hook 包裝器,以支援將 fp16_compress_hook()bf16_compress_hook() 用作包裝器,可以與其他通訊 hook 結合使用。

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_wrapper(hook)[source][source]

將輸入張量強制轉換為 torch.float16,將 hook 的結果強制轉換回輸入資料型別。

此包裝器將給定 DDP 通訊 hook 的輸入梯度張量強制轉換為半精度浮點格式(torch.float16),並將給定 hook 的結果張量強制轉換回輸入資料型別,例如 float32。因此,fp16_compress_hook 等效於 fp16_compress_wrapper(allreduce_hook)

示例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
>>> ddp_model.register_comm_hook(state, fp16_compress_wrapper(powerSGD_hook))
返回值型別

Callable[[Any, GradBucket], Future[Tensor]]

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.bf16_compress_wrapper(hook)[source][source]

警告:此 API 處於實驗階段,需要 NCCL 版本高於 2.9.6。

此包裝器將給定 DDP 通訊 hook 的輸入梯度張量強制轉換為半精度 `Brain 浮點格式 `_(torch.bfloat16),並將給定 hook 的結果張量強制轉換回輸入資料型別,例如 float32

因此,bf16_compress_hook 等效於 bf16_compress_wrapper(allreduce_hook)

示例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
>>> ddp_model.register_comm_hook(state, bf16_compress_wrapper(powerSGD_hook))
返回值型別

Callable[[Any, GradBucket], Future[Tensor]]

PowerSGD 通訊 Hook

PowerSGD (Vogels 等,NeurIPS 2019) 是一種梯度壓縮演算法,可以提供非常高的壓縮率並加速頻寬受限的分散式訓練。該演算法需要同時維護一些超引數和內部狀態。因此,PowerSGD 通訊 hook 是一個 有狀態 的 hook,使用者需要提供如下定義的狀態物件。

PowerSGD 狀態

class torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.PowerSGDState(process_group, matrix_approximation_rank=1, start_powerSGD_iter=1000, min_compression_rate=2, use_error_feedback=True, warm_start=True, orthogonalization_epsilon=0, random_seed=0, compression_stats_logging_frequency=10000, batch_tensors_with_same_shape=False)[source][source]

在訓練期間儲存演算法的超引數和所有梯度的內部狀態。

特別地,matrix_approximation_rankstart_powerSGD_iter 是使用者應調整的主要超引數。為了效能,我們建議保持二進位制超引數 use_error_feedbackwarm_start 為 True。

  1. matrix_approximation_rank 控制壓縮低秩張量的大小,這決定了壓縮率。秩越低,壓縮越強。

    1.1. 如果 matrix_approximation_rank 過低,模型質量將需要更多訓練步驟才能達到,或者永遠無法達到,並導致精度損失。

    1.2. 增加 matrix_approximation_rank 會顯著增加壓縮的計算成本,並且精度可能不會超過某個 matrix_approximation_rank 閾值進一步提高。

為了調優 matrix_approximation_rank,我們建議從 1 開始,並以 2 的倍數增加(例如指數網格搜尋,1、2、4…),直到達到令人滿意的精度。通常只使用一個較小的值,如 1-4。對於一些 NLP 任務(如原始論文附錄 D 所示),該值已增加到 32。

  1. start_powerSGD_iter 會將 PowerSGD 壓縮推遲到步驟 start_powerSGD_iter 之後,而 vanilla allreduce 會在步驟 start_powerSGD_iter 之前執行。這種 vanilla allreduce + PowerSGD 的混合方案可以有效提高精度,即使使用相對較小的 matrix_approximation_rank。這是因為訓練階段的開始通常對不準確的梯度非常敏感,過早壓縮梯度可能會使訓練迅速走向次優軌跡,從而對精度產生不可逆轉的影響。

為了調優 start_powerSGD_iter,我們建議從總訓練步驟的 10% 開始,並增加它直到達到令人滿意的精度。如果在訓練中有熱身(warm-up)階段,start_powerSGD_iter 通常應不小於熱身步驟的數量。

  1. min_compression_rate 是壓縮層時所需的最小壓縮率。由於壓縮帶來的計算開銷,只有當頻寬能節省足夠多時才值得壓縮張量,即 (num_rows + num_cols) * matrix_approximation_rank * min_compression_rate < num_rows * num_cols。如果指定的壓縮率閾值無法滿足,張量將直接進行 allreduce 而不進行壓縮。

PowerSGD 壓縮開始後,每 compression_stats_logging_frequency 次迭代記錄一次壓縮統計資訊。

  1. orthogonalization_epsilon 是在正交化步驟中新增到每個歸一化矩陣列的非常小的值(例如 1e-8),以防止任一列全為 0 時出現除以零錯誤。如果這已可避免(例如透過批次歸一化),建議使用 0 的 epsilon 以獲得精度。

  2. batch_tensors_with_same_shape 控制是否在批次操作中壓縮和解壓縮形狀相同的張量,以實現更高的並行性。請注意,您還應該增加桶的大小(即 DDP 建構函式中的 bucket_cap_mb 引數),以使更多形狀相同的張量出現在同一個桶中,但這可能會減少計算和通訊之間的重疊,並因堆疊相同形狀的張量而增加記憶體佔用。如果壓縮/解壓縮計算是瓶頸,請設定為 True

警告

如果啟用錯誤反饋或熱身,DDP 中允許的 start_powerSGD_iter 的最小值為 2。這是因為 DDP 中存在另一個內部最佳化,它在迭代 1 時重建桶,這可能會與重建過程之前記憶的任何張量衝突。

PowerSGD Hook

警告

PowerSGD 通常需要與模型梯度大小相同的額外記憶體來啟用錯誤反饋,這可以補償有偏的壓縮通訊並提高精度。

警告

PowerSGD 鉤子可能與 Apex 自動混合精度包 衝突。請改用 PyTorch 原生自動混合精度包

torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.powerSGD_hook(state, bucket)[source][source]

實現 PowerSGD 演算法。

這個 DDP 通訊鉤子實現了 論文中描述的 PowerSGD 梯度壓縮演算法。一旦梯度張量在所有工作程序中聚合完畢,這個鉤子會按如下方式應用壓縮:

  1. 將輸入的展平一維梯度張量視為一個包含每個引數張量的列表,並將所有張量分為兩組:

    1.1 在 allreduce 之前應該被壓縮的張量,因為壓縮可以顯著節省頻寬。

    1.2 剩餘的張量將不經壓縮直接進行 allreduce,包括所有向量張量(用於偏置)。

  2. 處理未壓縮的張量:

    2.1. 為這些未壓縮張量分配連續記憶體,並將所有未壓縮張量作為一個批次進行 allreduce,不進行壓縮;

    2.2. 將單個未壓縮張量從未壓縮張量的連續記憶體複製回輸入張量。

  3. 處理應該透過 PowerSGD 壓縮的張量:

    3.1. 對於每個張量 M,建立兩個低秩張量 P 和 Q 來分解 M,使得 M = PQ^T,其中 Q 從標準正態分佈初始化並正交化;

    3.2. 計算 Ps 中的每個 P,其等於 MQ;

    3.3. 將 Ps 作為一個批次進行 allreduce;

    3.4. 對 Ps 中的每個 P 進行正交化;

    3.5. 計算 Qs 中的每個 Q,其約等於 M^TP;

    3.6. 將 Qs 作為一個批次進行 allreduce;

    3.7. 計算所有被壓縮張量中的每個 M,其約等於 PQ^T。

請注意,這個通訊鉤子在前 state.start_powerSGD_iter 次迭代中強制執行原版 allreduce。這不僅讓使用者可以更好地控制速度提升和精度之間的權衡,還有助於未來的通訊鉤子開發者抽象 DDP 內部最佳化的一些複雜性。

引數
  • state (PowerSGDState) – 用於配置壓縮率並支援誤差反饋、熱啟動等的狀態資訊。要調整壓縮配置,主要需要調整 matrix_approximation_rankstart_powerSGD_itermin_compression_rate

  • bucket (dist.GradBucket) – 儲存展平的一維梯度張量的桶,該張量批處理了多個按變數劃分的張量。請注意,由於 DDP 通訊鉤子僅支援單程序單裝置模式,因此此桶中僅儲存一個張量。

返回值

通訊的 Future handler,它會就地更新梯度。

返回值型別

Future[Tensor]

示例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1,
                          start_powerSGD_iter=10, min_compression_rate=0.5)
>>> ddp_model.register_comm_hook(state, powerSGD_hook)
torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.batched_powerSGD_hook(state, bucket)[source][source]

實現簡化的 PowerSGD 演算法。

這個 DDP 通訊鉤子實現了 論文中描述的簡化的 PowerSGD 梯度壓縮演算法。這個變體不按層壓縮梯度,而是壓縮批處理所有梯度的展平輸入張量。因此,它比 powerSGD_hook() **更快**,但通常會導致**精度大大降低**,除非 matrix_approximation_rank 為 1。

警告

在這裡增加 matrix_approximation_rank 可能不一定能提高精度,因為在沒有列/行對齊的情況下批處理每個引數張量可能會破壞低秩結構。因此,使用者應始終首先考慮 powerSGD_hook(),僅當 matrix_approximation_rank 為 1 時可以達到令人滿意的精度時,才考慮使用此變體。

一旦梯度張量在所有工作程序中聚合完畢,這個鉤子會按如下方式應用壓縮:

  1. 將輸入的展平一維梯度張量視為一個帶 0 填充的方形張量 M;

  2. 建立兩個低秩張量 P 和 Q 來分解 M,使得 M = PQ^T,其中 Q 從標準正態分佈初始化並正交化;

  3. 計算 P,其等於 MQ;

  4. 對 P 進行 allreduce;

  5. 對 P 進行正交化;

  6. 計算 Q,其約等於 M^TP;

  7. 對 Q 進行 allreduce;

  8. 計算 M,其約等於 PQ^T。

  9. 將輸入張量截斷到原始長度。

請注意,這個通訊鉤子在前 state.start_powerSGD_iter 次迭代中強制執行原版 allreduce。這不僅讓使用者可以更好地控制速度提升和精度之間的權衡,還有助於未來的通訊鉤子開發者抽象 DDP 內部最佳化的一些複雜性。

引數
  • state (PowerSGDState) – 用於配置壓縮率並支援誤差反饋、熱啟動等的狀態資訊。要調整壓縮配置,主要需要調整 matrix_approximation_rankstart_powerSGD_iter

  • bucket (dist.GradBucket) – 儲存展平的一維梯度張量的桶,該張量批處理了多個按變數劃分的張量。請注意,由於 DDP 通訊鉤子僅支援單程序單裝置模式,因此此桶中僅儲存一個張量。

返回值

通訊的 Future handler,它會就地更新梯度。

返回值型別

Future[Tensor]

示例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1)
>>> ddp_model.register_comm_hook(state, batched_powerSGD_hook)

除錯通訊鉤子

顧名思義,除錯通訊鉤子**僅**用於除錯和效能最佳化目的。

警告

除錯通訊鉤子不一定會輸出正確的結果。

torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks.noop_hook(_, bucket)[source][source]

返回一個包裝輸入的 Future,因此它是一個不產生任何通訊開銷的空操作(no-op)。

這個鉤子**僅**應用於 allreduce 最佳化上限的分析,而不是正常的梯度同步。例如,如果在註冊此鉤子後訓練時間僅提速不到 10%,通常意味著在這種情況下 allreduce 不是效能瓶頸。如果 GPU 軌跡難以獲取,或者軌跡分析因 allreduce 與計算的重疊或跨 rank 的不同步等因素而變得複雜,這種檢測會特別有用。

示例:
>>> ddp_model.register_comm_hook(None, noop_hook)
返回值型別

Future[Tensor]

通訊鉤子的檢查點儲存

有狀態的通訊鉤子可以作為模型檢查點的一部分儲存,以支援訓練器重啟。要使鉤子可序列化,應該定義 __setstate____getstate__

警告

__getstate__ 應該從返回的字典中排除不可序列化的屬性。

警告

__setstate__ 應該正確初始化從提供的 state 中排除的不可序列化屬性。

PowerSGDState 實現了 __setstate____getstate__,可以作為參考。

class torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.PowerSGDState[source][source]
__getstate__()[source][source]

返回一個 Dict[str, Any],它將被 pickle 化並儲存。

process_group 不可序列化,因此從返回的狀態中排除。

__setstate__(state)[source][source]

接受一個提供的 state 並設定到此 PowerSGDState 例項。

process_group 被設定為預設值。

這裡是一個簡單、端到端儲存和重新載入 PowerSGD state 和 hook 的示例。

import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp

from torch.nn.parallel import DistributedDataParallel
from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook as powerSGD

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(24,24)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(24,12)

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def run_demo(demo_fn, world_size):
    mp.spawn(
        demo_fn,
        args=(world_size,),
        nprocs=world_size,
        join=True)

def demo_serialization(rank, world_size):
    setup(rank, world_size)

    CHECKPOINT = tempfile.gettempdir() + "/checkpoint.pt"

    model = SimpleModel().to(rank)
    ddp_model = DistributedDataParallel(model, device_ids=[rank])

    powersgd_hook = powerSGD.powerSGD_hook
    powersgd_state = powerSGD.PowerSGDState(process_group=None)

    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
    ddp_model.register_comm_hook(powersgd_state, powersgd_hook)

    state = {
        'state_dict': ddp_model.state_dict(),
        'comm_hook': powersgd_hook,
        'comm_hook_state': powersgd_state}

    if rank == 0:
        torch.save(state, CHECKPOINT)

    dist.barrier()
    map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
    checkpoint = torch.load(CHECKPOINT, map_location=map_location)

    new_ddp_model = DistributedDataParallel(SimpleModel().to(rank), device_ids=[rank])
    new_ddp_model.load_state_dict(checkpoint['state_dict'])
    powersgd_hook = checkpoint['comm_hook']
    powersgd_state = checkpoint['comm_hook_state']

    new_ddp_model.register_comm_hook(powersgd_state, powersgd_hook)

    if rank == 0:
        os.remove(CHECKPOINT)

    cleanup()

if __name__ == "__main__":
    n_gpus = torch.cuda.device_count()
    assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
    world_size = n_gpus
    run_demo(demo_serialization, world_size)

致謝

非常感謝 PowerSGD 論文作者 **Thijs Vogels** 對 PowerSGD 通訊鉤子程式碼進行了審查,以及提供了對比實驗,這些實驗表明 PowerSGD 通訊鉤子的效能與原始論文中的實現相當。

文件

訪問 PyTorch 全面的開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源