torch.overrides¶
此模組公開了 __torch_function__ 協定的各種輔助函式。如需 __torch_function__ 協定的詳細資訊,請參閱 擴展 torch Python API。
函式¶
- torch.overrides.get_ignored_functions()[原始碼]¶
傳回無法被
__torch_function__覆寫的公開函式。- 傳回值
一個函式元組,這些函式在 torch API 中公開可用,但無法使用
__torch_function__覆寫。這主要是因為這些函式的引數都不是張量或類似張量的物件。- 傳回類型
Set[Callable]
範例
>>> torch.Tensor.as_subclass in torch.overrides.get_ignored_functions() True >>> torch.add in torch.overrides.get_ignored_functions() False
- torch.overrides.get_overridable_functions()[原始碼]¶
列出可透過 __torch_function__ 覆寫的函式
- 傳回值
一個字典,將包含可覆寫函式的命名空間對應到該命名空間中可覆寫的函式。
- 傳回類型
Dict[Any, List[Callable]]
- torch.overrides.resolve_name(f)[原始碼]¶
取得傳遞給 __torch_function__ 的函式的易讀字串名稱
- 參數
f (Callable) – 要解析名稱的函式。
- 傳回值
函式的名稱;如果經過 eval 後,應該會傳回輸入函式。
- 傳回類型
- torch.overrides.get_testing_overrides()[原始碼]¶
傳回一個字典,其中包含所有可覆寫函式的虛擬覆寫
- 傳回值
一個字典,將 PyTorch API 中可覆寫的函式對應到 lambda 函式,這些 lambda 函式具有與實際函式相同的簽章,並且無條件傳回 -1。這些 lambda 函式可用於測試定義了
__torch_function__的類型的 API 涵蓋範圍。- 傳回類型
Dict[Callable, Callable]
範例
>>> import inspect >>> my_add = torch.overrides.get_testing_overrides()[torch.add] >>> inspect.signature(my_add) <Signature (input, other, out=None)>
- torch.overrides.handle_torch_function(public_api, relevant_args, *args, **kwargs)[原始碼]¶
使用
__torch_function__覆寫的檢查來實作函式。如需 C++ 實作中此函式的等效函式,請參閱 torch::autograd::handle_torch_function。
- 參數
- 傳回值
呼叫
implementation或__torch_function__方法的結果(如果有的話)。- 傳回類型
:raises TypeError : 如果找不到實作。
範例
>>> def func(a): ... if has_torch_function_unary(a): ... return handle_torch_function(func, (a,), a) ... return a + 0
- torch.overrides.has_torch_function()¶
檢查可迭代物件元素中是否有 __torch_function__ 的實現,或者是否啟用了 __torch_function__ 模式。將精確的
Tensor和Parameter視為不可調度的。使用此函數來保護對handle_torch_function()的呼叫;不要使用它來測試某個物件是否類似於 Tensor,請改用is_tensor_like()。 :param relevant_args: 要檢查 __torch_function__ 方法的可迭代物件或參數。 :type relevant_args: 可迭代物件- 傳回值
如果 relevant_args 中的任何元素具有 __torch_function__ 的實現,則為 True,否則為 False。
- 傳回類型
另請參閱
torch.is_tensor_like檢查某個物件是否類似於 Tensor,包括精確的
Tensor。
- torch.overrides.is_tensor_like(inp)[原始碼]¶
如果傳入的輸入是類似於 Tensor 的物件,則返回
True。目前,只要輸入類型的物件上有
__torch_function__屬性,就會發生這種情況。範例
Tensor 的子類別通常是類似於 Tensor 的物件。
>>> class SubTensor(torch.Tensor): ... >>> is_tensor_like(SubTensor([0])) True
內建或使用者定義的類型通常不是類似於 Tensor 的物件。
>>> is_tensor_like(6) False >>> is_tensor_like(None) False >>> class NotATensor: ... >>> is_tensor_like(NotATensor()) False
但是,它們可以通過實現 __torch_function__ 來變成類似於 Tensor 的物件。
>>> class TensorLike: ... @classmethod ... def __torch_function__(cls, func, types, args, kwargs): ... return -1 >>> is_tensor_like(TensorLike()) True
- torch.overrides.is_tensor_method_or_property(func)[原始碼]¶
如果傳入的函數是屬於
torch.Tensor的方法或屬性的處理程序(如傳遞到__torch_function__中),則返回 True。備註
對於屬性,必須傳入它們的
__get__方法。可能需要這樣做,特別是因為以下原因
方法/屬性有時不包含 __module__ 槽。
它們要求第一個傳入的參數是
torch.Tensor的實例。
範例
>>> is_tensor_method_or_property(torch.Tensor.add) True >>> is_tensor_method_or_property(torch.add) False
- 傳回類型
- torch.overrides.wrap_torch_function(dispatcher)[原始碼]¶
使用
__torch_function__相關功能包裝給定的函數。- 參數
dispatcher (可呼叫物件) – 一個可呼叫物件,它返回傳遞到函數中的類似於 Tensor 的物件的可迭代物件。
備註
此裝飾器可能會降低程式碼的效能。一般來說,將程式碼表示為一系列本身支持 __torch_function__ 的函數就足夠了。如果您發現自己處於極少數情況下並非如此,例如,如果您正在包裝一個低階程式庫,並且您還需要它能夠處理類似於 Tensor 的物件,那麼可以使用此函數。
範例
>>> def dispatcher(a): # Must have the same signature as func ... return (a,) >>> @torch.overrides.wrap_torch_function(dispatcher) >>> def func(a): # This will make func dispatchable by __torch_function__ ... return a + 0