快捷方式

torch.overrides

本模組公開了 __torch_function__ 協議的各種輔助函式。有關 __torch_function__ 協議的更多詳細資訊,請參閱擴充套件 torch Python API

函式

torch.overrides.get_ignored_functions()[source][source]

返回無法透過 __torch_function__ 重寫的公共函式。

返回

一個函式元組,這些函式在 torch API 中公開可用,但無法透過 __torch_function__ 重寫。這主要是因為這些函式的引數都不是 tensor 或 tensor-like 物件。

返回型別

set[Callable]

示例

>>> torch.Tensor.as_subclass in torch.overrides.get_ignored_functions()
True
>>> torch.add in torch.overrides.get_ignored_functions()
False
torch.overrides.get_overridable_functions()[source][source]

列出可以透過 __torch_function__ 重寫的函式

返回

一個字典,將包含可重寫函式的名稱空間對映到該名稱空間中可重寫的函式。

返回型別

Dict[Any, List[Callable]]

torch.overrides.resolve_name(f)[source][source]

獲取傳遞給 __torch_function__ 的函式的易讀字串名稱

引數

f (Callable) – 要解析名稱的函式。

返回

函式的名稱;如果 eval 後,應返回輸入的函式。

返回型別

str

torch.overrides.get_testing_overrides()[source][source]

返回一個包含所有可重寫函式的虛擬重寫字典

返回

一個字典,將 PyTorch API 中可重寫的函式對映到與實際函式具有相同簽名的 lambda 函式,並無條件返回 -1。這些 lambda 函式對於測試定義了 __torch_function__ 的型別的 API 覆蓋範圍很有用。

返回型別

Dict[Callable, Callable]

示例

>>> import inspect
>>> my_add = torch.overrides.get_testing_overrides()[torch.add]
>>> inspect.signature(my_add)
<Signature (input, other, out=None)>
torch.overrides.handle_torch_function(public_api, relevant_args, *args, **kwargs)[source][source]

實現一個帶 __torch_function__ 重寫檢查的函式。

有關此函式在 C++ 實現中的對應部分,請參閱 torch::autograd::handle_torch_function。

引數
  • public_api (function) – 由 torch 公共 API 暴露的函式,最初呼叫方式為 public_api(*args, **kwargs),現在正在檢查其引數。

  • relevant_args (iterable) – 要檢查 __torch_function__ 方法的引數迭代器。

  • args (tuple) – 最初傳遞給 public_api 的任意位置引數。

  • kwargs (tuple) – 最初傳遞給 public_api 的任意關鍵字引數。

返回

根據需要,呼叫 implementation__torch_function__ 方法的結果。

返回型別

object

:raises TypeError : 如果找不到實現。

示例

>>> def func(a):
...     if has_torch_function_unary(a):
...         return handle_torch_function(func, (a,), a)
...     return a + 0
torch.overrides.has_torch_function()

檢查迭代器的元素中是否存在 __torch_function__ 實現,或是否啟用了 __torch_function__ 模式。將精確的 TensorParameter 視為不可分派。使用此函式來保護對 handle_torch_function() 的呼叫;不要用它來測試某個物件是否是 Tensor-like,而應使用 is_tensor_like()。 :param relevant_args: 要檢查 __torch_function__ 方法的迭代器或引數。 :type relevant_args: iterable

返回

如果 relevant_args 的任何元素具有 __torch_function__ 實現,則返回 True,否則返回 False。

返回型別

bool

另請參見

torch.is_tensor_like

檢查某個物件是否是 Tensor-like,包括精確的 Tensor

torch.overrides.is_tensor_like(inp)[source][source]

如果傳入的輸入是 Tensor-like,則返回 True

目前,當輸入型別的屬性中存在 __torch_function__ 時,就會發生這種情況。

示例

tensor 的子類通常是 Tensor-like。

>>> class SubTensor(torch.Tensor): ...
>>> is_tensor_like(SubTensor([0]))
True

內建或使用者型別通常不是 Tensor-like。

>>> is_tensor_like(6)
False
>>> is_tensor_like(None)
False
>>> class NotATensor: ...
>>> is_tensor_like(NotATensor())
False

但是,可以透過實現 __torch_function__ 使它們成為 Tensor-like。

>>> class TensorLike:
...     @classmethod
...     def __torch_function__(cls, func, types, args, kwargs):
...         return -1
>>> is_tensor_like(TensorLike())
True
torch.overrides.is_tensor_method_or_property(func)[source][source]

如果傳入的函式是屬於 torch.Tensor 的方法或屬性的 handler,並且被傳遞給 __torch_function__,則返回 True。

注意

對於屬性,必須傳入它們的 __get__ 方法。

這可能特別需要,原因如下:

  1. 方法/屬性有時不包含 __module__ 槽。

  2. 它們要求傳入的第一個引數是 torch.Tensor 的例項。

示例

>>> is_tensor_method_or_property(torch.Tensor.add)
True
>>> is_tensor_method_or_property(torch.add)
False
返回型別

bool

torch.overrides.wrap_torch_function(dispatcher)[source][source]

__torch_function__ 相關功能包裝給定函式。

引數

dispatcher (Callable) – 一個可呼叫物件,返回傳遞給函式的 Tensor-like 物件的迭代器。

注意

此裝飾器可能會降低程式碼效能。通常,只需將程式碼表達為一系列本身支援 __torch_function__ 的函式即可。如果您發現自己處於極少數情況下,例如您正在包裝一個底層庫並且還需要它支援 Tensor-like 物件,那麼此函式可用。

示例

>>> def dispatcher(a):  # Must have the same signature as func
...     return (a,)
>>> @torch.overrides.wrap_torch_function(dispatcher)
>>> def func(a):  # This will make func dispatchable by __torch_function__
...     return a + 0

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並解答疑問

檢視資源