評價此頁面

Torch Library API#

PyTorch C++ API 提供了擴充套件 PyTorch 核心運算元庫的能力,允許使用者定義運算元和資料型別。使用 Torch Library API 實現的擴充套件在 PyTorch 的 eager API 和 TorchScript 中均可使用。

有關 library API 的教程式介紹,請查閱使用自定義 C++ 運算元擴充套件 TorchScript 教程。

#

TORCH_LIBRARY(ns, m)

用於定義一個函式,該函式將在靜態初始化時執行,以在名稱空間 ns(必須是有效的 C++ 識別符號,不帶引號)中定義一個運算元庫。

當您想定義一組 PyTorch 中尚不存在的新的自定義運算元時,請使用此宏。

示例用法

TORCH_LIBRARY(myops, m) {
  // m is a torch::Library; methods on it will define
  // operators in the myops namespace
  m.def("add", add_impl);
}

引數 m 繫結到一個 torch::Library 物件,該物件用於註冊運算元。對於任何給定的名稱空間,只能有一個 TORCH_LIBRARY()

TORCH_LIBRARY_IMPL(ns, k, m)

用於定義一個函式,該函式將在靜態初始化時執行,以在名稱空間 ns(必須是有效的 C++ 識別符號,不帶引號)中為排程鍵 k(必須是 c10::DispatchKey 的非限定列舉成員)定義運算元覆蓋。

當您想在新的排程鍵上實現一組已存在的自定義運算元時(例如,您想為已存在的運算元提供 CUDA 實現),請使用此宏。一個常見的用法模式是使用 TORCH_LIBRARY() 定義所有要定義的新運算元的 schema,然後使用多個 TORCH_LIBRARY_IMPL() 塊為 CPU、CUDA 和 Autograd 提供運算元的實現。

在某些情況下,您需要定義適用於所有名稱空間(而不僅僅是一個名稱空間)的東西(通常是 fallback)。在這種情況下,請使用保留名稱空間 _,例如:

TORCH_LIBRARY_IMPL(_, XLA, m) {
   m.fallback(xla_fallback);
}

示例用法

TORCH_LIBRARY_IMPL(myops, CPU, m) {
  // m is a torch::Library; methods on it will define
  // CPU implementations of operators in the myops namespace.
  // It is NOT valid to call torch::Library::def()
  // in this context.
  m.impl("add", add_cpu_impl);
}

如果 add_cpu_impl 是一個過載函式,請使用 static_cast 指定您想要的過載(透過提供完整的型別)。

#

class Library

此物件提供了定義運算元和在排程鍵處提供實現的能力的 API。

通常,torch::Library 不是直接分配的;而是由 TORCH_LIBRARY()TORCH_LIBRARY_IMPL() 宏建立的。

torch::Library 上的大多數方法都返回自身的引用,支援方法鏈式呼叫。

// Examples:

TORCH_LIBRARY(torchvision, m) {
   // m is a torch::Library
   m.def("roi_align", ...);
   ...
}

TORCH_LIBRARY_IMPL(aten, XLA, m) {
   // m is a torch::Library
   m.impl("add", ...);
   ...
}

公有函式

inline Library &def(c10::FunctionSchema &&s, const std::vector<at::Tag> &tags = {}, _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) &

宣告一個帶有 schema 的運算元,但不為其提供任何實現。

您應隨後使用 impl() 方法提供實現。所有模板引數都會被推斷出來。

// Example:
TORCH_LIBRARY(myops, m) {
  m.def("add(Tensor self, Tensor other) -> Tensor");
}

引數

raw_schema – 要定義的運算元的 schema。通常,這是一個 const char* 字串字面量,但 torch::schema() 接受的任何型別在此處都可接受。

inline Library &set_python_module(const char *pymodule, const char *context = "")

聲明後續所有 def 的運算元的 fake impls 可以在給定的 Python 模組 (pymodule) 中找到。

