快捷方式

PyTorch 自定義運算元

創建於:2024 年 6 月 18 日 | 最後更新:2025 年 1 月 6 日 | 最後驗證:2024 年 11 月 5 日

PyTorch 提供了大量的運算元庫,可以在張量 (例如 torch.add, torch.sum 等) 上工作。但是,您可能希望將一個新的自定義操作引入 PyTorch,並使其與 torch.compile、自動求導和 torch.vmap 等子系統協同工作。為此,您必須透過 Python torch.library 文件 或 C++ TORCH_LIBRARY API 向 PyTorch 註冊自定義操作。

從 Python 建立自定義運算元

請參閱 自定義 Python 運算元

您可能希望從 Python (而非 C++) 建立自定義運算元,如果

  • 您有一個 Python 函式,希望 PyTorch 將其視為一個不透明的可呼叫物件,尤其是在 torch.compiletorch.export 方面。

  • 您有一些連線 C++/CUDA 核心的 Python 繫結,並希望這些繫結與 PyTorch 子系統 (如 torch.compiletorch.autograd) 協同工作

  • 您正在使用 Python (而不是像 AOTInductor 這樣的純 C++ 環境)。

將自定義 C++ 和/或 CUDA 程式碼與 PyTorch 整合

請參閱 自定義 C++ 和 CUDA 運算元

您可能希望從 C++ (而非 Python) 建立自定義運算元,如果

  • 您有自定義的 C++ 和/或 CUDA 程式碼。

  • 您計劃將此程式碼與 AOTInductor 一起使用以進行無 Python 推理。

自定義運算元手冊

對於教程和本頁未涵蓋的資訊,請參閱 自定義運算元手冊 (我們正在將這些資訊遷移到我們的文件網站)。我們建議您先閱讀上面的一個教程,然後將自定義運算元手冊作為參考;它不是用來從頭讀到尾的。

何時應該建立自定義運算元?

如果您的操作可以表示為內建 PyTorch 運算元的組合,那麼請將其編寫為一個 Python 函式並呼叫它,而不是建立自定義運算元。如果您正在呼叫 PyTorch 無法理解的某個庫 (例如自定義 C/C++ 程式碼、自定義 CUDA 核心或 C/C++/CUDA 擴充套件的 Python 繫結),請使用運算元註冊 API 建立自定義運算元。

為什麼應該建立自定義運算元?

可以透過獲取張量的資料指標並將其傳遞給 pybind 繫結的核心來使用 C/C++/CUDA 核心。但是,這種方法無法與 PyTorch 子系統 (如自動求導、torch.compile、vmap 等) 協同工作。為了使操作能夠與 PyTorch 子系統協同工作,必須透過運算元註冊 API 進行註冊。

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源