• 文件 >
  • 使用 autograd.Function 擴充套件 torch.func
快捷方式

使用 autograd.Function 擴充套件 torch.func

您可能希望將 torch.autograd.Functiontorch.func 變換(如 torch.vmap()torch.func.grad() 等)一起使用。

主要有兩個用例

  • 您希望呼叫不包含 PyTorch 運算元的程式碼並使其與函式變換一起工作。也就是說,torch.autograd.Function 的 forward/backward/等方法呼叫其他系統(如 C++、CUDA、numpy)中的函式。

  • 您希望指定自定義梯度規則,例如 JAX 的 custom_vjp/custom_jvp

PyTorch 將這兩個概念結合到 torch.autograd.Function 中。

基本用法

本指南假定您熟悉 擴充套件 torch.autograd,其中解釋瞭如何使用 torch.autograd.Function

torch.autograd.Function 可以有一個接受 ctx 物件的 forward() 方法,或者有一個不接受 ctx 的獨立的 forward() 方法和一個修改 ctx 物件的 setup_context() 靜態方法。

函式變換隻支援後者

  • forward() 是執行操作的程式碼,它不應該接受 ctx 物件。

  • setup_context(ctx, inputs, output) 是您可以在 ctx 物件上呼叫方法的地方。您應該在這裡儲存用於反向傳播的張量(透過呼叫 ctx.save_for_backward(*tensors)),或者儲存非張量物件(透過將它們賦值給 ctx 物件)。

因為 setup_context() 只接受 inputsoutput,所以可以儲存的數量只能是 inputs 或 outputs 中的物件(如張量),或是從它們派生的數量(如 Tensor.shape)。如果您希望儲存 Function.forward() 中非輸入的中間啟用用於反向傳播,則需要將其作為 forward() 的輸出返回,以便傳遞給 setup_context()

根據不同的變換,

為了使 torch.autograd.Function 能夠與函式變換任意組合,我們建議除了 forward()setup_context() 之外的所有其他靜態方法都必須是可變換的:也就是說,它們必須僅由 PyTorch 運算元組成或呼叫其他 torch.autograd.Function(這些 Function 可能呼叫 C++/CUDA/等)。

讓我們來看一些常見用例的示例。

示例 1:autograd.Function 呼叫其他系統

一個常見情況是,torch.autograd.Functionforward()backward() 都呼叫其他系統(如 C++、CUDA、numpy、triton)中的函式。

import torch
import numpy as np

def to_numpy(tensor):
    return tensor.cpu().numpy()

class NumpySort(torch.autograd.Function):
    # Note that forward does not take ctx
    @staticmethod
    def forward(x, dim):
        device = x.device
        x = to_numpy(x)
        ind = np.argsort(x, axis=dim)
        ind_inv = np.argsort(ind, axis=dim)
        result = np.take_along_axis(x, ind, axis=dim)
        # Any intermediates to be saved in backward must be returned as
        # outputs.
        return (
            # The desired output
            torch.tensor(result, device=device),
            # intermediate to save for backward
            torch.tensor(ind, device=device),
            # intermediate to save for backward
            torch.tensor(ind_inv, device=device),
        )

    # setup_context is responsible for calling methods and/or assigning to
    # the ctx object. Please do not do additional compute (e.g. add
    # Tensors together) in setup_context.
    @staticmethod
    def setup_context(ctx, inputs, output):
        x, dim = inputs
        # Note that output is whatever you returned from forward.
        # If you returned multiple values, then output is a Tuple of multiple values.
        # If you returned a single Tensor, then output is a Tensor.
        # If you returned a Tuple with a single Tensor, then output is a
        # Tuple with a single Tensor.
        _, ind, ind_inv = output
        ctx.mark_non_differentiable(ind, ind_inv)
        # Tensors must be saved via ctx.save_for_backward. Please do not
        # assign them directly onto the ctx object.
        ctx.save_for_backward(ind, ind_inv)
        # Non-tensors may be saved by assigning them as attributes on the ctx object.
        ctx.dim = dim

    @staticmethod
    def backward(ctx, grad_output, _0, _1):
        # For the autograd.Function to be arbitrarily composable with function
        # transforms, all staticmethod other than forward and setup_context
        # must be implemented in a "transformable" way; that is, they must
        # only consist of PyTorch operations or autograd.Function.
        #
        # For example, this allows us to do double backwards and/or compute
        # second order gradients.
        #
        # We've written the backward pass of NumpySort in terms of another
        # autograd.Function, NumpyTake.
        ind, ind_inv = ctx.saved_tensors
        return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None

