擴充 PyTorch¶
在本說明中,我們將介紹擴充 torch.nn、torch.autograd、torch 的方法,以及撰寫自訂 C++ 擴充功能。
新增運算子¶
PyTorch 提供了大量處理張量的運算子(例如 torch.add()、torch.sum() 等)。但是,您可能希望將新的自訂運算帶入 PyTorch,並使其行為類似於 PyTorch 的內建運算子。為此,您必須透過 Python torch.library 或 C++ TORCH_LIBRARY API 向 PyTorch 註冊自訂運算。
如需更多詳細資訊,請參閱 PyTorch 自訂運算子登陸頁面。
擴充 torch.autograd¶
將運算新增至 autograd 需要為每個運算實作新的 Function 子類別。回想一下,函數是 autograd 用於編碼運算歷史記錄和計算梯度的工具。
本文件的第 1 部分重點介紹反向模式 AD,因為它是使用最廣泛的功能。最後一部分討論了正向模式 AD 的擴充功能。
何時使用¶
一般而言,如果您想要在模型中執行不可微分或依賴非 PyTorch 程式庫(例如 NumPy)的計算,但仍然希望您的運算與其他運算鏈接並與 autograd 引擎一起使用,則應實作自訂函數。
在某些情況下,自訂函數也可以用於提高效能和記憶體使用量:如果您使用 C++ 擴充功能 實作了正向和反向傳遞,則可以將它們包裝在 Function 中,以便與 autograd 引擎互動。如果您想減少為反向傳遞儲存的緩衝區數量,則可以使用自訂函數將運算組合在一起。
何時不使用¶
如果您已經可以使用 PyTorch 的內建運算來編寫函數,則其反向圖(很可能)已經可以由 autograd 記錄。在這種情況下,您不需要自己實作反向函數。請考慮使用普通的 Python 函數。
如果您需要維護狀態,即可訓練參數,則您應該(也)使用自訂模組。如需有關擴充 torch.nn 的更多資訊,請參閱以下部分。
使用方法¶
請執行以下步驟: 1. 建立 Function 的子類別,並實作 forward()、(可選)setup_context() 和 backward() 方法。 2. 對 ctx 引數呼叫適當的方法。 3. 宣告您的函數是否支援 雙重反向傳遞。 4. 使用 gradcheck 驗證您的梯度是否正確。
**步驟 1:**建立 Function 的子類別後,您需要定義 3 個方法
- forward()是執行運算的程式碼。它可以接受任意數量的引數,如果您指定預設值,則其中一些引數可以是可選的。這裡接受所有類型的 Python 物件。 追蹤歷史記錄的- Tensor引數(例如,使用- requires_grad=True)將在呼叫之前轉換為不追蹤歷史記錄的引數,並且它們的使用將被註冊在圖表中。 請注意,此邏輯不會遍歷列表/字典/任何其他資料結構,並且只會考慮作為呼叫直接引數的張量。 您可以返回單個- Tensor輸出,或者如果有多個輸出,則返回一個- tuple的張量。 同時,另請參閱- Function的文件,以查找僅能從- forward()呼叫的有用方法的說明。
- setup_context()(可選)。可以編寫一個「組合式」的- forward(),它接受一個- ctx物件,或者(從 PyTorch 2.0 開始)一個不接受- ctx的獨立- forward()和一個發生- ctx修改的- setup_context()方法。- forward()應該進行計算,而- setup_context()應該只負責- ctx的修改(並且不進行任何計算)。一般來說,獨立的- forward()和- setup_context()更接近於 PyTorch 原生操作的工作方式,因此可以更好地與各種 PyTorch 子系統組合。有關更多詳細資訊,請參閱 組合式或獨立的 forward() 和 setup_context()。
- backward()(或- vjp())定義了梯度公式。它將獲得與輸出數量相同的- Tensor引數,每個引數都代表相對於該輸出的梯度。重要的是永遠不要就地修改這些引數。它應該返回與輸入數量相同的張量,每個張量都包含相對於其對應輸入的梯度。如果您的輸入不需要梯度(- needs_input_grad是一個布林值元組,指示每個輸入是否需要梯度計算),或者是非- Tensor物件,則您可以返回- python:None。此外,如果您有- forward()的可選引數,您可以返回比輸入更多的梯度,只要它們都是- None。
**步驟 2:**您有責任正確使用 ctx 中的函數,以確保新的 Function 能與自動微分引擎正常運作。
- 必須使用 - save_for_backward()來儲存要在反向傳遞中使用的任何張量。非張量應直接儲存在 ctx 上。如果將既不是輸入也不是輸出的張量儲存起來以供反向使用,則您的- Function可能不支援雙重反向(請參閱步驟 3)。
- 必須使用 - mark_dirty()來標記任何由正向函數就地修改的輸入。
- 必須使用 - mark_non_differentiable()來告訴引擎輸出是否不可微分。預設情況下,所有可微分類型的輸出張量都將被設定為需要梯度。不可微分類型(即整數類型)的張量永遠不會被標記為需要梯度。
- 可以使用 - set_materialize_grads()來告訴自動微分引擎,在輸出不依賴於輸入的情況下,通過不實例化提供給反向函數的梯度張量來優化梯度計算。也就是說,如果設定為 False,Python 中的 None 物件或 C++ 中的「未定義張量」(張量 x,其 x.defined() 為 False)在呼叫反向之前不會轉換為填充了零的張量,因此您的程式碼需要處理此類物件,就好像它們是填充了零的張量一樣。此設定的預設值為 True。
**步驟 3:**如果您的 Function 不支援雙重反向,您應該使用 once_differentiable() 裝飾器來明確宣告這一點。使用此裝飾器時,嘗試通過您的函數執行雙重反向將會產生錯誤。有關雙重反向的更多資訊,請參閱我們的雙重反向教學。
**步驟 4:**建議您使用 torch.autograd.gradcheck() 來檢查您的反向函數是否正確計算了正向函數的梯度,方法是使用您的反向函數計算雅可比矩陣,並將其值與使用有限差分法數值計算的雅可比矩陣進行逐元素比較。
範例¶
您可以在下方找到一個 Linear 函數的程式碼,以及其他註釋
# Inherit from Function
class LinearFunction(Function):
    # Note that forward, setup_context, and backward are @staticmethods
    @staticmethod
    def forward(input, weight, bias):
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output
    @staticmethod
    # inputs is a Tuple of all of the inputs passed to forward.
    # output is the output of the forward().
    def setup_context(ctx, inputs, output):
        input, weight, bias = inputs
        ctx.save_for_backward(input, weight, bias)
    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output):
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None
        # These needs_input_grad checks are optional and there only to
        # improve efficiency. If you want to make your code simpler, you can
        # skip them. Returning gradients for inputs that don't require it is
        # not an error.
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)
        return grad_input, grad_weight, grad_bias
