torch.overrides¶
此模組公開了 __torch_function__ 協定的各種輔助函式。如需 __torch_function__ 協定的詳細資訊,請參閱 擴展 torch Python API。
函式¶
- torch.overrides.get_ignored_functions()[原始碼]¶
- 傳回無法被 - __torch_function__覆寫的公開函式。- 傳回值
- 一個函式元組,這些函式在 torch API 中公開可用,但無法使用 - __torch_function__覆寫。這主要是因為這些函式的引數都不是張量或類似張量的物件。
- 傳回類型
- Set[Callable] 
 - 範例 - >>> torch.Tensor.as_subclass in torch.overrides.get_ignored_functions() True >>> torch.add in torch.overrides.get_ignored_functions() False 
- torch.overrides.get_overridable_functions()[原始碼]¶
- 列出可透過 __torch_function__ 覆寫的函式 - 傳回值
- 一個字典,將包含可覆寫函式的命名空間對應到該命名空間中可覆寫的函式。 
- 傳回類型
- Dict[Any, List[Callable]] 
 
- torch.overrides.resolve_name(f)[原始碼]¶
- 取得傳遞給 __torch_function__ 的函式的易讀字串名稱 - 參數
- f (Callable) – 要解析名稱的函式。 
- 傳回值
- 函式的名稱;如果經過 eval 後,應該會傳回輸入函式。 
- 傳回類型
 
- torch.overrides.get_testing_overrides()[原始碼]¶
- 傳回一個字典,其中包含所有可覆寫函式的虛擬覆寫 - 傳回值
- 一個字典,將 PyTorch API 中可覆寫的函式對應到 lambda 函式,這些 lambda 函式具有與實際函式相同的簽章,並且無條件傳回 -1。這些 lambda 函式可用於測試定義了 - __torch_function__的類型的 API 涵蓋範圍。
- 傳回類型
- Dict[Callable, Callable] 
 - 範例 - >>> import inspect >>> my_add = torch.overrides.get_testing_overrides()[torch.add] >>> inspect.signature(my_add) <Signature (input, other, out=None)> 
- torch.overrides.handle_torch_function(public_api, relevant_args, *args, **kwargs)[原始碼]¶
- 使用 - __torch_function__覆寫的檢查來實作函式。- 如需 C++ 實作中此函式的等效函式,請參閱 torch::autograd::handle_torch_function。 - 參數
- 傳回值
- 呼叫 - implementation或- __torch_function__方法的結果(如果有的話)。
- 傳回類型
 - :raises TypeError : 如果找不到實作。 - 範例 - >>> def func(a): ... if has_torch_function_unary(a): ... return handle_torch_function(func, (a,), a) ... return a + 0 
- torch.overrides.has_torch_function()¶
- 檢查可迭代物件元素中是否有 __torch_function__ 的實現,或者是否啟用了 __torch_function__ 模式。將精確的 - Tensor和- Parameter視為不可調度的。使用此函數來保護對- handle_torch_function()的呼叫;不要使用它來測試某個物件是否類似於 Tensor,請改用- is_tensor_like()。 :param relevant_args: 要檢查 __torch_function__ 方法的可迭代物件或參數。 :type relevant_args: 可迭代物件- 傳回值
- 如果 relevant_args 中的任何元素具有 __torch_function__ 的實現,則為 True,否則為 False。 
- 傳回類型
 - 另請參閱 - torch.is_tensor_like
- 檢查某個物件是否類似於 Tensor,包括精確的 - Tensor。
 
- torch.overrides.is_tensor_like(inp)[原始碼]¶
- 如果傳入的輸入是類似於 Tensor 的物件,則返回 - True。- 目前,只要輸入類型的物件上有 - __torch_function__屬性,就會發生這種情況。- 範例 - Tensor 的子類別通常是類似於 Tensor 的物件。 - >>> class SubTensor(torch.Tensor): ... >>> is_tensor_like(SubTensor([0])) True - 內建或使用者定義的類型通常不是類似於 Tensor 的物件。 - >>> is_tensor_like(6) False >>> is_tensor_like(None) False >>> class NotATensor: ... >>> is_tensor_like(NotATensor()) False - 但是,它們可以通過實現 __torch_function__ 來變成類似於 Tensor 的物件。 - >>> class TensorLike: ... @classmethod ... def __torch_function__(cls, func, types, args, kwargs): ... return -1 >>> is_tensor_like(TensorLike()) True 
- torch.overrides.is_tensor_method_or_property(func)[原始碼]¶
- 如果傳入的函數是屬於 - torch.Tensor的方法或屬性的處理程序(如傳遞到- __torch_function__中),則返回 True。- 備註 - 對於屬性,必須傳入它們的 - __get__方法。- 可能需要這樣做,特別是因為以下原因 - 方法/屬性有時不包含 __module__ 槽。 
- 它們要求第一個傳入的參數是 - torch.Tensor的實例。
 - 範例 - >>> is_tensor_method_or_property(torch.Tensor.add) True >>> is_tensor_method_or_property(torch.add) False - 傳回類型
 
- torch.overrides.wrap_torch_function(dispatcher)[原始碼]¶
- 使用 - __torch_function__相關功能包裝給定的函數。- 參數
- dispatcher (可呼叫物件) – 一個可呼叫物件,它返回傳遞到函數中的類似於 Tensor 的物件的可迭代物件。 
 - 備註 - 此裝飾器可能會降低程式碼的效能。一般來說,將程式碼表示為一系列本身支持 __torch_function__ 的函數就足夠了。如果您發現自己處於極少數情況下並非如此,例如,如果您正在包裝一個低階程式庫,並且您還需要它能夠處理類似於 Tensor 的物件,那麼可以使用此函數。 - 範例 - >>> def dispatcher(a): # Must have the same signature as func ... return (a,) >>> @torch.overrides.wrap_torch_function(dispatcher) >>> def func(a): # This will make func dispatchable by __torch_function__ ... return a + 0