class NumpyTake(torch.autograd.Function):
    @staticmethod
    def forward(x, ind, ind_inv, dim):
        device = x.device
        x = to_numpy(x)
        ind = to_numpy(ind)
        return torch.tensor(np.take_along_axis(x, ind, dim), device=device)

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, ind, ind_inv, dim = inputs
        ctx.save_for_backward(ind, ind_inv)
        ctx.dim = dim

    @staticmethod
    def backward(ctx, grad_output):
        ind, ind_inv = ctx.saved_tensors
        result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
        return result, None, None, None

現在,為了更方便地使用 NumpySort(以隱藏我們作為輸出返回的中間結果,並允許使用預設引數和關鍵字引數),我們建立一個新函式來呼叫它。

def numpy_sort(x, dim=-1):
    result, _, _ = NumpySort.apply(x, dim)
    return result

以下是一個健全性檢查

x = torch.randn(2, 3)
grad_x = torch.func.grad(lambda x: numpy_sort(x).sum())(x)
assert torch.allclose(grad_x, torch.ones_like(x))

示例 2:autograd.Function 指定自定義梯度規則

另一種常見情況是,torch.autograd.Function 是用 PyTorch 運算元實現的。PyTorch 能夠自動計算 PyTorch 運算元的梯度,但也許我們希望自定義梯度計算方式。我們可能希望自定義與 PyTorch 提供的不同反向傳播的一些原因包括:

  • 提高數值穩定性

  • 改變反向傳播的效能特性

  • 改變邊緣情況的處理方式(例如 NaN、Inf)

  • 修改梯度(例如梯度裁剪)

以下是一個函式 y = x ** 3torch.autograd.Function 示例,其中我們改變了效能特性(一些通常發生在反向傳播中計算 dx 的計算,被放在了前向傳播中)。

class MyCube(torch.autograd.Function):
    @staticmethod
    def forward(x):
        result = x ** 3
        # In regular PyTorch, if we had just run y = x ** 3, then the backward
        # pass computes dx = 3 * x ** 2. In this autograd.Function, we've done
        # that computation here in the forward pass instead.
        dx = 3 * x ** 2
        return result, dx

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, = inputs
        result, dx = output
        ctx.save_for_backward(x, dx)

    @staticmethod
    def backward(ctx, grad_output, grad_dx):
        x, dx = ctx.saved_tensors
        # In order for the autograd.Function to work with higher-order
        # gradients, we must add the gradient contribution of `dx`.
        result = grad_output * dx + grad_dx * 6 * x
        return result

現在,為了更方便地使用 NumpySort(並隱藏我們作為輸出返回的中間結果),我們建立一個新函式來呼叫它。

def my_cube(x):
    result, _ = MyCube.apply(x)
    return result

以下是計算二階梯度的健全性檢查。

x = torch.randn([])
ggx = torch.func.grad(torch.func.grad(my_cube))(x)
assert torch.allclose(ggx, 6 * x)

限制與陷阱

警告

請仔細閱讀 torch.autograd.Function 與 torch.func 變換結合使用時的這些限制。我們無法捕獲許多此類情況並優雅地報錯,因此它們可能導致未定義行為。

請不要將正在進行變換、設定了 requires_grad=True 或屬於 dual tensors 的張量捕獲到 torch.autograd.Function 的方法中。完全安全的方法是確保在 torch.autograd.Function 的任何方法中使用的張量必須直接作為輸入(或透過 ctx 物件)傳遞,而不是來自 torch.autograd.Function 外部。

torch.autograd.Function 不處理 pytrees(任意巢狀的 Python 資料結構,可能包含或不包含張量)中的張量。為了讓這些張量被 autograd 跟蹤,它們必須直接作為引數傳遞給 torch.autograd.Function。這與 jax.{custom_vjp, custom_jvp} 不同,後者接受 pytrees。

請僅使用 save_for_backward()save_for_forward() 來儲存張量。請不要直接將張量或張量集合賦值給 ctx 物件 - 這些張量不會被跟蹤。

torch.vmap() 支援

要將 torch.autograd.Functiontorch.vmap() 一起使用,您必須選擇以下方法之一:

