torch.overrides¶
本模組公開了 __torch_function__ 協議的各種輔助函式。有關 __torch_function__ 協議的更多詳細資訊,請參閱擴充套件 torch Python API。
函式¶
- torch.overrides.get_ignored_functions()[source][source]¶
返回無法透過
__torch_function__重寫的公共函式。- 返回
一個函式元組,這些函式在 torch API 中公開可用,但無法透過
__torch_function__重寫。這主要是因為這些函式的引數都不是 tensor 或 tensor-like 物件。- 返回型別
set[Callable]
示例
>>> torch.Tensor.as_subclass in torch.overrides.get_ignored_functions() True >>> torch.add in torch.overrides.get_ignored_functions() False
- torch.overrides.get_overridable_functions()[source][source]¶
列出可以透過 __torch_function__ 重寫的函式
- 返回
一個字典,將包含可重寫函式的名稱空間對映到該名稱空間中可重寫的函式。
- 返回型別
Dict[Any, List[Callable]]
- torch.overrides.resolve_name(f)[source][source]¶
獲取傳遞給 __torch_function__ 的函式的易讀字串名稱
- 引數
f (Callable) – 要解析名稱的函式。
- 返回
函式的名稱;如果 eval 後,應返回輸入的函式。
- 返回型別
- torch.overrides.get_testing_overrides()[source][source]¶
返回一個包含所有可重寫函式的虛擬重寫字典
- 返回
一個字典,將 PyTorch API 中可重寫的函式對映到與實際函式具有相同簽名的 lambda 函式,並無條件返回 -1。這些 lambda 函式對於測試定義了
__torch_function__的型別的 API 覆蓋範圍很有用。- 返回型別
Dict[Callable, Callable]
示例
>>> import inspect >>> my_add = torch.overrides.get_testing_overrides()[torch.add] >>> inspect.signature(my_add) <Signature (input, other, out=None)>
- torch.overrides.handle_torch_function(public_api, relevant_args, *args, **kwargs)[source][source]¶
實現一個帶
__torch_function__重寫檢查的函式。有關此函式在 C++ 實現中的對應部分,請參閱 torch::autograd::handle_torch_function。
- 引數
- 返回
根據需要,呼叫
implementation或__torch_function__方法的結果。- 返回型別
:raises TypeError : 如果找不到實現。
示例
>>> def func(a): ... if has_torch_function_unary(a): ... return handle_torch_function(func, (a,), a) ... return a + 0
- torch.overrides.has_torch_function()¶
檢查迭代器的元素中是否存在 __torch_function__ 實現,或是否啟用了 __torch_function__ 模式。將精確的
Tensor和Parameter視為不可分派。使用此函式來保護對handle_torch_function()的呼叫;不要用它來測試某個物件是否是 Tensor-like,而應使用is_tensor_like()。 :param relevant_args: 要檢查 __torch_function__ 方法的迭代器或引數。 :type relevant_args: iterable- 返回
如果 relevant_args 的任何元素具有 __torch_function__ 實現,則返回 True,否則返回 False。
- 返回型別
另請參見
torch.is_tensor_like檢查某個物件是否是 Tensor-like,包括精確的
Tensor。
- torch.overrides.is_tensor_like(inp)[source][source]¶
如果傳入的輸入是 Tensor-like,則返回
True。目前,當輸入型別的屬性中存在
__torch_function__時,就會發生這種情況。示例
tensor 的子類通常是 Tensor-like。
>>> class SubTensor(torch.Tensor): ... >>> is_tensor_like(SubTensor([0])) True
內建或使用者型別通常不是 Tensor-like。
>>> is_tensor_like(6) False >>> is_tensor_like(None) False >>> class NotATensor: ... >>> is_tensor_like(NotATensor()) False
但是,可以透過實現 __torch_function__ 使它們成為 Tensor-like。
>>> class TensorLike: ... @classmethod ... def __torch_function__(cls, func, types, args, kwargs): ... return -1 >>> is_tensor_like(TensorLike()) True
- torch.overrides.is_tensor_method_or_property(func)[source][source]¶
如果傳入的函式是屬於
torch.Tensor的方法或屬性的 handler,並且被傳遞給__torch_function__,則返回 True。注意
對於屬性,必須傳入它們的
__get__方法。這可能特別需要,原因如下:
方法/屬性有時不包含 __module__ 槽。
它們要求傳入的第一個引數是
torch.Tensor的例項。
示例
>>> is_tensor_method_or_property(torch.Tensor.add) True >>> is_tensor_method_or_property(torch.add) False
- 返回型別
- torch.overrides.wrap_torch_function(dispatcher)[source][source]¶
用
__torch_function__相關功能包裝給定函式。- 引數
dispatcher (Callable) – 一個可呼叫物件,返回傳遞給函式的 Tensor-like 物件的迭代器。
注意
此裝飾器可能會降低程式碼效能。通常,只需將程式碼表達為一系列本身支援 __torch_function__ 的函式即可。如果您發現自己處於極少數情況下,例如您正在包裝一個底層庫並且還需要它支援 Tensor-like 物件,那麼此函式可用。
示例
>>> def dispatcher(a): # Must have the same signature as func ... return (a,) >>> @torch.overrides.wrap_torch_function(dispatcher) >>> def func(a): # This will make func dispatchable by __torch_function__ ... return a + 0