現在,為了更容易使用這些自定義運算,我們建議您為它們設定別名或將它們封裝在一個函數中。封裝在一個函數中可以讓我們支援預設引數和關鍵字引數
# Option 1: alias
linear = LinearFunction.apply
# Option 2: wrap in a function, to support default args and keyword args.
def linear(input, weight, bias=None):
    return LinearFunction.apply(input, weight, bias)
在這裡,我們給出一個額外的範例,說明一個由非張量引數參數化的函數
class MulConstant(Function):
    @staticmethod
    def forward(tensor, constant):
        return tensor * constant
    @staticmethod
    def setup_context(ctx, inputs, output):
        # ctx is a context object that can be used to stash information
        # for backward computation
        tensor, constant = inputs
        ctx.constant = constant
    @staticmethod
    def backward(ctx, grad_output):
        # We return as many input gradients as there were arguments.
        # Gradients of non-Tensor arguments to forward must be None.
        return grad_output * ctx.constant, None
在這裡,我們通過呼叫 set_materialize_grads(False) 來優化上述範例
class MulConstant(Function):
    @staticmethod
    def forward(tensor, constant):
        return tensor * constant
    @staticmethod
    def setup_context(ctx, inputs, output):
        tensor, constant = inputs
        ctx.set_materialize_grads(False)
        ctx.constant = constant
    @staticmethod
    def backward(ctx, grad_output):
        # Here we must handle None grad_output tensor. In this case we
        # can skip unnecessary computations and just return None.
        if grad_output is None:
            return None, None
        # We return as many input gradients as there were arguments.
        # Gradients of non-Tensor arguments to forward must be None.
        return grad_output * ctx.constant, None
如果您需要儲存在 forward() 中計算的任何「中間」張量,則必須將它們作為輸出返回,或者組合 forward 和 setup_context()(請參閱 組合式或獨立的 forward() 和 setup_context())。請注意,這意味著如果您希望梯度流經這些中間值,您需要為它們定義梯度公式(另請參閱 雙重反向教學)
class MyCube(torch.autograd.Function):
    @staticmethod
    def forward(x):
        # We wish to save dx for backward. In order to do so, it must
        # be returned as an output.
        dx = 3 * x ** 2
        result = x ** 3
        return result, dx
    @staticmethod
    def setup_context(ctx, inputs, output):
        x, = inputs
        result, dx = output
        ctx.save_for_backward(x, dx)
    @staticmethod
    def backward(ctx, grad_output, grad_dx):
        x, dx = ctx.saved_tensors
        # In order for the autograd.Function to work with higher-order
        # gradients, we must add the gradient contribution of `dx`,
        # which is grad_dx * 6 * x.
        result = grad_output * dx + grad_dx * 6 * x
        return result
# Wrap MyCube in a function so that it is clearer what the output is
def my_cube(x):
    result, dx = MyCube.apply(x)
    return result