這會註冊一些幫助文字,如果找不到 fake impl,就會使用這些文字。

引數

  • pymodule: Python 模組

  • context: 我們可能會在錯誤訊息中包含此資訊。

inline Library &impl_abstract_pystub(const char *pymodule, const char *context = "")

已棄用;請改用 set_python_module。

template<typename NameOrSchema, typename Func>
inline Library &def(NameOrSchema &&raw_name_or_schema, Func &&raw_f, const std::vector<at::Tag> &tags = {}) &

為一個 schema 定義一個運算元,然後為其註冊一個實現。

如果您不打算使用 dispatcher 來構建運算元實現,通常會使用此方法。它大致等同於呼叫 def() 再呼叫 impl(),但如果您省略運算元的 schema,我們將從 C++ 函式的型別中推斷出來。所有模板引數都會被推斷出來。

// Example:
TORCH_LIBRARY(myops, m) {
  m.def("add", add_fn);
}

引數
  • raw_name_or_schema – 要定義的運算元的 schema,或者僅為運算元名稱(如果 schema 從 raw_f 推斷)。通常是一個 const char* 字面量。

  • raw_f – 實現此運算元的 C++ 函式。此處接受 torch::CppFunction 的任何有效建構函式;通常您會提供函式指標或 lambda。

template<typename Name, typename Func>
inline Library &impl(Name name, Func &&raw_f, _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) &

註冊一個運算元的實現。

您可以在不同的排程鍵處為單個運算元註冊多個實現(參見 torch::dispatch())。實現必須有相應的宣告(來自 def()),否則無效。如果您計劃註冊多個實現,在 def() 運算元時不要提供函式實現。

// Example:
TORCH_LIBRARY_IMPL(myops, CUDA, m) {
  m.impl("add", add_cuda);
}

引數
  • name – 要實現的運算元名稱。此處不要提供 schema。

  • raw_f – 實現此運算元的 C++ 函式。此處接受 torch::CppFunction 的任何有效建構函式;通常您會提供函式指標或 lambda。

template<typename Func>
inline Library &fallback(Func &&raw_f) &

為所有運算元註冊一個 fallback 實現,當某個運算元沒有可用的特定實現時,將使用此 fallback 實現。

Fallback 必須關聯一個 DispatchKey;例如,只能從名稱空間為 _TORCH_LIBRARY_IMPL() 中呼叫此函式。

// Example:

TORCH_LIBRARY_IMPL(_, AutogradXLA, m) {
  // If there is not a kernel explicitly registered
  // for AutogradXLA, fallthrough to the next
  // available kernel
  m.fallback(torch::CppFunction::makeFallthrough());
}

// See aten/src/ATen/core/dispatch/backend_fallback_test.cpp
// for a full example of boxed fallback

引數

raw_f – 實現 fallback 的函式。未裝箱的函式通常不能用作 fallback 函式,因為 fallback 函式必須對每個運算元都有效(即使它們的型別簽名不同)。典型的引數是 CppFunction::makeFallthrough()CppFunction::makeFromBoxedFunction()

class CppFunction

表示實現運算元的 C++ 函式。

大多數使用者不會直接與此類互動,除非透過錯誤訊息:此類定義的建構函式定義了您可以透過介面繫結的“函式”類事物的允許集合。

此類抹去了傳入函式的型別,但透過為函式推斷的 schema 持久地記錄了型別。

公有函式

template<typename Func>
inline explicit CppFunction(Func *f, std::enable_if_t<c10::guts::is_function_type<Func>::value, std::nullptr_t> = nullptr)

此過載接受函式指標,例如 CppFunction(&add_impl)

template<typename FuncPtr>
inline explicit CppFunction(FuncPtr f, std::enable_if_t<c10::is_compile_time_function_pointer<FuncPtr>::value, std::nullptr_t> = nullptr)

此過載接受編譯時函式指標,例如 CppFunction(TORCH_FN(add_impl))

