快捷方式

支援 TorchScript 的分散式最佳化器¶

創建於:2021 年 4 月 26 日 | 最後更新於:2024 年 12 月 02 日 | 最後驗證於:2024 年 11 月 05 日

警告

TorchScript 不再積極開發。

在本 Recipe 中,你將學習

  • 支援 TorchScript 的分散式最佳化器的高階概念以及此功能帶來的益處

  • 如何編寫支援 TorchScript 的自定義分散式最佳化器

要求¶

什麼是分散式最佳化器?¶

DistributedOptimizer 接受一個遠端引數列表 (RRef),並在引數所在的 worker 上本地執行最佳化器,這通常與分散式 RPC/Autograd 一起用於模型並行訓練。它可以使用任何本地最佳化器演算法(無論是 torch.optim 中提供的預定義演算法還是自定義演算法)來對每個 worker 上的梯度應用更新。

什麼是支援 TorchScript 的分散式最佳化器?¶

分散式最佳化器廣泛應用於分散式模型並行訓練,在一些常見用例中,出於效能考慮和資源利用率,訓練需要以多執行緒方式而不是多程序方式進行(或者至少是部分多執行緒,例如引數伺服器託管模型和引數的一部分,每個請求由新執行緒更新引數)。PyTorch 本身不原生支援多執行緒訓練,因為它受到 Python 全域性直譯器鎖 (GIL) 的影響,但它可以利用 TorchScript 來擺脫 GIL 並以多執行緒方式執行模型。

對於關鍵模型訓練工作負載,提高訓練效能是一個重要課題。研究人員通常希望透過圖表示(即透過運算元融合)實現不同的最佳化策略,或實現自定義運算元核以加速訓練。

支援 TorchScript 的分散式最佳化器可以幫助擺脫 GIL,從而提高 PyTorch 在多執行緒環境中的訓練效能,它還釋放了利用 TorchScript 提供的先進編譯器技術(即 CPU/GPU 融合)進一步提升效能的潛力。

如何編寫支援 TorchScript 的自定義分散式最佳化器?¶

下面的程式碼展示瞭如何在現有本地最佳化器實現的基礎上編寫自定義分散式最佳化器,從而解鎖 TorchScript 的優勢,包括消除 GIL 和提高效能的機會。

假設你已經有一個在訓練期間當前使用的本地最佳化器,在此示例中,我們將使用擬雙曲動量 (QHM) 作為例子來展示如何啟用 TorchScript 支援,請注意,這也適用於任何繼承自 torch.optim.Optimizer 的自定義最佳化器。

首先,我們需要將計算和狀態管理與最佳化器實現分開,這樣我們就可以提取計算部分並使其成為一個自由函式,這對於 TorchScript 是友好的。這有兩個好處:1. 計算邏輯更容易檢查,它允許我們快速將引數更新/計算部分轉換為 TorchScript,並利用 TorchScript IR 進行進一步最佳化(運算元融合等)。2. 分散式最佳化器底層使用不同的機制來獲取梯度和更新引數(我們單獨儲存梯度,而不是在反向傳播期間直接填充 param.grad 欄位)。將計算分離出來使得分散式最佳化器能夠在多執行緒模式下進行最佳化器更新,因為它消除了對 param.grad 可能存在的競爭條件。

import torch
from torch import Tensor
from typing import List


def qhm_update(params: List[Tensor],
            dp_list: List[Tensor],
            momentum_buffer_list: List[Tensor],
            lr: float,
            nu: float,
            weight_decay: float,
            weight_decay_type: str,
            momentum: float):

    for p, d_p, momentum_buffer in zip(params, dp_list, momentum_buffer_list):
        if weight_decay != 0:
            if weight_decay_type == "grad":
                d_p.add_(weight_decay, p)
            elif weight_decay_type == "direct":
                p.mul_(1.0 - lr * weight_decay)
            else:
                raise ValueError("Invalid weight decay type provided")

        momentum_buffer.mul_(momentum).add_(1.0 - momentum, d_p)

        p.data.add_(-lr * nu, momentum_buffer)
        p.data.add_(-lr * (1.0 - nu), d_p)

接下來,我們將定義一個具有 TorchScript 相容性的分散式函式式最佳化器,用於管理最佳化器狀態並呼叫我們上面定義的 TorchScript 相容更新函式。請注意,一些約定與普通自定義最佳化器不同:1. 我們不繼承 torch.optim.Optimizer,因為 TorchScript 不支援多型。2. step 接受梯度列表而不是損失閉包。

import torch
from torch import Tensor
from typing import List, Optional, Dict

# define this as a TorchScript class
@torch.jit.script
class FunctionalQHM(object):
    def __init__(self,
                params: List[Tensor],
                lr: float,
                momentum: float,
                nu: float,
                weight_decay: float = 0.0,
                weight_decay_type: str = "grad"):
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        if weight_decay_type not in ("grad", "direct"):
            raise ValueError("Invalid weight_decay_type value: {}".format(weight_decay_type))

        self.defaults = {
            "lr": lr,
            "momentum": momentum,
            "nu": nu,
            "weight_decay": weight_decay,
        }
        self.weight_decay_type = weight_decay_type

        # NOTE: we only have one param_group here and don't allow user to add additional
        # param group as it's not a common use case.
        self.param_group = {"params": params}

        self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})

    def step(self, gradients: List[Optional[Tensor]]):
        params = self.param_group['params']
        params_with_grad = []
        grads = []
        momentum_buffer_list: List[Tensor] = []

        if len(params) != len(gradients):
            raise ValueError(
                "the gradients passed in does not equal to the size of the parameters!"
                + f"Params length: {len(params)}. "
                + f"Gradients length: {len(gradients)}"
            )

        for param, gradient in zip(self.param_group['params'], gradients):
            if gradient is not None:
                params_with_grad.append(param)
                grads.append(gradient)
                state = self.state[param]
                state['momentum_buffer'] = torch.zeros_like(param, memory_format=torch.preserve_format)
                momentum_buffer_list.append(state['momentum_buffer'])

        # calls into the update function we just defined
        with torch.no_grad():
            qhm_update(params_with_grad,
                    grads,
                    momentum_buffer_list,
                    self.defaults['lr'],
                    self.defaults['nu'],
                    self.defaults['weight_decay'],
                    self.weight_decay_type,
                    self.defaults['momentum'])

最後,我們將新定義的分散式函式式最佳化器註冊到 functional_optim_map 中。這樣,DistributedOptimizer 將嘗試使用我們的自定義實現,而不是預定義的預設實現。

from torch.distributed.optim import DistributedOptimizer

DistributedOptimizer.functional_optim_map[QHM] = FunctionalQHM

現在,你可以在分散式訓練中像往常一樣使用 QHM 最佳化器,將其傳遞給 DistributedOptimizer

...
remote_params_list = [...]
dist_optim = DistributedOptimizer(
    QHM, remote_params_list, *args, **kwargs
)

DistributedOptimizer 會在底層自動將 QHM 最佳化器轉換為 FunctionalQHM,並啟用 TorchScript 支援。這將釋放由多執行緒訓練帶來的效能提升,也為進一步改進(例如 TorchScript 融合等)提供了更多潛力。

請注意,大多數 PyTorch 內建最佳化器已經使用這種方法來加速分散式訓練。如果你看到關於某些最佳化器尚未轉換的警告,可以按照本 recipe 編寫自己的轉換。

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源