注意
backward 的輸入,即 grad_output,也可以是追蹤歷史記錄的張量。因此,如果使用可微分運算(例如,呼叫另一個自定義 Function)實作 backward,則高階導數將會起作用。在這種情況下,使用 save_for_backward 儲存的張量也可以在反向中使用,並且梯度會反向流動,但儲存在 ctx 中的張量不會有梯度反向流動。如果您需要梯度反向流動到儲存在 ctx 中的張量,您應該將其設為自定義 Function 的輸出,並使用 save_for_backward 儲存它。
您可能想檢查您實作的反向方法是否真的計算了函數的導數。這可以通過與使用小有限差分的數值近似值進行比較來實現
from torch.autograd import gradcheck
# gradcheck takes a tuple of tensors as input, check if your gradient
# evaluated with these tensors are close enough to numerical
# approximations and returns True if they all verify this condition.
input = (torch.randn(20,20,dtype=torch.double,requires_grad=True), torch.randn(30,20,dtype=torch.double,requires_grad=True))
test = gradcheck(linear, input, eps=1e-6, atol=1e-4)
print(test)
有關有限差分梯度比較的更多詳細資訊,請參閱 數值梯度檢查。如果您的函數用於高階導數(對反向傳遞求導),您可以使用相同套件中的 gradgradcheck 函數來檢查高階導數。
組合式或獨立的 forward() 和 setup_context()¶
定義 Function 主要有兩種方法。可以
我們建議第二個選項(獨立的 forward() 和 setup_context()),因為這更接近於 PyTorch 原生操作的實現方式,並且它可以與 torch.func 變換組合使用。然而,我們計劃在未來同時支援這兩種方法;結合 forward() 和 setup_context():可以帶來更大的靈活性,因為您可以在不將中間結果作為輸出返回的情況下儲存它們。
關於如何使用獨立的 forward() 和 setup_context() 定義 Function,請參閱上一節。
以下是如何使用組合的 forward() 和 setup_context() 定義 Function 的範例
class LinearFunction(Function):
    @staticmethod
    # ctx is the first argument to forward
    def forward(ctx, input, weight, bias=None):
        # The forward pass can use ctx.
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output
    @staticmethod
    def backward(ctx, grad_output):
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)
        return grad_input, grad_weight, grad_bias
正向模式自動微分¶
覆寫正向模式自動微分公式的 API 非常相似,但有一些細微的差異。您可以實作 jvp() 函數。
它將接收與輸入數量相同的 Tensor 引數,每個引數代表相對於該輸入的梯度。它應該返回與輸出數量相同的張量,每個張量包含相對於其對應輸出的梯度。jvp() 將在 forward() 方法之後、apply() 返回之前被呼叫。
jvp() 與 backward() 函數有一些細微的差異
- 您可以使用 ctx 將任何資料從 - forward()傳遞到- jvp()函數。如果- backward()不需要該狀態,您可以在- jvp()函數的末尾使用- del ctx.foo明確地釋放它。
- jvp()的實作必須是可反向微分的,或者明確檢查沒有任何給定的正向模式梯度設定了- requires_grad。
- jvp()函數必須與- forward()的視圖/原地行為相符。例如,如果第- i個輸入被原地修改,則第- i個梯度必須被原地更新。類似地,如果第- j個輸出是第- k個輸入的視圖,則返回的第- j個輸出梯度必須是給定的第- k個輸入梯度的視圖。
- 因為使用者無法指定需要計算哪個梯度,所以 - jvp()函數應該始終計算所有輸出的梯度。
- 正向模式梯度確實會遵守 - set_materialize_grads()設定的旗標,並且當此旗標被禁用時,您可以獲得 None 輸入梯度。
torch.func 變換和/或 torch.vmap()¶
擴展 torch.nn¶
nn 匯出兩種介面 - 模組及其函數版本。您可以透過這兩種方式擴展它,但我們建議對所有持有任何參數或緩衝區的層使用模組,並建議對激活函數、池化等無參數操作使用函數形式。
在上一節中已經完整涵蓋了添加操作的函數版本。
添加 Module¶
由於 nn 大量利用了 autograd,因此添加新的 Module 需要實作一個執行操作並且可以計算梯度的 Function。從現在開始,假設我們想要實作一個 Linear 模組,並且我們已經像上面列出的那樣實作了該函數。添加這個只需要很少的程式碼。現在,需要實作兩個函數
以下是如何實作 Linear 模組
class Linear(nn.Module):
    def __init__(self, input_features, output_features, bias=True):
        super().__init__()
        self.input_features = input_features
        self.output_features = output_features
        # nn.Parameter is a special kind of Tensor, that will get
        # automatically registered as Module's parameter once it's assigned
        # as an attribute. Parameters and buffers need to be registered, or
        # they won't appear in .parameters() (doesn't apply to buffers), and
        # won't be converted when e.g. .cuda() is called. You can use
        # .register_buffer() to register buffers.
        # nn.Parameters require gradients by default.
        self.weight = nn.Parameter(torch.empty(output_features, input_features))
        if bias:
            self.bias = nn.Parameter(torch.empty(output_features))
        else:
            # You should always register all possible parameters, but the
            # optional ones can be None if you want.
            self.register_parameter('bias', None)
        # Not a very smart way to initialize weights
        nn.init.uniform_(self.weight, -0.1, 0.1)
        if self.bias is not None:
            nn.init.uniform_(self.bias, -0.1, 0.1)
    def forward(self, input):
        # See the autograd section for explanation of what happens here.
        return LinearFunction.apply(input, self.weight, self.bias)
    def extra_repr(self):
        # (Optional)Set the extra information about this module. You can test
        # it by printing an object of this class.
        return 'input_features={}, output_features={}, bias={}'.format(
            self.input_features, self.output_features, self.bias is not None
        )
