捷徑

torch.overrides

此模組公開了 __torch_function__ 協定的各種輔助函式。如需 __torch_function__ 協定的詳細資訊,請參閱 擴展 torch Python API

函式

torch.overrides.get_ignored_functions()[原始碼]

傳回無法被 __torch_function__ 覆寫的公開函式。

傳回值

一個函式元組,這些函式在 torch API 中公開可用,但無法使用 __torch_function__ 覆寫。這主要是因為這些函式的引數都不是張量或類似張量的物件。

傳回類型

Set[Callable]

範例

>>> torch.Tensor.as_subclass in torch.overrides.get_ignored_functions()
True
>>> torch.add in torch.overrides.get_ignored_functions()
False
torch.overrides.get_overridable_functions()[原始碼]

列出可透過 __torch_function__ 覆寫的函式

傳回值

一個字典,將包含可覆寫函式的命名空間對應到該命名空間中可覆寫的函式。

傳回類型

Dict[Any, List[Callable]]

torch.overrides.resolve_name(f)[原始碼]

取得傳遞給 __torch_function__ 的函式的易讀字串名稱

參數

f (Callable) – 要解析名稱的函式。

傳回值

函式的名稱;如果經過 eval 後,應該會傳回輸入函式。

傳回類型

str

torch.overrides.get_testing_overrides()[原始碼]

傳回一個字典,其中包含所有可覆寫函式的虛擬覆寫

傳回值

一個字典,將 PyTorch API 中可覆寫的函式對應到 lambda 函式,這些 lambda 函式具有與實際函式相同的簽章,並且無條件傳回 -1。這些 lambda 函式可用於測試定義了 __torch_function__ 的類型的 API 涵蓋範圍。

傳回類型

Dict[Callable, Callable]

範例

>>> import inspect
>>> my_add = torch.overrides.get_testing_overrides()[torch.add]
>>> inspect.signature(my_add)
<Signature (input, other, out=None)>
torch.overrides.handle_torch_function(public_api, relevant_args, *args, **kwargs)[原始碼]

使用 __torch_function__ 覆寫的檢查來實作函式。

如需 C++ 實作中此函式的等效函式,請參閱 torch::autograd::handle_torch_function。

參數
  • public_api (函式) – 公開 torch API 公開的函式,最初像 public_api(*args, **kwargs) 這樣被呼叫,現在正在檢查其引數。

  • relevant_args (可迭代物件) – 要檢查 __torch_function__ 方法的引數的可迭代物件。

  • args (元組) – 最初傳遞到 public_api 的任意位置引數。

  • kwargs (元組) – 最初傳遞到 public_api 的任意關鍵字引數。

傳回值

呼叫 implementation__torch_function__ 方法的結果(如果有的話)。

傳回類型

物件

:raises TypeError : 如果找不到實作。

範例

>>> def func(a):
...     if has_torch_function_unary(a):
...         return handle_torch_function(func, (a,), a)
...     return a + 0
torch.overrides.has_torch_function()

檢查可迭代物件元素中是否有 __torch_function__ 的實現,或者是否啟用了 __torch_function__ 模式。將精確的 TensorParameter 視為不可調度的。使用此函數來保護對 handle_torch_function() 的呼叫;不要使用它來測試某個物件是否類似於 Tensor,請改用 is_tensor_like()。 :param relevant_args: 要檢查 __torch_function__ 方法的可迭代物件或參數。 :type relevant_args: 可迭代物件

傳回值

如果 relevant_args 中的任何元素具有 __torch_function__ 的實現,則為 True,否則為 False。

傳回類型

布林值

另請參閱

torch.is_tensor_like

檢查某個物件是否類似於 Tensor,包括精確的 Tensor

torch.overrides.is_tensor_like(inp)[原始碼]

如果傳入的輸入是類似於 Tensor 的物件,則返回 True

目前,只要輸入類型的物件上有 __torch_function__ 屬性,就會發生這種情況。

範例

Tensor 的子類別通常是類似於 Tensor 的物件。

>>> class SubTensor(torch.Tensor): ...
>>> is_tensor_like(SubTensor([0]))
True

內建或使用者定義的類型通常不是類似於 Tensor 的物件。

>>> is_tensor_like(6)
False
>>> is_tensor_like(None)
False
>>> class NotATensor: ...
>>> is_tensor_like(NotATensor())
False

但是,它們可以通過實現 __torch_function__ 來變成類似於 Tensor 的物件。

>>> class TensorLike:
...     @classmethod
...     def __torch_function__(cls, func, types, args, kwargs):
...         return -1
>>> is_tensor_like(TensorLike())
True
torch.overrides.is_tensor_method_or_property(func)[原始碼]

如果傳入的函數是屬於 torch.Tensor 的方法或屬性的處理程序(如傳遞到 __torch_function__ 中),則返回 True。

備註

對於屬性,必須傳入它們的 __get__ 方法。

可能需要這樣做,特別是因為以下原因

  1. 方法/屬性有時不包含 __module__ 槽。

  2. 它們要求第一個傳入的參數是 torch.Tensor 的實例。

範例

>>> is_tensor_method_or_property(torch.Tensor.add)
True
>>> is_tensor_method_or_property(torch.add)
False
傳回類型

布林值

torch.overrides.wrap_torch_function(dispatcher)[原始碼]

使用 __torch_function__ 相關功能包裝給定的函數。

參數

dispatcher (可呼叫物件) – 一個可呼叫物件,它返回傳遞到函數中的類似於 Tensor 的物件的可迭代物件。

備註

此裝飾器可能會降低程式碼的效能。一般來說,將程式碼表示為一系列本身支持 __torch_function__ 的函數就足夠了。如果您發現自己處於極少數情況下並非如此,例如,如果您正在包裝一個低階程式庫,並且您還需要它能夠處理類似於 Tensor 的物件,那麼可以使用此函數。

範例

>>> def dispatcher(a): # Must have the same signature as func
...     return (a,)
>>> @torch.overrides.wrap_torch_function(dispatcher)
>>> def func(a): # This will make func dispatchable by __torch_function__
...     return a + 0

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得適合初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源