使用 autograd.Function 擴充 torch.func¶
因此,您想要將 torch.autograd.Function 與 torch.func 轉換(例如 torch.vmap()、torch.func.grad() 等)一起使用。
有兩個主要的使用案例
- 您希望呼叫不包含 PyTorch 運算的程式碼,並使其與函式轉換一起使用。也就是說, - torch.autograd.Function的 forward/backward/etc 會呼叫來自其他系統(例如 C++、CUDA、numpy)的函式。
- 您希望指定自訂梯度規則,例如 JAX 的 custom_vjp/custom_jvp 
PyTorch 將這兩個概念結合到 torch.autograd.Function 中。
基本用法¶
本指南假設您熟悉 擴充 torch.autograd,其中說明了如何使用 torch.autograd.Function。
torch.autograd.Function 可以有一個接受 ctx 物件的 forward(),也可以有單獨的 forward()(不接受 ctx)和一個修改 ctx 物件的 setup_context() 靜態方法。
函式轉換僅支援後者
- forward()是執行運算的程式碼,它不應該接受- ctx物件。
- setup_context(ctx, inputs, output)是您可以呼叫- ctx方法的程式碼。在這裡,您應該儲存張量以供反向使用(透過呼叫- ctx.save_for_backward(*tensors)),或儲存非張量(透過將其指派給- ctx物件)。
因為 setup_context() 只接受 inputs 和 output,所以唯一可以儲存的數量是輸入或輸出中的物件(例如張量)或從中衍生的數量(例如 Tensor.shape)。如果您希望從 Function.forward() 儲存非輸入的中間激活以供反向使用,則需要將其作為 forward() 的輸出傳回,以便將其傳遞給 setup_context()。
根據轉換,
- 若要支援反向模式 AD( - torch.func.grad()、- torch.func.vjp()),則- torch.autograd.Function需要- backward()靜態方法。
- 為了支援 - torch.vmap(),- torch.autograd.Function需要一個- vmap()靜態方法。
- 為了支援 - torch.func.jvp(),- torch.autograd.Function需要一個- jvp()靜態方法。
- 為了支援變換的組合(例如 - torch.func.jacrev()、- torch.func.jacfwd()、- torch.func.hessian()),您可能需要以上多個方法。
為了使 torch.autograd.Function 可以與函數變換任意組合,我們建議除了 forward() 和 setup_context() 之外的所有其他靜態方法都必須是可變換的:也就是說,它們必須僅包含 PyTorch 運算符或呼叫其他 torch.autograd.Function(可以呼叫 C++/CUDA/etc)。
讓我們來看一些常見用例的範例。
範例 1:autograd.Function 呼叫另一個系統¶
一個常見的情況是 torch.autograd.Function 的 forward() 和 backward() 都呼叫另一個系統(例如 C++、CUDA、numpy、triton)。
import torch
import numpy as np
def to_numpy(tensor):
    return tensor.cpu().numpy()
class NumpySort(torch.autograd.Function):
    # Note that forward does not take ctx
    @staticmethod
    def forward(x, dim):
        device = x.device
        x = to_numpy(x)
        ind = np.argsort(x, axis=dim)
        ind_inv = np.argsort(ind, axis=dim)
        result = np.take_along_axis(x, ind, axis=dim)
        # Any intermediates to be saved in backward must be returned as
        # outputs.
        return (
            # The desired output
            torch.tensor(result, device=device),
            # intermediate to save for backward
            torch.tensor(ind, device=device),
            # intermediate to save for backward
            torch.tensor(ind_inv, device=device),
        )
    # setup_context is responsible for calling methods and/or assigning to
    # the ctx object. Please do not do additional compute (e.g. add
    # Tensors together) in setup_context.
    @staticmethod
    def setup_context(ctx, inputs, output):
        x, dim = inputs
        # Note that output is whatever you returned from forward.
        # If you returned multiple values, then output is a Tuple of multiple values.
        # If you returned a single Tensor, then output is a Tensor.
        # If you returned a Tuple with a single Tensor, then output is a
        # Tuple with a single Tensor.
        _, ind, ind_inv = output
        ctx.mark_non_differentiable(ind, ind_inv)
        # Tensors must be saved via ctx.save_for_backward. Please do not
        # assign them directly onto the ctx object.
        ctx.save_for_backward(ind, ind_inv)
        # Non-tensors may be saved by assigning them as attributes on the ctx object.
        ctx.dim = dim
    @staticmethod
    def backward(ctx, grad_output, _0, _1):
        # For the autograd.Function to be arbitrarily composable with function
        # transforms, all staticmethod other than forward and setup_context
        # must be implemented in a "transformable" way; that is, they must
        # only consist of PyTorch operations or autograd.Function.
        #
        # For example, this allows us to do double backwards and/or compute
        # second order gradients.
        #
        # We've written the backward pass of NumpySort in terms of another
        # autograd.Function, NumpyTake.
        ind, ind_inv = ctx.saved_tensors
        return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None