擴展 torch Python API¶
您可以透過定義一個自訂類別,並使用與 Tensor 相符的方法來建立模擬 Tensor 的自訂類型。但是,如果您希望能夠將這些類型傳遞給頂級 torch 命名空間中接受 Tensor 運算元的函數,例如 torch.add(),該怎麼辦?
如果您的自訂 Python 類型定義了一個名為 __torch_function__ 的方法,當您的自訂類別的實例被傳遞給 torch 命名空間中的函數時,PyTorch 將呼叫您的 __torch_function__ 實作。這使得您可以為 torch 命名空間中的任何函數定義自訂實作,您的 __torch_function__ 實作可以呼叫這些函數,從而允許您的使用者將您的自訂類型與他們已經為 Tensor 編寫的現有 PyTorch 工作流程一起使用。這適用於與 Tensor 無關的“鴨子”類型,以及 Tensor 的使用者定義子類別。
使用類似 Tensor 的類型擴展 torch¶
為了具體說明,讓我們從一個簡單的例子開始,說明 API 調度機制。我們將創建一個自定義類型來表示二維純量張量,由階數 N 和對角線項的值 value 進行參數化。
class ScalarTensor(object):
   def __init__(self, N, value):
       self._N = N
       self._value = value
   def __repr__(self):
       return "ScalarTensor(N={}, value={})".format(self._N, self._value)
   def tensor(self):
       return self._value * torch.eye(self._N)
設計的第一個迭代版本並不是很有用。ScalarTensor 的主要功能是提供比基本張量類別更緊湊的純量張量字串表示法。
>>> d = ScalarTensor(5, 2)
>>> d
ScalarTensor(N=5, value=2)
>>> d.tensor()
tensor([[2., 0., 0., 0., 0.],
        [0., 2., 0., 0., 0.],
        [0., 0., 2., 0., 0.],
        [0., 0., 0., 2., 0.],
        [0., 0., 0., 0., 2.]])
如果我們嘗試將此物件與 torch API 一起使用,我們會遇到問題。
>>> import torch
>>> torch.mean(d)
TypeError: mean(): argument 'input' (position 1) must be Tensor, not ScalarTensor
將 __torch_function__ 實作新增至 ScalarTensor 可以讓上述操作成功。讓我們重新進行實作,這次新增 __torch_function__ 實作。
HANDLED_FUNCTIONS = {}
class ScalarTensor(object):
    def __init__(self, N, value):
        self._N = N
        self._value = value
    def __repr__(self):
        return "ScalarTensor(N={}, value={})".format(self._N, self._value)
    def tensor(self):
        return self._value * torch.eye(self._N)
    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        if func not in HANDLED_FUNCTIONS or not all(
            issubclass(t, (torch.Tensor, ScalarTensor))
            for t in types
        ):
            return NotImplemented
        return HANDLED_FUNCTIONS[func](*args, **kwargs)
__torch_function__ 方法採用四個參數:func,對正在被覆寫的 torch API 函數的引用,types,實作 __torch_function__ 的類張量類型的清單,args,傳遞給函數的參數元組,以及 kwargs,傳遞給函數的關鍵字參數字典。它使用名為 HANDLED_FUNCTIONS 的全域調度表來儲存自定義實作。此字典的鍵是 torch 命名空間中的函數,而值是 ScalarTensor 的實作。
注意
使用全域調度表並不是 __torch_function__ API 的強制性部分,它只是用於建構覆寫實作的有用設計模式。
這個類別定義還不足以讓 torch.mean 在我們傳遞 ScalarTensor 時做正確的事情——我們還需要為 ScalarTensor 運算元定義 torch.mean 的實作,並將實作新增至 HANDLED_FUNCTIONS 調度表字典。一種方法是定義一個裝飾器
import functools
def implements(torch_function):
    """Register a torch function override for ScalarTensor"""
    def decorator(func):
        functools.update_wrapper(func, torch_function)
        HANDLED_FUNCTIONS[torch_function] = func
        return func
    return decorator
可以應用於我們的覆寫實作
@implements(torch.mean)
def mean(input):
    return float(input._value) / input._N
