快捷方式

torch.compiler.substitute_in_graph

torch.compiler.substitute_in_graph(original_fn, *, can_constant_fold_through=False, skip_signature_check=False)[原始碼][原始碼]

註冊一個函式的 polyfill 處理器,通常是一個 C 擴充套件中的 C 函式,用於在圖中內聯原始函式時替代原始函式。

注意

polyfill 處理器僅在內聯原始函式時使用。當直接呼叫原始函式時,它不會被使用。在 eager 模式下,裝飾後的函式會呼叫高效能的 C 函式,而不是 polyfill 處理器。

polyfill 處理器是一個函式,在內聯原始函式時將替代原始函式被呼叫。polyfill 處理器應具有與原始函式相同的簽名和行為。

引數
  • original_fn (可呼叫物件) – 要為其註冊 polyfill 處理器 的原始函式,通常是 C 函式。

  • can_constant_fold_through (bool, 可選) – polyfill 處理器是否可以透過常量摺疊。也就是說,如果 polyfill 處理器是純函式且其引數是常量,那麼在編譯期間可以對 polyfill 處理器 的結果進行常量摺疊。預設為 False

  • skip_signature_check (bool, 可選) – 是否跳過原始函式和 polyfill 處理器之間的簽名檢查。預設為 False

返回

一個用於為原始函式註冊 polyfill 處理器 的裝飾器。

返回型別

Callable[[Callable[[_P], _R]], Callable[[_P], _R]]

示例

>>> import operator
>>> operator.indexOf([1, 2, 3, 4, 5], 3)
2
>>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3)
... # xdoctest: +SKIP("Long tracebacks")
Traceback (most recent call last):
...
torch._dynamo.exc.Unsupported: ...

>>> @torch.compiler.substitute_in_graph(operator.indexOf)
... def indexOf(a, b, /):
...     for i, item in enumerate(a):
...         if item is b or item == b:
...             return i
...     raise ValueError("sequence.index(x): x not in sequence")
>>>
>>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3)
2

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得解答

檢視資源