class NumpyTake(torch.autograd.Function):
    @staticmethod
    def forward(x, ind, ind_inv, dim):
        device = x.device
        x = to_numpy(x)
        ind = to_numpy(ind)
        return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
    @staticmethod
    def setup_context(ctx, inputs, output):
        x, ind, ind_inv, dim = inputs
        ctx.save_for_backward(ind, ind_inv)
        ctx.dim = dim
    @staticmethod
    def backward(ctx, grad_output):
        ind, ind_inv = ctx.saved_tensors
        result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
        return result, None, None, None
現在,為了更容易使用 NumpySort(隱藏我們作為輸出返回的中間結果,並允許預設參數和關鍵字參數),我們創建一個呼叫它的新函數
def numpy_sort(x, dim=-1):
    result, _, _ = NumpySort.apply(x, dim)
    return result
這是一個健全性檢查
x = torch.randn(2, 3)
grad_x = torch.func.grad(lambda x: numpy_sort(x).sum())(x)
assert torch.allclose(grad_x, torch.ones_like(x))
範例 2:autograd.Function 指定自訂梯度規則¶
另一個常見的情況是用 PyTorch 運算實現的 torch.autograd.Function。PyTorch 可以自動計算 PyTorch 運算的梯度,但我們可能希望自訂梯度的計算方式。我們可能想要一個與 PyTorch 提供的自訂反向傳播不同的原因有一些
- 提高數值穩定性 
- 改變反向傳播的效能特性 
- 改變邊緣情況的處理方式(例如 nan、inf) 
- 修改梯度(例如梯度裁剪) 
這是一個函數 y = x ** 3 的 torch.autograd.Function 範例,我們在其中更改了效能特性(一些通常在反向傳播期間發生的計算,計算 dx,發生在正向傳播中)。
class MyCube(torch.autograd.Function):
    @staticmethod
    def forward(x):
        result = x ** 3
        # In regular PyTorch, if we had just run y = x ** 3, then the backward
        # pass computes dx = 3 * x ** 2. In this autograd.Function, we've done
        # that computation here in the forward pass instead.
        dx = 3 * x ** 2
        return result, dx
    @staticmethod
    def setup_context(ctx, inputs, output):
        x, = inputs
        result, dx = output
        ctx.save_for_backward(x, dx)
    @staticmethod
    def backward(ctx, grad_output, grad_dx):
        x, dx = ctx.saved_tensors
        # In order for the autograd.Function to work with higher-order
        # gradients, we must add the gradient contribution of `dx`.
        result = grad_output * dx + grad_dx * 6 * x
        return result
現在,為了更容易使用 NumpySort(並隱藏我們作為輸出返回的中間結果),我們創建一個呼叫它的新函數
def my_cube(x):
    result, _ = MyCube.apply(x)
    return result
這是一個計算二階梯度的健全性檢查
x = torch.randn([])
ggx = torch.func.grad(torch.func.grad(my_cube))(x)
assert torch.allclose(ggx, 6 * x)
限制和陷阱¶
警告
請仔細閱讀這些使用 torch.func 變換的 torch.autograd.Function 的限制。我們無法捕捉到許多此類情況並正常地拋出錯誤,因此它們將導致未定義的行為。
請不要將正在變換、具有 requires_grad=True 或雙重張量的張量捕捉到 torch.autograd.Function 的方法中。完全安全的做法是確保在 torch.autograd.Function 的任何方法中使用的張量都必須直接作為輸入傳遞(或通過 ctx 物件傳遞),而不是來自 torch.autograd.Function 之外。
torch.autograd.Function 不處理 pytrees(可能包含或不包含張量的任意嵌套 Python 數據結構)中的張量。為了讓 autograd 跟踪這些張量,必須將它們作為參數直接傳遞給 torch.autograd.Function。這與接受 pytrees 的 jax.{custom_vjp、custom_jvp} 形成對比。
請僅使用 save_for_backward() 或 save_for_forward() 來保存張量。請不要將張量或張量集合直接分配給 ctx 物件 - 這些張量將不會被跟踪
torch.vmap() 支援¶
要在 torch.vmap() 中使用 torch.autograd.Function,您必須
- 提供一個 - vmap()靜態方法,告訴我們- torch.autograd.Function在- torch.vmap()下的行為
- 通過設置 - generate_vmap_rule=True要求我們自動生成它。
自動生成 vmap 規則¶
如果您的 torch.autograd.Function 滿足以下附加約束,那麼我們就可以為它生成一個 vmap 規則。如果它不滿足約束,或者您希望在 vmap 下有自訂行為,請手動定義一個 vmap 靜態方法(請參閱下一節)。
警告
我們不容易檢查以下約束並正常地拋出錯誤。違反約束可能會導致未定義的行為。
- torch.autograd.Function的- forward()、- backward()(如果存在)和- jvp()(如果存在)靜態方法必須可以通過- torch.vmap()進行變換。也就是說,它們必須僅包含 PyTorch 運算(而不是例如 NumPy 或自訂 CUDA 內核)。
範例
class MyCube(torch.autograd.Function):
    # Set generate_vmap_rule to True to ask PyTorch to automatically generate
    # a vmap rule.
    generate_vmap_rule = True
    @staticmethod
    def forward(x):
        result = x ** 3
        dx = 3 * x ** 2
        return result, dx
    @staticmethod
    def setup_context(ctx, inputs, output):
        x, = inputs
        result, dx = output
        ctx.save_for_backward(x, dx)
    @staticmethod
    def backward(ctx, grad_output, grad_dx):
        x, dx = ctx.saved_tensors
        result = grad_output * dx + grad_dx * 6 * x
        return result