透過此變更,我們現在可以將 torch.mean 與 ScalarTensor 一起使用
>>> d = ScalarTensor(5, 2)
>>> torch.mean(d)
0.4
當然,torch.mean 是最簡單的函數覆寫範例,因為它只採用一個運算元。我們可以使用相同的機制來覆寫採用多個運算元的函數,其中任何一個運算元都可能是定義 __torch_function__ 的張量或類張量,例如 torch.add()
def ensure_tensor(data):
    if isinstance(data, ScalarTensor):
        return data.tensor()
    return torch.as_tensor(data)
@implements(torch.add)
def add(input, other):
   try:
       if input._N == other._N:
           return ScalarTensor(input._N, input._value + other._value)
       else:
           raise ValueError("Shape mismatch!")
   except AttributeError:
       return torch.add(ensure_tensor(input), ensure_tensor(other))
此版本在兩個運算元都是 ScalarTensor 執行個體時具有快速路徑,並且在任何一個運算元不是 ScalarTensor 時也具有降級為將資料轉換為張量的較慢路徑。這使得覆寫函數在任何一個運算元是 ScalarTensor 或一般 Tensor 時都能正確運作。
>>> s = ScalarTensor(2, 2)
>>> torch.add(s, s)
ScalarTensor(N=2, value=4)
>>> t = torch.tensor([[1, 1,], [1, 1]])
>>> torch.add(s, t)
tensor([[3., 1.],
        [1., 3.]])
請注意,我們的 add 實作不像 torch.add() 那樣將 alpha 或 out 作為關鍵字參數。
>>> torch.add(s, s, alpha=2)
TypeError: add() got an unexpected keyword argument 'alpha'
為了速度和靈活性,__torch_function__ 調度機制不會檢查覆寫函數的簽章是否與 torch API 中正在被覆寫的函數的簽章相符。對於某些應用程式,忽略可選參數是可以的,但為了確保與 Tensor 完全相容,torch API 函數的使用者實作應注意完全模擬正在被覆寫的函數的 API。
torch API 中沒有明確覆寫的函數將從 __torch_function__ 傳回 NotImplemented。如果所有在其上定義了 __torch_function__ 的運算元都傳回 NotImplemented,則 PyTorch 將引發 TypeError。這意味著在大多數情況下,當傳遞此類類型的執行個體時,沒有為類型明確覆寫的操作將引發 TypeError。
>>> torch.mul(s, 3)
TypeError: no implementation found for 'torch.mul' on types that
implement __torch_function__: [ScalarTensor]
在實務中,這意味著如果您想使用 __torch_function__ 實作來實作覆寫,您需要明確地為您的使用案例實作完整的 torch API 或您關心的整個 API 子集。這可能是一項艱鉅的任務,因為完整的 torch API 相當龐大。
另一種選擇是不為未處理的操作傳回 NotImplemented,而是在沒有可用覆寫時將 Tensor 傳遞給原始的 torch 函數。例如,如果我們將 ScalarTensor 的 __torch_function__ 實作更改為以下內容
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
    if kwargs is None:
        kwargs = {}
    if func not in HANDLED_FUNCTIONS or not all(
            issubclass(t, (torch.Tensor, ScalarTensor))
            for t in types
        ):
        args = [a.tensor() if hasattr(a, 'tensor') else a for a in args]
        return func(*args, **kwargs)
    return HANDLED_FUNCTIONS[func](*args, **kwargs)
那麼 torch.mul() 將正常運作,儘管傳回類型將始終是 Tensor 而不是 ScalarTensor,即使兩個運算元都是 ScalarTensor 執行個體。
>>> s = ScalarTensor(2, 2)
>>> torch.mul(s, s)
tensor([[4., 0.],
        [0., 4.]])
另請參閱下面的 MetadataTensor 範例,了解此模式的另一個變體,但它始終傳回 MetadataTensor 以透過 torch API 中的操作傳播中繼資料。
__torch_function__ 協定旨在完全涵蓋 API,部分涵蓋可能會導致不良結果,特別是某些函數引發 TypeError。對於子類別尤其如此,其中 torch.add、torch.Tensor.__add__ 和 torch.Tensor.add 這三者都必須涵蓋,即使它們傳回完全相同的結果。不這樣做也可能導致無限遞迴。如果需要來自 torch.Tensor 子類別的函數實作,則必須在其實作內使用 super().__torch_function__。
將 torch.Tensor 子類別化¶
從 1.7.0 版開始,torch.Tensor 上的方法和應用於 torch.Tensor 子類別的公共 torch.* 命名空間中的函數將傳回子類別執行個體,而不是 torch.Tensor 執行個體。
>>> class SubTensor(torch.Tensor):
...     pass
>>> type(torch.add(SubTensor([0]), SubTensor([1]))).__name__
'SubTensor'
>>> type(torch.add(SubTensor([0]), torch.tensor([1]))).__name__
'SubTensor'
如果存在多個子類別,則預設會選擇階層結構中最低的子類別。如果沒有唯一的方法來確定這種情況,則會引發 TypeError。
>>> type(torch.add(SubTensor2([0]), SubTensor([1]))).__name__
'SubTensor2'
>>> type(torch.add(SubTensor2([0]), torch.tensor([1]))).__name__
'SubTensor2'
>>> torch.add(SubTensor([0]), OtherSubTensor([1]))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: no implementation found for 'torch.add' on types that implement __torch_function__: [SubTensor, OtherSubTensor]
如果希望對所有張量方法進行全域覆寫,則可以使用 __torch_function__。以下是一個記錄所有函數/方法呼叫的範例。
class LoggingTensor(torch.Tensor):
    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        # NOTE: Logging calls Tensor.__repr__, so we can't log __repr__ without infinite recursion
        if func is not torch.Tensor.__repr__:
            logging.info(f"func: {func.__name__}, args: {args!r}, kwargs: {kwargs!r}")
        if kwargs is None:
            kwargs = {}
        return super().__torch_function__(func, types, args, kwargs)