template<typename Lambda>
inline explicit CppFunction(Lambda &&f, std::enable_if_t<c10::guts::is_functor<std::decay_t<Lambda>>::value, std::nullptr_t> = nullptr)

此過載接受 lambda,例如 CppFunction([](const Tensor& self) { ...

})

公有靜態函式

static inline CppFunction makeFallthrough()

這會建立一個 fallthrough 函式。

Fallthrough 函式會立即重新排程到下一個可用的排程鍵,但其實現比手動編寫的相同功能的函式更有效率。

template<c10::BoxedKernel::BoxedKernelFunction *func>
static inline CppFunction makeFromBoxedFunction()

從具有簽名 void(const OperatorHandle&, Stack*) 的盒裝核心函式(boxed kernel function)建立函式;也就是說,它們在盒裝呼叫約定中接收引數堆疊,而不是在原生的 C++ 呼叫約定中接收。

盒裝函式通常僅用於透過 torch::Library::fallback() 註冊後端回退。

template<class KernelFunctor>
static inline CppFunction makeFromBoxedFunctor(std::unique_ptr<KernelFunctor> kernelFunctor)

從定義了 operator()(const OperatorHandle&, DispatchKeySet, Stack*)(從盒裝呼叫約定接收引數)並繼承自 c10::OperatorKernel 的盒裝核心 functor 建立函式。

與 makeFromBoxedFunction 不同,透過這種方式註冊的函式還可以攜帶額外的狀態,這些狀態由 functor 管理;如果你正在編寫介面卡以連線到其他實現(例如,與註冊核心動態關聯的 Python 可呼叫物件),這會很有用。

template<typename FuncPtr, std::enable_if_t<c10::guts::is_function_type<FuncPtr>::value, std::nullptr_t> = nullptr>
static inline CppFunction makeFromUnboxedFunction(FuncPtr *f)

從未盒裝核心函式(unboxed kernel function)建立函式。

這通常用於註冊常見的運算子。

template<typename FuncPtr, std::enable_if_t<c10::is_compile_time_function_pointer<FuncPtr>::value, std::nullptr_t> = nullptr>
static inline CppFunction makeFromUnboxedFunction(FuncPtr f)

從編譯時未盒裝核心函式指標建立函式。

這通常用於註冊常見的運算子。編譯時函式指標可以允許編譯器最佳化(例如,內聯)對其的呼叫。

Functions#

template<typename Func>
inline CppFunction dispatch(c10::DispatchKey k, Func &&raw_f)#

建立一個與特定排程鍵(dispatch key)關聯的 torch::CppFunction

標記有 c10::DispatchKey 的 torch::CppFunctions 只有在排程程式確定這個特定的 c10::DispatchKey 是應該被排程的鍵時才會被呼叫。

此函式通常不直接使用,而是推薦使用 TORCH_LIBRARY_IMPL(),後者會隱式地在其主體內的所有註冊呼叫中設定 c10::DispatchKey。

template<typename Func>
inline CppFunction dispatch(c10::DeviceType type, Func &&raw_f)#

接受 c10::DeviceTypedispatch() 便利過載。

inline c10::FunctionSchema schema(const char *str, c10::AliasAnalysisKind k, bool allow_typevars = false)#

從字串構造 c10::FunctionSchema,並顯式指定 c10::AliasAnalysisKind。

通常,schema 直接作為字串傳入,但如果您需要指定自定義別名分析(alias analysis),可以用對此函式的呼叫替換該字串。

// Default alias analysis (FROM_SCHEMA)
m.def("def3(Tensor self) -> Tensor");
// Pure function alias analysis
m.def(torch::schema("def3(Tensor self) -> Tensor",
c10::AliasAnalysisKind::PURE_FUNCTION));

inline c10::FunctionSchema schema(const char *s, bool allow_typevars = false)#

函式 schema 可以直接從字串字面量構造。