自動生成 vmap 規則

如果您的 torch.autograd.Function 滿足以下附加約束,則我們可以為其生成 vmap 規則。如果它不滿足約束或者您希望在 vmap 下有自定義行為,請手動定義一個 vmap 靜態方法(參見下一節)。

警告

我們無法輕鬆檢查以下約束並優雅地報錯。違反約束可能導致未定義行為。

示例

class MyCube(torch.autograd.Function):
    # Set generate_vmap_rule to True to ask PyTorch to automatically generate
    # a vmap rule.
    generate_vmap_rule = True

    @staticmethod
    def forward(x):
        result = x ** 3
        dx = 3 * x ** 2
        return result, dx

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, = inputs
        result, dx = output
        ctx.save_for_backward(x, dx)

    @staticmethod
    def backward(ctx, grad_output, grad_dx):
        x, dx = ctx.saved_tensors
        result = grad_output * dx + grad_dx * 6 * x
        return result

def my_cube(x):
    result, dx = MyCube.apply(x)
    return result

x = torch.randn(3)
result = torch.vmap(my_cube)(x)
assert torch.allclose(result, x ** 3)

定義 vmap 靜態方法

如果您的 torch.autograd.Function 呼叫其他系統(如 NumPy、C++、CUDA、triton)中的函式,那麼為了使其與 torch.vmap() 或使用它的變換一起工作,您需要手動定義一個 vmap() 靜態方法。

根據您想要使用的變換和您的用例,您可能不需要為所有的 torch.autograd.Function 新增 vmap() 靜態方法。

不過,我們建議確保您所有的 torch.autograd.Function 都支援 torch.vmap(),特別是如果您正在編寫第三方庫,並且希望您的 torch.autograd.Function 能與 torch.func() 的所有變換組合一起工作。

從概念上講,vmap 靜態方法負責定義 forward()torch.vmap() 下的行為方式。也就是說,它定義瞭如何變換 forward(),使其在具有額外維度(即進行 vmap 的維度)的輸入上執行。這類似於 torch.vmap() 在 PyTorch 運算元上的實現方式:對於每個運算元,我們定義一個 vmap 規則(有時也稱為“批處理規則”)。

以下是如何定義 vmap() 靜態方法:

  • 簽名為 vmap(info, in_dims: Tuple[Optional[int]], *args),其中 *argsforward() 的引數相同。

  • vmap 靜態方法負責定義 forward()torch.vmap() 下的行為方式。也就是說,給定具有額外維度(由 in_dims 指定)的輸入,我們如何計算 forward() 的批處理版本?

  • 對於 args 中的每個引數,in_dims 都有一個對應的 Optional[int]。如果引數不是張量或未進行 vmap 變換,則該值為 None;否則,它是一個整數,指定了張量正在進行 vmap 變換的維度。

  • info 是一個包含可能有用附加元資料的集合:info.batch_size 指定進行 vmap 變換的維度大小,而 info.randomness 是傳遞給 torch.vmap()randomness 選項。

  • vmap 靜態方法的返回是一個包含 (output, out_dims) 的元組。與 in_dims 類似,out_dims 應與 output 具有相同的結構,並且每個輸出包含一個 out_dim,用於指定輸出是否包含 vmap 維度以及該維度的索引位置。

示例

def to_numpy(tensor):
    return tensor.cpu().numpy()