但是,如果希望覆寫 Tensor 子類別上的方法,則可以透過直接覆寫方法(透過為子類別定義它)或使用 __torch_function__ 並與 func 匹配來實現。
對於子類別,在 __torch_function__ 中應該小心始終呼叫 super().__torch_function__(func, ...) 而不是像 1.7.0 版之前那樣直接呼叫 func。不這樣做可能會導致 func 遞迴回 __torch_function__,從而導致無限遞迴。
使用 Tensor 包裝器類型擴展 torch¶
另一個有用的情況是包裝 Tensor 的類型,作為屬性或透過子類別化。下面我們實現了這種特殊情況的類型,一個 MetadataTensor,它將中繼資料字典附加到透過 torch 操作傳播的 Tensor。由於這是對整個 torch API 的通用包裝,我們不需要單獨實作每個覆寫,因此我們可以使 __torch_function__ 實作對允許的操作更寬鬆。
class MetadataTensor(object):
    def __init__(self, data, metadata=None, **kwargs):
        self._t = torch.as_tensor(data, **kwargs)
        self._metadata = metadata
    def __repr__(self):
        return "Metadata:\n{}\n\ndata:\n{}".format(self._metadata, self._t)
    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        metadatas = tuple(a._metadata for a in args if hasattr(a, '_metadata'))
        args = [getattr(a, '_t', a) for a in args]
        assert len(metadatas) > 0
        ret = func(*args, **kwargs)
        return MetadataTensor(ret, metadata=metadatas[0])
這個簡單的實作不一定適用於 torch API 中的每個函數,但它足以捕獲最常見的操作。
>>> metadata = {'owner': 'Ministry of Silly Walks'}
>>> m = MetadataTensor([[1, 2], [3, 4]], metadata=metadata)
>>> t = torch.tensor([[1, 2], [1, 2]])
>>> torch.add(t, m)
Metadata:
{'owner': 'Ministry of Silly Walks'}
data:
tensor([[2, 4],
        [4, 6]])
>>> torch.mul(t, m)
Metadata:
{'owner': 'Ministry of Silly Walks'}
data:
tensor([[1, 4],
        [3, 8]])