def my_cube(x):
    result, dx = MyCube.apply(x)
    return result
x = torch.randn(3)
result = torch.vmap(my_cube)(x)
assert torch.allclose(result, x ** 3)
定義 vmap 靜態方法¶
如果您的 torch.autograd.Function 呼叫另一個系統(例如 NumPy、C++、CUDA、triton),那麼要使其與 torch.vmap() 或使用它的變換一起使用,您需要手動定義一個 vmap() 靜態方法。
根據您要使用的變換和您的用例,您可能不需要將 vmap() 靜態方法添加到您的所有 torch.autograd.Function 中
- 例如, - torch.func.jacrev()在反向傳播過程中執行- vmap()。因此,如果您只對使用- torch.func.jacrev()感興趣,則只需將- backward()靜態方法設置為可 vmap。
我們建議確保您所有的 torch.autograd.Function 都支援 torch.vmap(),特別是如果您正在撰寫第三方函式庫,並且希望您的 torch.autograd.Function 可以與所有 torch.func() 轉換的組合一起使用。
從概念上來說,vmap 靜態方法負責定義 forward() 在 torch.vmap() 下的行為。也就是說,它定義瞭如何轉換 forward() 以便在具有額外維度(正在進行 vmap 的維度)的輸入上執行。這與 torch.vmap() 在 PyTorch 操作上的實作方式類似:對於每個操作,我們都定義了一個 vmap 規則(有時也稱為「批次處理規則」)。
以下是定義 vmap() 靜態方法的方法
- 簽章為 - vmap(info, in_dims: Tuple[Optional[int]], *args),其中- *args與- forward()的參數相同。
- vmap 靜態方法負責定義 - forward()在- torch.vmap()下的行為。也就是說,給定具有額外維度(由- in_dims指定)的輸入,我們如何計算- forward()的批次處理版本?
- 對於 - args中的每個參數,- in_dims都有一個對應的- Optional[int]。如果參數不是張量,或者參數沒有被 vmap 處理,則為- None,否則,它是一個整數,指定張量的哪個維度正在被 vmap 處理。
- info是一個可能會有幫助的額外中繼資料集合:- info.batch_size指定正在進行 vmap 的維度的大小,而- info.randomness是傳遞給- torch.vmap()的- randomness選項。
- vmap 靜態方法的回傳值是一個 - (output, out_dims)元組。與- in_dims類似,- out_dims的結構應與- output相同,並且每個輸出都包含一個- out_dim,指定輸出是否具有 vmap 維度以及其索引。
範例
def to_numpy(tensor):
    return tensor.cpu().numpy()
