擴充套件 PyTorch¶
在本說明中,我們將介紹擴充套件 torch.nn、torch.autograd、torch 以及編寫自定義 C++ 擴充套件的方法。
新增新運算元¶
PyTorch 提供了大量可在張量 (Tensor) 上工作的運算元(例如 torch.add()、torch.sum() 等)。但是,您可能希望將新的自定義操作引入 PyTorch,並使其行為與 PyTorch 的內建運算元一樣。為此,您必須透過 Python torch.library 或 C++ TORCH_LIBRARY API 將自定義操作註冊到 PyTorch。
有關更多詳細資訊,請參閱 PyTorch 自定義運算元登陸頁。
擴充套件 torch.autograd¶
將操作新增到 autograd 需要為每個操作實現一個新的 Function 子類。回想一下,Function 是 autograd 用於編碼操作歷史和計算梯度的物件。
本文件的第一部分側重於反向模式自動微分 (AD),因為它是最廣泛使用的特性。末尾的部分討論了前向模式自動微分 (AD) 的擴充套件。
何時使用¶
通常,如果您想在模型中執行不可微分或依賴非 PyTorch 庫(例如 NumPy)的計算,但仍希望您的操作能夠與其他運算元鏈式連線並與 autograd 引擎一起工作,那麼請實現自定義函式。
在某些情況下,自定義函式也可用於提高效能和記憶體使用:如果您使用 C++ 擴充套件實現了前向和反向傳播,您可以將它們封裝在 Function 中以與 autograd 引擎互動。如果您想減少為反向傳播儲存的緩衝區數量,可以使用自定義函式將多個運算元組合在一起。
何時不使用¶
如果您已經可以使用 PyTorch 的內建運算元來編寫您的函式,那麼它的反向圖(很可能)已經可以被 autograd 記錄。在這種情況下,您無需自己實現 backward 函式。考慮使用普通的 Python 函式即可。
如果您需要維護狀態,即可訓練引數,您應該(也)使用自定義模組。有關擴充套件 torch.nn 的更多資訊,請參閱下面的部分。
如果您想在反向傳播期間修改梯度或執行副作用,請考慮註冊一個張量或Module hook。
如何使用¶
請按照以下步驟操作:1. 子類化 Function 並實現 forward()、(可選)setup_context() 和 backward() 方法。2. 呼叫 ctx 引數上的適當方法。3. 宣告您的函式是否支援二次反向傳播 (double backward)。4. 使用 gradcheck 驗證您的梯度是否正確。
步驟 1:子類化 Function 後,您需要定義 3 個方法
forward()是執行操作的程式碼。它可以接受任意數量的引數,如果您指定預設值,其中一些引數是可選的。這裡接受所有型別的 Python 物件。Tensor引數如果跟蹤歷史(即,requires_grad=True),則在呼叫前將被轉換為不跟蹤歷史的 Tensor,並且它們的用法將被註冊到圖中。請注意,此邏輯不會遍歷列表/字典/任何其他資料結構,只會考慮直接作為呼叫引數的 Tensor。您可以返回單個Tensor輸出,或者在有多個輸出時返回一個tuple的 tensors。此外,請參閱Function的文件以查詢只能從forward()中呼叫的有用方法的描述。setup_context()(可選)。您可以編寫一個接受ctx物件的“組合式”forward(),或者(自 PyTorch 2.0 起)編寫一個不接受ctx的單獨forward()和一個setup_context()方法,在其中進行ctx的修改。forward()應該包含計算邏輯,而setup_context()只應負責ctx的修改(不包含任何計算)。通常,分開的forward()和setup_context()更接近 PyTorch 本地操作的工作方式,因此與各種 PyTorch 子系統更具組合性。有關更多詳細資訊,請參閱 組合式或分開式 forward() 和 setup_context()。backward()(或vjp())定義了梯度公式。它將獲得與輸出數量相同的Tensor引數,每個引數代表相對於相應輸出的梯度。切記不要就地修改這些引數。它應該返回與輸入數量相同的 tensors,每個 Tensor 包含相對於其相應輸入的梯度。如果您的輸入不需要梯度(needs_input_grad是一個布林值元組,指示每個輸入是否需要計算梯度),或者是非Tensor物件,您可以返回python:None。此外,如果forward()有可選引數,您可以返回比輸入更多的梯度,只要它們全部為None。
步驟 2:您有責任正確使用 ctx 中的函式,以確保新的 Function 與 autograd 引擎正常工作。
必須使用
save_for_backward()儲存要在反向傳播中使用的任何張量。非張量應直接儲存在 ctx 上。如果儲存了既不是輸入也不是輸出的張量用於反向傳播,您的Function可能不支援二次反向傳播(參見步驟 3)。必須使用
mark_dirty()標記 forward 函式就地修改的任何輸入。必須使用
mark_non_differentiable()告知引擎輸出是否不可微分。預設情況下,所有可微分型別的輸出張量都將設定為需要梯度。不可微分型別(即整型)的張量永遠不會被標記為需要梯度。set_materialize_grads()可用於告知 autograd 引擎在輸出不依賴於輸入的情況下最佳化梯度計算,方法是不具體化 (materializing) 傳遞給 backward 函式的梯度張量。也就是說,如果設定為 False,Python 中的 None 物件或 C++ 中的“未定義張量”(即 x.defined() 為 False 的張量 x)將不會在呼叫 backward 之前被轉換為填充零的張量,因此您的程式碼需要像處理填充零的張量一樣處理這些物件。此設定的預設值為 True。
步驟 3:如果您的 Function 不支援二次反向傳播,您應該透過使用 once_differentiable() 裝飾器裝飾 backward 來明確宣告。使用此裝飾器後,嘗試透過您的函式進行二次反向傳播將產生錯誤。有關二次反向傳播的更多資訊,請參閱我們的二次反向傳播教程。
步驟 4:建議您使用 torch.autograd.gradcheck() 來檢查您的 backward 函式是否透過使用 backward 函式計算雅可比矩陣,並將其值與使用有限差分法數值計算的雅可比矩陣進行逐元素比較,從而正確計算了 forward 的梯度。
示例¶
您可以在下方找到 Linear 函式的程式碼,並附有額外註釋。
# Inherit from Function
class LinearFunction(Function):
# Note that forward, setup_context, and backward are @staticmethods
@staticmethod
def forward(input, weight, bias):
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
@staticmethod
# inputs is a Tuple of all of the inputs passed to forward.
# output is the output of the forward().
def setup_context(ctx, inputs, output):
input, weight, bias = inputs
ctx.save_for_backward(input, weight, bias)
# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_output):
# This is a pattern that is very convenient - at the top of backward
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
# These needs_input_grad checks are optional and there only to
# improve efficiency. If you want to make your code simpler, you can
# skip them. Returning gradients for inputs that don't require it is
# not an error.
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
return grad_input, grad_weight, grad_bias
現在,為了更容易使用這些自定義運算元,我們建議將它們別名化或封裝在一個函式中。封裝在函式中可以讓我們支援預設引數和關鍵字引數。
# Option 1: alias
linear = LinearFunction.apply
# Option 2: wrap in a function, to support default args and keyword args.
def linear(input, weight, bias=None):
return LinearFunction.apply(input, weight, bias)
這裡,我們提供一個由非 Tensor 引數引數化的函式的額外示例。
class MulConstant(Function):
@staticmethod
def forward(tensor, constant):
return tensor * constant
@staticmethod
def setup_context(ctx, inputs, output):
# ctx is a context object that can be used to stash information
# for backward computation
tensor, constant = inputs
ctx.constant = constant
@staticmethod
def backward(ctx, grad_output):
# We return as many input gradients as there were arguments.
# Gradients of non-Tensor arguments to forward must be None.
return grad_output * ctx.constant, None
在這裡,我們透過呼叫 set_materialize_grads(False) 來最佳化上述示例。
class MulConstant(Function):
@staticmethod
def forward(tensor, constant):
return tensor * constant
@staticmethod
def setup_context(ctx, inputs, output):
tensor, constant = inputs
ctx.set_materialize_grads(False)
ctx.constant = constant
@staticmethod
def backward(ctx, grad_output):
# Here we must handle None grad_output tensor. In this case we
# can skip unnecessary computations and just return None.
if grad_output is None:
return None, None
# We return as many input gradients as there were arguments.
# Gradients of non-Tensor arguments to forward must be None.
return grad_output * ctx.constant, None
如果在 forward() 中計算的任何“中間”張量需要被儲存,則它們必須作為輸出返回,或者結合使用 forward 和 setup_context()(參見 組合式或分開式 forward() 和 setup_context())。請注意,這意味著如果您希望梯度流經這些中間值,您需要為它們定義梯度公式(另請參見 二次反向傳播教程)。
class MyCube(torch.autograd.Function):
@staticmethod
def forward(x):
# We wish to save dx for backward. In order to do so, it must
# be returned as an output.
dx = 3 * x ** 2
result = x ** 3
return result, dx
@staticmethod
def setup_context(ctx, inputs, output):
x, = inputs
result, dx = output
ctx.save_for_backward(x, dx)
@staticmethod
def backward(ctx, grad_output, grad_dx):
x, dx = ctx.saved_tensors
# In order for the autograd.Function to work with higher-order
# gradients, we must add the gradient contribution of `dx`,
# which is grad_dx * 6 * x.
result = grad_output * dx + grad_dx * 6 * x
return result
# Wrap MyCube in a function so that it is clearer what the output is
def my_cube(x):
result, dx = MyCube.apply(x)
return result
注意
backward 的輸入,即 grad_output,也可以是跟蹤歷史的張量。因此,如果 backward 是用可微分操作實現的(例如,呼叫另一個自定義 Function),則高階導數將起作用。在這種情況下,使用 save_for_backward 儲存的張量也可以在 backward 中使用並有梯度流回,但儲存在 ctx 中的張量將不會有梯度流回。如果您需要儲存在 ctx 中的 Tensor 有梯度流回,您應該將其作為自定義 Function 的輸出並使用 save_for_backward 儲存。
您可能想檢查您實現的 backward 方法是否實際計算了您的函式的導數。這可以透過與使用小有限差分法的數值逼近進行比較來實現。
from torch.autograd import gradcheck
# gradcheck takes a tuple of tensors as input, check if your gradient
# evaluated with these tensors are close enough to numerical
# approximations and returns True if they all verify this condition.
input = (torch.randn(20,20,dtype=torch.double,requires_grad=True), torch.randn(30,20,dtype=torch.double,requires_grad=True))
test = gradcheck(linear, input, eps=1e-6, atol=1e-4)
print(test)
有關有限差分梯度比較的更多詳細資訊,請參閱 數值梯度檢查。如果您的函式用於高階導數(對 backward 傳播求導),您可以使用同一包中的 gradgradcheck 函式來檢查高階導數。
組合式或分開式 forward() 和 setup_context()¶
定義 Function 有兩種主要方式。即
我們推薦第二種選項(分開的 forward() 和 setup_context()),因為它更接近 PyTorch 本地操作的實現方式,並且與 torch.func 轉換更具組合性。然而,我們計劃將來繼續支援這兩種方法;將 forward() 與 setup_context() 結合使用:可以提供更大的靈活性,因為您可以在不將中間值作為輸出返回的情況下儲存它們。
有關如何使用分開的 forward() 和 setup_context() 定義 Function 的資訊,請參閱上一節。
以下是一個示例,說明如何使用組合式的 forward() 和 setup_context() 定義 Function。
class LinearFunction(Function):
@staticmethod
# ctx is the first argument to forward
def forward(ctx, input, weight, bias=None):
# The forward pass can use ctx.
ctx.save_for_backward(input, weight, bias)
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
return grad_input, grad_weight, grad_bias
前向模式自動微分 (AD)¶
重寫前向模式自動微分 (AD) 公式具有非常相似的 API,但有一些細微差別。您可以實現 jvp() 函式。
它將獲得與輸入數量相同的 Tensor 引數,每個引數代表相對於相應輸入的梯度。它應該返回與輸出數量相同的 tensors,每個 Tensor 包含相對於其相應輸出的梯度。jvp() 將在 forward() 方法之後、apply() 返回之前被呼叫。
jvp() 與 backward() 函式有一些細微差別
您可以使用 ctx 將
forward()中的任何資料傳遞給jvp()函式。如果backward()不需要該狀態,您可以在jvp()函式末尾透過執行del ctx.foo明確釋放它。jvp()的實現必須是 backward 可微分的,或者明確檢查給定的前向模式梯度都沒有設定requires_grad。jvp()函式必須與forward()的 view/就地行為相匹配。例如,如果第i個輸入被就地修改,則第i個梯度也必須就地更新。類似地,如果第j個輸出是第k個輸入的檢視。那麼返回的第j個輸出梯度必須是給定第k個輸入梯度的檢視。由於使用者無法指定需要計算哪個梯度,
jvp()函式應該始終計算所有輸出的梯度。前向模式梯度確實遵守由
set_materialize_grads()設定的標誌,並且當停用此功能時,您可以獲得 None 輸入梯度。
torch.func 轉換和/或 torch.vmap()¶
有關詳細資訊,請參閱 使用 autograd.Function 擴充套件 torch.func。
擴充套件 torch.nn¶
nn 匯出了兩種介面 - 模組 (modules) 及其函式式版本 (functional versions)。你可以透過這兩種方式進行擴充套件,但我們推薦使用模組來實現需要持有引數或緩衝 (buffers) 的各種層,而對於啟用函式、池化等無引數操作,則推薦使用函式形式。
在上面一節中已經全面介紹瞭如何新增一個操作的函式式版本。
新增 Module¶
由於 nn 大量利用了 autograd,新增一個新的 Module 需要實現一個 Function,該函式執行操作並能計算梯度。現在我們假設要實現一個 Linear 模組,並且已經按照上面的程式碼清單實現了該函式。新增這個模組只需要非常少的程式碼。現在,需要實現兩個函式:
__init__(可選) - 接收諸如卷積核大小、特徵數量等引數,並初始化引數和緩衝。
以下是 Linear 模組的實現示例:
class Linear(nn.Module):
def __init__(self, input_features, output_features, bias=True):
super().__init__()
self.input_features = input_features
self.output_features = output_features
# nn.Parameter is a special kind of Tensor, that will get
# automatically registered as Module's parameter once it's assigned
# as an attribute. Parameters and buffers need to be registered, or
# they won't appear in .parameters() (doesn't apply to buffers), and
# won't be converted when e.g. .cuda() is called. You can use
# .register_buffer() to register buffers.
# nn.Parameters require gradients by default.
self.weight = nn.Parameter(torch.empty(output_features, input_features))
if bias:
self.bias = nn.Parameter(torch.empty(output_features))
else:
# You should always register all possible parameters, but the
# optional ones can be None if you want.
self.register_parameter('bias', None)
# Not a very smart way to initialize weights
nn.init.uniform_(self.weight, -0.1, 0.1)
if self.bias is not None:
nn.init.uniform_(self.bias, -0.1, 0.1)
def forward(self, input):
# See the autograd section for explanation of what happens here.
return LinearFunction.apply(input, self.weight, self.bias)
def extra_repr(self):
# (Optional)Set the extra information about this module. You can test
# it by printing an object of this class.
return 'input_features={}, output_features={}, bias={}'.format(
self.input_features, self.output_features, self.bias is not None
)
擴充套件 torch Python API¶
你可以透過定義一個自定義類,使其方法與 Tensor 相匹配,從而建立模仿 Tensor 的自定義型別。但是,如果你想將這些型別傳遞給頂層 torch 名稱空間中接受 Tensor 運算元的函式,例如 torch.add(),該怎麼辦呢?
如果你的自定義 Python 型別定義了一個名為 __torch_function__ 的方法,當將你的自定義類例項傳遞給 torch 名稱空間中的函式時,PyTorch 將呼叫你的 __torch_function__ 實現。這使得你可以為 torch 名稱空間中的任何函式定義自定義實現,你的 __torch_function__ 實現可以呼叫這些函式,從而允許你的使用者在使用 Tensor 時,在他們已有的 PyTorch 工作流中使用你的自定義型別。這既適用於與 Tensor 無關的“鴨子型別” (duck types),也適用於使用者定義的 Tensor 子類。
使用類似 Tensor 的型別擴充套件 torch¶
為了具體說明這一點,我們從一個簡單的示例開始,它演示了 API 分派機制。我們將建立一個自定義型別,表示一個二維標量張量,該張量由階數 N 和對角線元素的值 value 引數化。
class ScalarTensor(object):
def __init__(self, N, value):
self._N = N
self._value = value
def __repr__(self):
return "ScalarTensor(N={}, value={})".format(self._N, self._value)
def tensor(self):
return self._value * torch.eye(self._N)
設計的第一版並不是非常有用。ScalarTensor 的主要功能是提供比基本張量類更緊湊的標量張量字串表示形式。
>>> d = ScalarTensor(5, 2)
>>> d
ScalarTensor(N=5, value=2)
>>> d.tensor()
tensor([[2., 0., 0., 0., 0.],
[0., 2., 0., 0., 0.],
[0., 0., 2., 0., 0.],
[0., 0., 0., 2., 0.],
[0., 0., 0., 0., 2.]])
如果我們嘗試將此物件與 torch API 一起使用,將會遇到問題:
>>> import torch
>>> torch.mean(d)
TypeError: mean(): argument 'input' (position 1) must be Tensor, not ScalarTensor
向 ScalarTensor 新增 __torch_function__ 實現使得上述操作得以成功。讓我們重新編寫實現,這次新增 __torch_function__ 實現:
HANDLED_FUNCTIONS = {}
class ScalarTensor(object):
def __init__(self, N, value):
self._N = N
self._value = value
def __repr__(self):
return "ScalarTensor(N={}, value={})".format(self._N, self._value)
def tensor(self):
return self._value * torch.eye(self._N)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func not in HANDLED_FUNCTIONS or not all(
issubclass(t, (torch.Tensor, ScalarTensor))
for t in types
):
return NotImplemented
return HANDLED_FUNCTIONS[func](*args, **kwargs)
__torch_function__ 方法接受四個引數:func,一個指向正在被覆蓋的 torch API 函式的引用;types,實現 __torch_function__ 的 Tensor-like 型別的列表;args,傳遞給函式的引數元組;kwargs,傳遞給函式的關鍵字引數字典。它使用一個名為 HANDLED_FUNCTIONS 的全域性分派表來儲存自定義實現。這個字典的鍵是 torch 名稱空間中的函式,值是針對 ScalarTensor 的實現。
注意
使用全域性分派表不是 __torch_function__ API 的強制要求,它只是一個有用的設計模式,用於組織你的覆蓋實現。
這個類定義不足以讓 torch.mean 在我們傳遞 ScalarTensor 時執行正確的操作——我們還需要為 torch.mean 定義一個針對 ScalarTensor 運算元的實現,並將該實現新增到 HANDLED_FUNCTIONS 分派表字典中。一種實現方法是定義一個裝飾器:
import functools
def implements(torch_function):
"""Register a torch function override for ScalarTensor"""
def decorator(func):
functools.update_wrapper(func, torch_function)
HANDLED_FUNCTIONS[torch_function] = func
return func
return decorator
然後可以將其應用於我們覆蓋實現的函式:
@implements(torch.mean)
def mean(input):
return float(input._value) / input._N
透過此更改,我們現在可以使用 torch.mean 處理 ScalarTensor:
>>> d = ScalarTensor(5, 2)
>>> torch.mean(d)
0.4
當然,torch.mean 是最簡單的覆蓋函式型別之一,因為它只接受一個運算元。我們可以使用相同的機制來覆蓋接受多個運算元的函式,其中任何一個運算元都可能是定義了 __torch_function__ 的張量或類似張量的型別,例如 torch.add():
def ensure_tensor(data):
if isinstance(data, ScalarTensor):
return data.tensor()
return torch.as_tensor(data)
@implements(torch.add)
def add(input, other):
try:
if input._N == other._N:
return ScalarTensor(input._N, input._value + other._value)
else:
raise ValueError("Shape mismatch!")
except AttributeError:
return torch.add(ensure_tensor(input), ensure_tensor(other))
此版本為兩個運算元都是 ScalarTensor 例項的情況提供了一個快速路徑,還提供了一個較慢的路徑,在任一運算元不是 ScalarTensor 時會退化為將資料轉換為張量。這使得覆蓋函式在任一運算元是 ScalarTensor 或常規 Tensor 時都能正確工作。
>>> s = ScalarTensor(2, 2)
>>> torch.add(s, s)
ScalarTensor(N=2, value=4)
>>> t = torch.tensor([[1, 1,], [1, 1]])
>>> torch.add(s, t)
tensor([[3., 1.],
[1., 3.]])
請注意,我們的 add 實現不像 torch.add() 那樣接受 alpha 或 out 作為關鍵字引數:
>>> torch.add(s, s, alpha=2)
TypeError: add() got an unexpected keyword argument 'alpha'
為了速度和靈活性,__torch_function__ 分派機制不會檢查覆蓋函式的簽名是否與 torch API 中被覆蓋函式的簽名匹配。對於某些應用,忽略可選引數是可以接受的,但為了確保與 Tensor 的完全相容性,使用者實現的 torch API 函式應該注意完全模仿被覆蓋函式的 API。
在 torch API 中沒有顯式覆蓋的函式將從 __torch_function__ 返回 NotImplemented。如果所有定義了 __torch_function__ 的運算元都返回 NotImplemented,PyTorch 將引發 TypeError。這意味著在大多數情況下,當傳遞此類型別的例項時,沒有顯式覆蓋的操作將引發 TypeError。
>>> torch.mul(s, 3)
TypeError: no implementation found for 'torch.mul' on types that
implement __torch_function__: [ScalarTensor]
實際上,這意味著如果你想按照這些思路使用 __torch_function__ 實現來編寫覆蓋,你需要顯式地實現完整的 torch API,或者至少是你用例關心的 API 子集。這可能是一個艱鉅的任務,因為完整的 torch API 相當廣泛。
另一個選擇是對於未處理的操作不返回 NotImplemented,而是在沒有可用覆蓋時將 Tensor 傳遞給原始的 torch 函式。例如,如果我們將 ScalarTensor 的 __torch_function__ 實現更改為以下內容:
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func not in HANDLED_FUNCTIONS or not all(
issubclass(t, (torch.Tensor, ScalarTensor))
for t in types
):
args = [a.tensor() if hasattr(a, 'tensor') else a for a in args]
return func(*args, **kwargs)
return HANDLED_FUNCTIONS[func](*args, **kwargs)
那麼 torch.mul() 將正常工作,儘管返回型別始終是 Tensor 而不是 ScalarTensor,即使兩個運算元都是 ScalarTensor 例項:
>>> s = ScalarTensor(2, 2)
>>> torch.mul(s, s)
tensor([[4., 0.],
[0., 4.]])
另請參見下面的 MetadataTensor 示例,它展示了這種模式的另一種變體,但始終返回 MetadataTensor,以便在 torch API 的操作中傳播元資料。
__torch_function__ 協議旨在覆蓋完整的 API,部分覆蓋可能導致不良結果,特別是某些函式會引發 TypeError。對於子類尤其如此,torch.add、torch.Tensor.__add__ 和 torch.Tensor.add 都必須被覆蓋,即使它們返回完全相同的結果。未能做到這一點也可能導致無限遞迴。如果需要從 torch.Tensor 子類實現函式,則必須在其實現內部使用 super().__torch_function__。
子類化 torch.Tensor¶
自版本 1.7.0 起,應用於 torch.Tensor 子類的 torch.Tensor 方法和公共 torch.* 名稱空間中的函式將返回子類例項而不是 torch.Tensor 例項:
>>> class SubTensor(torch.Tensor):
... pass
>>> type(torch.add(SubTensor([0]), SubTensor([1]))).__name__
'SubTensor'
>>> type(torch.add(SubTensor([0]), torch.tensor([1]))).__name__
'SubTensor'
如果存在多個子類,預設會選擇層級最低的那個。如果無法唯一確定這種情況,則會引發 TypeError。
>>> type(torch.add(SubTensor2([0]), SubTensor([1]))).__name__
'SubTensor2'
>>> type(torch.add(SubTensor2([0]), torch.tensor([1]))).__name__
'SubTensor2'
>>> torch.add(SubTensor([0]), OtherSubTensor([1]))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: no implementation found for 'torch.add' on types that implement __torch_function__: [SubTensor, OtherSubTensor]
如果希望對所有張量方法進行全域性覆蓋,可以使用 __torch_function__。以下是記錄所有函式/方法呼叫的示例:
class LoggingTensor(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
# NOTE: Logging calls Tensor.__repr__, so we can't log __repr__ without infinite recursion
if func is not torch.Tensor.__repr__:
logging.info(f"func: {func.__name__}, args: {args!r}, kwargs: {kwargs!r}")
if kwargs is None:
kwargs = {}
return super().__torch_function__(func, types, args, kwargs)
但是,如果希望覆蓋 Tensor 子類上的方法,可以透過直接覆蓋方法(為子類定義該方法)或使用 __torch_function__ 並與 func 匹配來實現。
對於子類中的 __torch_function__,應注意始終呼叫 super().__torch_function__(func, ...),而不是直接呼叫 func,這與 1.7.0 版本之前的做法不同。未能做到這一點可能導致 func 遞迴呼叫回 __torch_function__,從而導致無限遞迴。
使用 Tensor 包裝器型別擴充套件 torch¶
另一個有用的場景是包裝 Tensor 的型別,無論作為屬性還是透過子類化。下面我們實現這種型別的一種特殊情況:一個 MetadataTensor,它將一個元資料字典附加到 Tensor 上,並在 torch 操作中傳播。由於這是對完整 torch API 的通用包裝,我們無需單獨實現每個覆蓋,因此可以使 __torch_function__ 實現對允許的操作更具包容性:
class MetadataTensor(object):
def __init__(self, data, metadata=None, **kwargs):
self._t = torch.as_tensor(data, **kwargs)
self._metadata = metadata
def __repr__(self):
return "Metadata:\n{}\n\ndata:\n{}".format(self._metadata, self._t)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
metadatas = tuple(a._metadata for a in args if hasattr(a, '_metadata'))
args = [getattr(a, '_t', a) for a in args]
assert len(metadatas) > 0
ret = func(*args, **kwargs)
return MetadataTensor(ret, metadata=metadatas[0])
這個簡單的實現不一定適用於 torch API 中的每個函式,但足以涵蓋大多數常用操作。
>>> metadata = {'owner': 'Ministry of Silly Walks'}
>>> m = MetadataTensor([[1, 2], [3, 4]], metadata=metadata)
>>> t = torch.tensor([[1, 2], [1, 2]])
>>> torch.add(t, m)
Metadata:
{'owner': 'Ministry of Silly Walks'}
data:
tensor([[2, 4],
[4, 6]])
>>> torch.mul(t, m)
Metadata:
{'owner': 'Ministry of Silly Walks'}
data:
tensor([[1, 4],
[3, 8]])
操作多個定義了 __torch_function__ 的型別¶
可以使用 torch API 處理多個各自擁有 __torch_function__ 實現的不同型別,但需要特別注意。在這種情況下,規則如下:
分派操作會收集每個運算元上所有不同的
__torch_function__實現,並按順序呼叫它們:子類優先於超類,否則按操作表示式中的從左到右順序。如果返回的值不是
NotImplemented,則該值作為結果返回。實現可以透過返回NotImplemented來表明它們不實現某個操作。如果所有
__torch_function__實現都返回NotImplemented,PyTorch 將引發TypeError。
測試 PyTorch API 覆蓋範圍¶
實現 __torch_function__ 的一個棘手方面是,如果某些操作有覆蓋而另一些沒有,使用者充其量會遇到不一致的體驗,最壞的情況下會在使用沒有覆蓋的函式時遇到執行時錯誤。為了簡化這個過程,PyTorch 提供了一個面向開發者的 API,用於確保對 __torch_function__ 覆蓋的全面支援。此 API 是私有的,未來可能會在沒有警告的情況下進行更改。
首先,要獲取所有可覆蓋函式的列表,請使用 torch.overrides._get_overridable_functions。這會返回一個字典,其鍵是 PyTorch Python API 中的名稱空間,值是該名稱空間中可以被覆蓋的函式列表。例如,讓我們列印 torch.nn.functional 中前 5 個可覆蓋函式的名稱:
>>> from torch.overrides import get_overridable_functions
>>> func_dict = get_overridable_functions()
>>> nn_funcs = func_dict[torch.nn.functional]
>>> print([f.__name__ for f in nn_funcs[:5])
['adaptive_avg_pool1d', 'adaptive_avg_pool2d', 'adaptive_avg_pool3d',
'adaptive_max_pool1d', 'adaptive_max_pool1d_with_indices']
這個函式列表使得可以迭代所有可覆蓋的函式,然而實際上,如果不能費力地手動複製每個函式的簽名進行測試,這不足以編寫針對所有這些函式的測試。為了簡化這個過程,torch.overrides._get_testing_overrides 函式返回一個字典,將 PyTorch API 中可覆蓋的函式對映到具有相同簽名的模擬 lambda 函式,這些函式無條件返回 -1。這些函式最適合與 inspect 一起使用,分析原始 PyTorch 函式的函式簽名:
>>> import inspect
>>> from torch.overrides import get_testing_overrides
>>> override_dict = get_testing_overrides()
>>> dummy_add = override_dict[torch.add]
>>> inspect.signature(dummy_add)
<Signature (input, other, out=None)>
最後,torch.overrides.get_ignored_functions 返回一個函式元組,這些函式明確不能被 __torch_function__ 覆蓋。這個列表對於確認在 get_overridable_functions 返回的字典中不存在的函式無法被覆蓋是很有用的。
擴充套件 torch 原生 API¶
雖然 __torch_function__ 允許有效地擴充套件 PyTorch 純 Python 元件的行為,但它不允許擴充套件 PyTorch 中用 C++ 實現的部分。為此,Tensor 子類也可以定義 __torch_dispatch__,它將能夠覆蓋 C++ 級別的行為。
要有效使用此功能,瞭解 PyTorch 的原生部分是如何實現的非常重要。其中最重要的元件是我們稱之為“分發器” (dispatcher) 的東西(最好的描述可以在這篇部落格文章中找到,儘管它略有 outdated)。正如其名稱所示,它負責為函式的特定呼叫呼叫正確的後端函式。例如,當呼叫 torch.add(a, b) 時,分發器會檢查兩個引數,確定應為此特定呼叫使用哪個“特性”(autograd、autocast、functionalization 等)和哪個“後端”(CPU、CUDA、MPS 等),最後呼叫所有正確的核心。核心經常做的一件事是“重新分派” (redispatch)。例如,當在 GPU 上使用 autocast 執行神經網路時,第一次呼叫將是 autocast 核心,它將處理任何潛在的 autocast 邏輯,然後向下重新分派。佇列中的下一個特性將是 autograd,它將正確建立 autograd 圖,然後向下重新分派。最後,我們到達 CUDA 的後端核心,它將啟動正確的 CUDA 核心並返回最終結果。在返回途中,autograd 將圖附加到輸出,最後,autocast 將有機會在退出時進行任何所需的更新。
分發器的一種配置是所有這些特性和後端鍵的呼叫順序。最新的列表及其順序可以在 DispatchKey.h 中的 DispatchKey 列舉中找到。就擴充套件 torch 而言,本次討論中重要的順序子集是:
vmap -> Autocast -> Autograd -> ZeroTensor -> Neg/Conj -> Functionalize -> Python -> Backends
就本次討論而言,最重要的鍵是 Python,因為所有定義了 __torch_dispatch__ 方法的 Tensor 子類都將呼叫此特性。使用者定義的方法就是從這裡呼叫的,並且可以在這裡任意覆蓋行為。從這裡,再次呼叫提供的 func 將執行“重新分派”。
此實現的一些重要含義是:
這段程式碼執行在“所有特性之下”。因此,它僅負責(像常規後端一樣)生成每個 Tensor 的輸出值(並且可以,也應該,忽略所有高階特性,如 autograd、autocast 等)。
如果任何高階特性在不重新分派的情況下實現了給定函式,它將永遠不會到達
Python鍵,因此__torch_dispatch__回撥將永遠不會被觸發。對於 CompositeImplicitAutograd 函式尤其如此,它們在 Autograd 級別進行評估而不進行重新分派。這是因為 CompositeImplicitAutograd 函式透過隱式呼叫其他原生操作來指定其 autograd 公式,因此在 Autograd 級別,該函式會被分解為其原生操作並對其進行評估。在回撥到 Python 幷包裝結果時,使用的轉換與常規 PyTorch Python/C++ 繫結相同。特別地,有些物件無法在 Python 中表示,需要特殊處理(例如,未定義的 Tensors 會變成 None)。
我們的原生函式被延遲填充為
torch.ops.{namespace}.{func_name}.{overload_name}的可呼叫 Python 物件,以便從 Python 輕鬆與其互動。傳遞給__torch_dispatch__的func物件始終是此名稱空間中的一個條目。此名稱空間可用於直接呼叫原生操作,繞過常規 Python API 和繫結程式碼。
與 __torch_function__ 能夠攔截 torch 的所有 Python API 和 Tensor 方法類似,__torch_dispatch__ 能夠攔截所有對 aten 原生 API 的呼叫。請注意,Tensor 上的所有方法在進入分發器之前都會轉換為函式呼叫,因此會在此處顯示為函式呼叫:torch.add(a, 2) 和 a + 2 將導致完全相同的 aten 呼叫。這些函式大多數定義在 native_functions.yaml 中,其中指定了這些函式的屬性及其後端實現。然後,透過程式碼生成,它們的實現以及指定的特性會被自動註冊。一些更奇特的函式或特性也在 C++ 程式碼庫的其他地方或使用者定義的 C++ 擴充套件中註冊。
也可以使用 torch.library 新增 新的 原生函式。這個 Python 特性允許定義和/或向原生函式新增新的實現。這可用於新增缺失的核心、替換現有核心或定義全新的原生函式。
您可以在 subclass zoo 倉庫中找到許多基於 __torch_dispatch__ 的子類示例。
__torch_dispatch__ 呼叫約定¶
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
pass
當用戶呼叫帶有定義了 __torch_dispatch__ 的輸入的運算子時,該呼叫可能會被轉發到 __torch_dispatch__。在呼叫 __torch_dispatch__ 之前,args 和 kwargs 會被標準化,也就是說:
kwargs由運算子 schema 中的僅關鍵字引數組成。如果某個關鍵字引數等於其預設值 (在 schema 中),則不會傳遞它。args由所有其他引數組成,無論它們如何傳遞給運算子(位置引數 vs 關鍵字引數)。如果某個引數等於其預設值,並且它是最右邊的位置引數,或者其右邊的所有引數都沒有傳遞,則不會傳遞它。
使用模式 (Modes) 擴充套件所有 torch API¶
不幸的是,有些函式不接受 Tensor 輸入。這意味著上面描述的子類方法無法用於覆蓋 PyTorch 所有函式的行為。此外,如果用例需要攔截每個函式呼叫,將每個 Tensor 更改為子類可能會過於侵入性。
為了解決這個用例,我們引入了“模式”(Mode) 的概念。它們存在於 __torch_function__ 和 __torch_dispatch__ 的覆蓋中,分別透過子類化 torch.overrides.TorchFunctionMode 和 torch.utils._python_dispatch.TorchDispatchMode 建立,並作為上下文管理器使用。
為了簡化其與子類和其他模式互動的描述,每當模式的上下文管理器被進入時,每個函式都表現得好像引數列表開頭有一個額外的 Tensor 引數,該 Tensor 的子類就是該模式。這意味著所有模式處理程式將先於任何子類處理程式被呼叫,並且與內部上下文管理器對應的模式將始終首先執行。
同樣重要的是要注意,在給定的模式處理程式內,此特定模式被停用,可以透過執行 with self: 手動重新啟用。
這是一個顯示不同型別模式日誌記錄的示例:
import torch
from torch.overrides import TorchFunctionMode, resolve_name
from torch.utils._python_dispatch import TorchDispatchMode
class FunctionLog(TorchFunctionMode):
def __torch_function__(self, func, types, args, kwargs=None):
print(f"Function Log: {resolve_name(func)}(*{args}, **{kwargs})")
return func(*args, **(kwargs or {}))
class DispatchLog(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args, kwargs=None):
print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
return func(*args, **(kwargs or {}))
def f():
a = torch.rand(10, requires_grad=True)
b = a * 2
b.sum().backward()
print("TorchFunctionMode logging:")
with FunctionLog():
f()
print("TorchDispatchMode logging:")
with DispatchLog():
f()
列印以下內容,附帶額外的註釋:
TorchFunctionMode logging:
Function Log: torch.rand(*(10,), **{'requires_grad': True})
Function Log: torch.Tensor.mul(*(tensor([0.7164, 0.9897, 0.1745, 0.9336, 0.4287, 0.7989, 0.2169, 0.7474, 0.5624,
0.5970], requires_grad=True), 2), **None)
Function Log: torch.Tensor.sum(*(tensor([1.4328, 1.9794, 0.3490, 1.8671, 0.8573, 1.5977, 0.4338, 1.4948, 1.1249,
1.1939], grad_fn=<MulBackward0>),), **None)
# Note that at the python level, we only see the call to backward but not what happens in the autograd engine.
Function Log: torch.Tensor.backward(*(tensor(12.3307, grad_fn=<SumBackward0>),), **{'gradient': None, 'retain_graph': None, 'create_graph': False, 'inputs': None})
TorchDispatchMode logging:
# Here the requires_grad flag from autograd is removed while default arguments were populated.
Dispatch Log: aten.rand.default(*([10],), **{'device': device(type='cpu'), 'pin_memory': False})
Dispatch Log: aten.mul.Tensor(*(tensor([0.2151, 0.6018, 0.8415, 0.9060, 0.2974, 0.7708, 0.6668, 0.0352, 0.7948,
0.6023], requires_grad=True), 2), **{})
Dispatch Log: aten.sum.default(*(tensor([0.4303, 1.2036, 1.6831, 1.8120, 0.5949, 1.5416, 1.3335, 0.0705, 1.5897,
1.2046], grad_fn=<MulBackward0>),), **{})
# Here we don't see the call to backward itself, but its constituents. Starting here with the factory function that creates the initial gradient.
Dispatch Log: aten.ones_like.default(*(tensor(11.4637, grad_fn=<SumBackward0>),), **{'pin_memory': False, 'memory_format': torch.preserve_format})
# This is the backward of the sum
Dispatch Log: aten.expand.default(*(tensor(1.), [10]), **{})
Dispatch Log: aten.mul.Tensor(*(tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]), 2), **{})
Dispatch Log: aten.detach.default(*(tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]),), **{})
Dispatch Log: aten.detach.default(*(tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]),), **{})