對定義了 __torch_function__ 的多種類型的操作¶
可以使用 torch API 搭配多種不同的類型,每個類型都具有 __torch_function__ 的實作,但必須格外小心。在這種情況下,規則如下
- 派送操作會收集每個運算元所有不同的 - __torch_function__實作,並依序呼叫它們:子類別先於父類別,否則在運算子運算式中從左到右。
- 如果傳回任何非 - NotImplemented的值,則該值將作為結果傳回。實作可以透過傳回- NotImplemented來註冊它們沒有實作某個操作。
- 如果所有 - __torch_function__實作都傳回- NotImplemented,PyTorch 會引發- TypeError。
測試 PyTorch API 覆寫的涵蓋範圍¶
實作 __torch_function__ 的一個麻煩之處在於,如果某些操作有覆寫而其他操作沒有,使用者頂多會看到不一致的體驗,或者最糟的情況是在執行時使用沒有覆寫的函式時看到引發錯誤。為了簡化此過程,PyTorch 提供了一個面向開發人員的 API,用於確保完全支援 __torch_function__ 覆寫。此 API 是私有的,未來可能會在沒有警告的情況下進行變更。
首先,若要取得所有可覆寫函式的清單,請使用 torch.overrides._get_overridable_functions。這會傳回一個字典,其鍵是 PyTorch Python API 中的命名空間,其值是可以覆寫的該命名空間中的函式清單。例如,讓我們列印 torch.nn.functional 中可以覆寫的前 5 個函式的名稱
>>> from torch.overrides import get_overridable_functions
>>> func_dict = get_overridable_functions()
>>> nn_funcs = func_dict[torch.nn.functional]
>>> print([f.__name__ for f in nn_funcs[:5])
['adaptive_avg_pool1d', 'adaptive_avg_pool2d', 'adaptive_avg_pool3d',
 'adaptive_max_pool1d', 'adaptive_max_pool1d_with_indices']
此函式清單可以迭代所有可覆寫的函式,然而在實務上,這還不足以針對所有這些函式編寫測試,而無需費力地手動複製每個測試的每個函式簽章。為了簡化此過程,torch.overrides._get_testing_overrides 函式會傳回一個字典,將 PyTorch API 中的可覆寫函式映射到虛擬 lambda 函式,這些函式具有與原始函式相同的簽章,但無條件地傳回 -1。這些函式最適合與 inspect 一起使用,以分析原始 PyTorch 函式的函式簽章
>>> import inspect
>>> from torch.overrides import get_testing_overrides
>>> override_dict = get_testing_overrides()
>>> dummy_add = override_dict[torch.add]
>>> inspect.signature(dummy_add)
<Signature (input, other, out=None)>
最後,torch.overrides.get_ignored_functions 會傳回一個函式元組,這些函式明確地無法被 __torch_function__ 覆寫。此清單可用於確認 get_overridable_functions 傳回的字典中不存在的函式無法被覆寫。
擴充 torch 原生 API¶
雖然 __torch_function__ 允許有效地擴充 PyTorch 純 Python 元件的行為,但它不允許擴充以 C++ 實作的 PyTorch 部分。為此,Tensor 子類別也可以定義 __torch_dispatch__,它將能夠在 C++ 層級覆寫行為。
為了有效地使用此功能,瞭解 PyTorch 原生部分的實作方式非常重要。其中最重要的元件是我們所謂的「派送器」(最佳描述可以在這篇 部落格文章 中找到,即使它有點過時了)。顧名思義,它負責為特定函式呼叫呼叫正確的後端函式。例如,當呼叫 torch.add(a, b) 時,派送器將檢查兩個參數,找出應該為此特定呼叫使用哪個「功能」(自動梯度、自動轉換、函式化等)以及哪個「後端」(CPU、CUDA、MPS 等),最後呼叫所有正確的核心。核心非常常見的一件事是「重新派送」。例如,當使用自動轉換在 GPU 上執行神經網路時,第一個呼叫將是自動轉換核心,它將處理任何潛在的自動轉換邏輯並向下重新派送。下一個功能將是自動梯度,它將正確建立自動梯度圖,然後向下重新派送。最後,我們到達 CUDA 的後端核心,它將啟動正確的 CUDA 核心並傳回最終結果。在輸出的過程中,自動梯度會將圖附加到輸出,最後,自動轉換將有機會進行任何需要的更新。
派送器的一種配置是呼叫所有這些功能和後端鍵的順序。最新的清單及其順序可以在 DispatchKey.h 內部的 DispatchKey 列舉中找到。為了擴充 torch,此討論的重要順序子集是
vmap -> Autocast -> Autograd -> ZeroTensor -> Neg/Conj -> Functionalize -> Python -> 後端
為了本次討論,最重要的鍵是 Python,因為每個定義了 __torch_dispatch__ 方法的 Tensor 子類別都會呼叫此功能。使用者定義的方法就是從這裡被呼叫的,並且可以在這裡任意覆寫行為。從那裡開始,再次呼叫提供的 func 將執行「重新派送」。
此實作的一些重要含義是
- 此程式碼在「所有功能之下」執行。因此,它僅像常規後端一樣,負責產生每個 Tensor 的輸出值(並且可以而且應該忽略所有進階功能,例如自動梯度、自動轉換等)。 
- 如果任何高階功能在沒有重新派送的情況下實作了給定函式,它將永遠不會到達 - Python鍵,因此- __torch_dispatch__回呼將永遠不會被觸發。這尤其發生在 CompositeImplicitAutograd 函式中,這些函式在 Autograd 層級進行評估而無需重新派送。這是因為 CompositeImplicitAutograd 函式透過隱式呼叫其他原生操作來指定其自動梯度公式,因此在 Autograd 層級,該函式會被分解成其原生操作,並對這些操作進行評估。
- 當回呼 Python 和包裝結果時,會使用與常規 PyTorch Python/C++ 繫結相同的轉換。特別是,某些物件無法在 Python 中表示,需要特殊處理(例如,未定義的 Tensor 會變成 None)。 
- 我們的原生函式會被延遲地填充為可呼叫的 Python 物件,例如 - torch.ops.{namespace}.{func_name}.{overload_name},以便從 Python 輕鬆地與它們互動。提供給- __torch_dispatch__的- func物件始終是此命名空間中的一個條目。此命名空間可用於直接呼叫原生操作並繞過通常的 Python API 和繫結程式碼。
與 __torch_function__ 能夠插入所有 torch 的 Python API 和 Tensor 方法的方式類似,__torch_dispatch__ 能夠攔截對 aten 原生 API 的所有呼叫。請注意,Tensor 上的所有方法在進入派送器之前都會被轉換為函式呼叫,因此在這裡將顯示為函式呼叫:torch.add(a, 2) 和 a + 2 將導致完全相同的 aten 呼叫。這些函式大多數定義在 native_functions.yaml 中,該檔案指定了這些函式的屬性以及它們的後端實作。然後,它們的實作以及指定的功能會透過程式碼產生自動註冊。一些更奇特的函式或功能也會註冊在 C++ 程式碼庫的其他地方或使用者定義的 C++ 擴充中。
也可以使用 torch.library 新增新的原生函式。此 Python 功能允許定義和/或將新的實作新增到原生函式中。這可用於新增遺漏的核心、替換現有的核心或定義全新的原生函式。
您可以在 子類別動物園 儲存庫中找到許多基於 __torch_dispatch__ 的子類別範例。
使用模式擴充所有 torch API¶
遺憾的是,有些函式不接受 Tensor 輸入。這表示上述子類別方法無法用於覆寫所有 PyTorch 函式的行為。此外,如果使用案例需要攔截每個函式呼叫,則將每個 Tensor 都更改為子類別可能會過於侵入。
為了滿足這種使用案例,我們引入了「模式」的概念。這些模式存在於 __torch_function__ 和 __torch_dispatch__ 覆寫中,分別透過繼承 torch.overrides.TorchFunctionMode 和 torch.utils._python_dispatch.TorchDispatchMode 建立,並用作上下文管理器。
為了簡化它與子類別和其他模式交互方式的描述,每當進入模式的上下文管理器時,每個函式的行為都好像在參數清單的開頭有一個額外的 Tensor 參數,其中模式作為子類別。這尤其意味著所有模式處理常式都將在任何子類別處理常式之前被呼叫,並且對應於內部上下文管理器的模式將始終先執行。
同樣重要的是要注意,在給定的模式處理常式中,此特定模式是被禁用的,並且可以透過執行 with self: 來手動重新啟用。
以下是一個顯示每種類型的日誌記錄模式的範例
import torch
from torch.overrides import TorchFunctionMode, resolve_name
from torch.utils._python_dispatch import TorchDispatchMode
class FunctionLog(TorchFunctionMode):
    def __torch_function__(self, func, types, args, kwargs=None):
        print(f"Function Log: {resolve_name(func)}(*{args}, **{kwargs})")
        return func(*args, **(kwargs or {}))
class DispatchLog(TorchDispatchMode):
    def __torch_dispatch__(self, func, types, args, kwargs=None):
        print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
        return func(*args, **(kwargs or {}))
def f():
    a = torch.rand(10, requires_grad=True)
    b = a * 2
    b.sum().backward()
print("TorchFunctionMode logging:")
with FunctionLog():
    f()
print("TorchDispatchMode logging:")
with DispatchLog():
    f()
它會列印以下內容,並附帶額外註釋
TorchFunctionMode logging:
Function Log: torch.rand(*(10,), **{'requires_grad': True})
Function Log: torch.Tensor.mul(*(tensor([0.7164, 0.9897, 0.1745, 0.9336, 0.4287, 0.7989, 0.2169, 0.7474, 0.5624,
        0.5970], requires_grad=True), 2), **None)
Function Log: torch.Tensor.sum(*(tensor([1.4328, 1.9794, 0.3490, 1.8671, 0.8573, 1.5977, 0.4338, 1.4948, 1.1249,
        1.1939], grad_fn=<MulBackward0>),), **None)
# Note that at the python level, we only see the call to backward but not what happens in the autograd engine.
Function Log: torch.Tensor.backward(*(tensor(12.3307, grad_fn=<SumBackward0>),), **{'gradient': None, 'retain_graph': None, 'create_graph': False, 'inputs': None})
TorchDispatchMode logging:
# Here the requires_grad flag from autograd is removed while default arguments were populated.
Dispatch Log: aten.rand.default(*([10],), **{'device': device(type='cpu'), 'pin_memory': False})
Dispatch Log: aten.mul.Tensor(*(tensor([0.2151, 0.6018, 0.8415, 0.9060, 0.2974, 0.7708, 0.6668, 0.0352, 0.7948,
        0.6023], requires_grad=True), 2), **{})
Dispatch Log: aten.sum.default(*(tensor([0.4303, 1.2036, 1.6831, 1.8120, 0.5949, 1.5416, 1.3335, 0.0705, 1.5897,
        1.2046], grad_fn=<MulBackward0>),), **{})
# Here we don't see the call to backward itself, but its constituents. Starting here with the factory function that creates the initial gradient.
Dispatch Log: aten.ones_like.default(*(tensor(11.4637, grad_fn=<SumBackward0>),), **{'pin_memory': False, 'memory_format': torch.preserve_format})
# This is the backward of the sum
Dispatch Log: aten.expand.default(*(tensor(1.), [10]), **{})
Dispatch Log: aten.mul.Tensor(*(tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]), 2), **{})
Dispatch Log: aten.detach.default(*(tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]),), **{})
Dispatch Log: aten.detach.default(*(tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]),), **{})