class NumpySort(torch.autograd.Function):
    @staticmethod
    def forward(x, dim):
        device = x.device
        x = to_numpy(x)
        ind = np.argsort(x, axis=dim)
        ind_inv = np.argsort(ind, axis=dim)
        result = np.take_along_axis(x, ind, axis=dim)
        return (
            torch.tensor(result, device=device),
            torch.tensor(ind, device=device),
            torch.tensor(ind_inv, device=device),
        )
    @staticmethod
    def setup_context(ctx, inputs, output):
        x, dim = inputs
        _, ind, ind_inv = output
        ctx.mark_non_differentiable(ind, ind_inv)
        ctx.save_for_backward(ind, ind_inv)
        ctx.dim = dim
    @staticmethod
    def backward(ctx, grad_output, _0, _1):
        ind, ind_inv = ctx.saved_tensors
        return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None
    # The signature of the vmap staticmethod is:
    # vmap(info, in_dims: Tuple[Optional[int]], *args)
    # where *args is the same as the arguments to `forward`.
    @staticmethod
    def vmap(info, in_dims, x, dim):
        # For every input (x and dim), in_dims stores an Optional[int]
        # that is:
        # - None if the input is not being vmapped over or if the input
        #   is not a Tensor
        # - an integer if the input is being vmapped over that represents
        #   the index of the dimension being vmapped over.
        x_bdim, _ = in_dims
        # A "vmap rule" is the logic of how to perform the operation given
        # inputs with one additional dimension. In NumpySort, x has an
        # additional dimension (x_bdim). The vmap rule is simply
        # to call NumpySort again but pass it a different `dim`.
        x = x.movedim(x_bdim, 0)
        # Handle negative dims correctly
        dim = dim if dim >= 0 else dim + x.dim() - 1
        result = NumpySort.apply(x, dim + 1)
        # The vmap rule must return a tuple of two things
        # 1. the output. Should be the same amount of things
        #    as returned by the forward().
        # 2. one Optional[int] for each output specifying if each output
        # is being vmapped over, and if so, the index of the
        # dimension being vmapped over.
        #
        # NumpySort.forward returns a Tuple of 3 Tensors. Since we moved the
        # dimension being vmapped over to the front of `x`, that appears at
        # dimension 0 of all outputs.
        # The return is (output, out_dims) -- output is a tuple of 3 Tensors
        # and out_dims is a Tuple of 3 Optional[int]
        return NumpySort.apply(x, dim + 1), (0, 0, 0)
class NumpyTake(torch.autograd.Function):
    @staticmethod
    def forward(x, ind, ind_inv, dim):
        device = x.device
        x = to_numpy(x)
        ind = to_numpy(ind)
        return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
    @staticmethod
    def setup_context(ctx, inputs, output):
        x, ind, ind_inv, dim = inputs
        ctx.save_for_backward(ind, ind_inv)
        ctx.dim = dim
    @staticmethod
    def backward(ctx, grad_output):
        ind, ind_inv = ctx.saved_tensors
        result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
        return result, None, None, None
    @staticmethod
    def vmap(info, in_dims, x, ind, ind_inv, dim):
        x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims
        # The strategy is: expand {x, ind, ind_inv} to all have the dimension
        # being vmapped over.
        # Then, call back into NumpyTake(expanded_x, expanded_ind, expanded_ind_inv, new_dim).
        # Handle negative dims by wrapping them to be positive
        logical_dim = x.dim() if x_bdim is None else x_bdim - 1
        dim = dim if dim >= 0 else dim + logical_dim
        def maybe_expand_bdim_at_front(x, x_bdim):
            if x_bdim is None:
                return x.expand(info.batch_size, *x.shape)
            return x.movedim(x_bdim, 0)
        # If the Tensor doesn't have the dimension being vmapped over,
        # expand it out. Otherwise, move it to the front of the Tensor
        x = maybe_expand_bdim_at_front(x, x_bdim)
        ind = maybe_expand_bdim_at_front(ind, ind_bdim)
        ind_inv = maybe_expand_bdim_at_front(ind_inv, ind_inv_bdim)
        # The return is a tuple (output, out_dims). Since output is a Tensor,
        # then out_dims is an Optional[int] (instead of being a Tuple).
        return NumpyTake.apply(x, ind, ind_inv, dim + 1), 0
def numpy_sort(x, dim=-1):
    result, _, _ = NumpySort.apply(x, dim)
    return result
x = torch.randn(2, 3)
result = torch.vmap(numpy_sort)(x)
assert torch.allclose(result, numpy_sort(result, 1))
注意
vmap 靜態方法應旨在保留整個 Function 的語義。也就是說,(偽代碼)grad(vmap(MyFunc)) 應該可以用 grad(map(MyFunc)) 代替。
如果您的 autograd.Function 在反向傳播中有任何自訂行為,請牢記這一點。
注意
為 PyTorch 能夠透過 generate_vmap_rule=True 為其產生 vmap 規則的 Function 撰寫自訂 vmap 靜態方法是一個合理的用例。如果您希望產生的 vmap 規則沒有您想要的語義,則可以這樣做。
torch.func.jvp() 支援¶
為了支援前向模式自動微分,torch.autograd.Function 必須有一個 jvp() 靜態方法。詳情請參閱 前向模式自動微分。