class NumpySort(torch.autograd.Function):
    @staticmethod
    def forward(x, dim):
        device = x.device
        x = to_numpy(x)
        ind = np.argsort(x, axis=dim)
        ind_inv = np.argsort(ind, axis=dim)
        result = np.take_along_axis(x, ind, axis=dim)
        return (
            torch.tensor(result, device=device),
            torch.tensor(ind, device=device),
            torch.tensor(ind_inv, device=device),
        )

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, dim = inputs
        _, ind, ind_inv = output
        ctx.mark_non_differentiable(ind, ind_inv)
        ctx.save_for_backward(ind, ind_inv)
        ctx.dim = dim

    @staticmethod
    def backward(ctx, grad_output, _0, _1):
        ind, ind_inv = ctx.saved_tensors
        return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None

    # The signature of the vmap staticmethod is:
    # vmap(info, in_dims: Tuple[Optional[int]], *args)
    # where *args is the same as the arguments to `forward`.
    @staticmethod
    def vmap(info, in_dims, x, dim):
        # For every input (x and dim), in_dims stores an Optional[int]
        # that is:
        # - None if the input is not being vmapped over or if the input
        #   is not a Tensor
        # - an integer if the input is being vmapped over that represents
        #   the index of the dimension being vmapped over.
        x_bdim, _ = in_dims

        # A "vmap rule" is the logic of how to perform the operation given
        # inputs with one additional dimension. In NumpySort, x has an
        # additional dimension (x_bdim). The vmap rule is simply
        # to call NumpySort again but pass it a different `dim`.
        x = x.movedim(x_bdim, 0)
        # Handle negative dims correctly
        dim = dim if dim >= 0 else dim + x.dim() - 1
        result = NumpySort.apply(x, dim + 1)

        # The vmap rule must return a tuple of two things
        # 1. the output. Should be the same amount of things
        #    as returned by the forward().
        # 2. one Optional[int] for each output specifying if each output
        # is being vmapped over, and if so, the index of the
        # dimension being vmapped over.
        #
        # NumpySort.forward returns a Tuple of 3 Tensors. Since we moved the
        # dimension being vmapped over to the front of `x`, that appears at
        # dimension 0 of all outputs.
        # The return is (output, out_dims) -- output is a tuple of 3 Tensors
        # and out_dims is a Tuple of 3 Optional[int]
        return NumpySort.apply(x, dim + 1), (0, 0, 0)

class NumpyTake(torch.autograd.Function):
    @staticmethod
    def forward(x, ind, ind_inv, dim):
        device = x.device
        x = to_numpy(x)
        ind = to_numpy(ind)
        return torch.tensor(np.take_along_axis(x, ind, dim), device=device)

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, ind, ind_inv, dim = inputs
        ctx.save_for_backward(ind, ind_inv)
        ctx.dim = dim

    @staticmethod
    def backward(ctx, grad_output):
        ind, ind_inv = ctx.saved_tensors
        result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
        return result, None, None, None

    @staticmethod
    def vmap(info, in_dims, x, ind, ind_inv, dim):
        x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims

        # The strategy is: expand {x, ind, ind_inv} to all have the dimension
        # being vmapped over.
        # Then, call back into NumpyTake(expanded_x, expanded_ind, expanded_ind_inv, new_dim).

        # Handle negative dims by wrapping them to be positive
        logical_dim = x.dim() if x_bdim is None else x_bdim - 1
        dim = dim if dim >= 0 else dim + logical_dim

        def maybe_expand_bdim_at_front(x, x_bdim):
            if x_bdim is None:
                return x.expand(info.batch_size, *x.shape)
            return x.movedim(x_bdim, 0)

        # If the Tensor doesn't have the dimension being vmapped over,
        # expand it out. Otherwise, move it to the front of the Tensor
        x = maybe_expand_bdim_at_front(x, x_bdim)
        ind = maybe_expand_bdim_at_front(ind, ind_bdim)
        ind_inv = maybe_expand_bdim_at_front(ind_inv, ind_inv_bdim)

        # The return is a tuple (output, out_dims). Since output is a Tensor,
        # then out_dims is an Optional[int] (instead of being a Tuple).
        return NumpyTake.apply(x, ind, ind_inv, dim + 1), 0

def numpy_sort(x, dim=-1):
    result, _, _ = NumpySort.apply(x, dim)
    return result

x = torch.randn(2, 3)
result = torch.vmap(numpy_sort)(x)
assert torch.allclose(result, numpy_sort(result, 1))

注意

vmap 靜態方法應旨在保留整個 Function 的語義。也就是說,(虛擬碼) grad(vmap(MyFunc)) 應該可以替換為 grad(map(MyFunc))

如果您的 autograd.Function 在反向傳播中包含任何自定義行為,請記住這一點。

注意

為一個 Function 編寫自定義 vmap 靜態方法是合法的用例,即使 PyTorch 可以透過 generate_vmap_rule=True 為其生成 vmap 規則。如果生成的 vmap 規則不具備您所需的語義,您可能希望這樣做。

torch.func.jvp() 支援

為了支援前向模式自動微分 (AD),torch.autograd.Function 必須包含一個 jvp() 靜態方法。詳細資訊請參閱前向模式 AD

文件

訪問 PyTorch 完整的開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲取問題解答

檢視資源