快捷方式

torch.cuda.jiterator._create_jit_fn

torch.cuda.jiterator._create_jit_fn(code_string, **kwargs)[source][source]

為逐元素操作建立 jiterator 生成的 CUDA kernel。

程式碼字串必須是一個有效的 CUDA 函式,用於描述單個元素的計算。程式碼字串必須遵循 C++ 模板模式,如下例所示。此函式將被內聯到逐元素 kernel 模板中,並動態編譯。編譯後的 kernel 將快取到記憶體和本地臨時目錄中。

Jiterator 生成的 kernel 接受非連續的 Tensor,並支援廣播和型別提升。

引數
  • code_string (str) – 由 jiterator 編譯的 CUDA 程式碼字串。入口 functor 必須按值返回。

  • kwargs (Dict, optional) – 生成函式的關鍵字引數

返回型別

可呼叫物件

示例

code_string = "template <typename T> T my_kernel(T x, T y, T alpha) { return -x + alpha * y; }"
jitted_fn = create_jit_fn(code_string, alpha=1.0)
a = torch.rand(3, device='cuda')
b = torch.rand(3, device='cuda')
# invoke jitted function like a regular python function
result = jitted_fn(a, b, alpha=3.14)

code_string 也允許定義多個函式,最後一個函式將被視為入口函式。

示例

code_string = "template <typename T> T util_fn(T x, T y) { return ::sin(x) + ::cos(y); }"
code_string += "template <typename T> T my_kernel(T x, T y, T val) { return ::min(val, util_fn(x, y)); }"
jitted_fn = create_jit_fn(code_string, val=0.0)
a = torch.rand(3, device='cuda')
b = torch.rand(3, device='cuda')
# invoke jitted function like a regular python function
result = jitted_fn(a, b)  # using default val=0.0

Jiterator 可以與 Python 註冊機制結合使用,以覆蓋運算元的 CUDA kernel。以下示例使用 relu 覆蓋 gelu 的 CUDA kernel。

示例

code_string = "template <typename T> T my_gelu(T a) { return a > 0 ? a : 0; }"
my_gelu = create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::gelu', my_gelu, "CUDA")
# torch.nn.GELU and torch.nn.function.gelu are now overridden
a = torch.rand(3, device='cuda')
torch.allclose(torch.nn.functional.gelu(a), torch.nn.functional.relu(a))

警告

此 API 處於 Beta 階段,未來版本中可能會有所更改。

警告

此 API 最多支援 8 個輸入和 1 個輸出

警告

所有輸入 Tensor 都必須位於 CUDA 裝置上

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源