快捷方式

torch.library

torch.library 是一組用於擴充套件 PyTorch 核心運算元庫的 API。它包含用於測試自定義運算元、建立新的自定義運算元以及擴充套件使用 PyTorch C++ 運算元註冊 API(例如 aten 運算元)定義的運算元的實用工具。

有關有效使用這些 API 的詳細指南,請參閱PyTorch 自定義運算元登入頁獲取有關如何有效使用這些 API 的更多詳情。

測試自定義運算元

使用torch.library.opcheck()來測試自定義運算元是否存在 Python torch.library 和/或 C++ TORCH_LIBRARY API 的不正確用法。此外,如果你的運算元支援訓練,請使用torch.autograd.gradcheck()來測試梯度在數學上是否正確。

torch.library.opcheck(op, args, kwargs=None, *, test_utils=('test_schema', 'test_autograd_registration', 'test_faketensor', 'test_aot_dispatch_dynamic'), raise_exception=True, atol=None, rtol=None)[source][source]

給定一個運算元和一些示例引數,測試該運算元是否已正確註冊。

也就是說,當你使用 torch.library/TORCH_LIBRARY API 建立自定義運算元時,你會指定關於該自定義運算元的元資料(例如可變性資訊),並且這些 API 要求你傳遞的函式滿足某些屬性(例如在 fake/meta/abstract kernel 中沒有資料指標訪問)opcheck測試這些元資料和屬性。

具體來說,我們測試以下方面:

  • test_schema: 測試 schema 是否與運算元的實現匹配。例如:如果 schema 指定一個 Tensor 會被改變(mutate),那麼我們會檢查實現是否確實改變了該 Tensor。如果 schema 指定返回一個新的 Tensor,那麼我們會檢查實現是否返回一個新的 Tensor(而不是現有 Tensor 或現有 Tensor 的檢視)。

  • test_autograd_registration: 如果運算元支援訓練(autograd):我們檢查其 autograd 公式是否透過 torch.library.register_autograd 或手動註冊到一個或多個 DispatchKey::Autograd 鍵。任何其他基於 DispatchKey 的註冊可能導致未定義行為。

  • test_faketensor: 測試運算元是否有 FakeTensor kernel(以及它是否正確)。FakeTensor kernel 對於運算元與 PyTorch 編譯 API (torch.compile/export/FX) 協同工作是必要條件(但不是充分條件)。我們檢查運算元是否註冊了 FakeTensor kernel(有時也稱為 meta kernel),以及它是否正確。此測試會比較在真實 tensor 上執行運算元的結果與在 FakeTensor 上執行運算元的結果,並檢查它們是否具有相同的 Tensor 元資料(大小/跨步/資料型別/裝置等)。

  • test_aot_dispatch_dynamic: 測試運算元在使用 PyTorch 編譯 API (torch.compile/export/FX) 時是否表現正確。它檢查在 eager-mode PyTorch 和 torch.compile 下的輸出(如果適用,還有梯度)是否相同。此測試是test_faketensor的超集,並且是一個端到端測試;它測試的其他方面包括運算元是否支援 functionalization 以及(如果存在)反向傳播是否也支援 FakeTensor 和 functionalization。

為了獲得最佳結果,請使用一組代表性輸入多次呼叫opcheck。如果你的運算元支援 autograd,請使用opcheck,並使用帶有requires_grad = True的輸入;如果你的運算元支援多種裝置(例如 CPU 和 CUDA),請使用opcheck並使用所有支援裝置上的輸入。

引數
  • op (Union[OpOverload, OpOverloadPacket, CustomOpDef]) – 運算元。必須是使用torch.library.custom_op()裝飾的函式,或者是 torch.ops.* 中找到的 OpOverload/OpOverloadPacket(例如 torch.ops.aten.sin, torch.ops.mylib.foo)

  • args (tuple[Any, ...]) – 傳遞給運算元的位置引數 (args)

  • kwargs (Optional[dict[str, Any]]) – 傳遞給運算元的關鍵字引數 (kwargs)

  • test_utils (Union[str, Sequence[str]]) – 應該執行的測試。預設:全部。示例:(“test_schema”, “test_faketensor”)

  • raise_exception (bool) – 是否在第一個錯誤時引發異常。如果為 False,將返回一個字典,其中包含每個測試是否透過的資訊。

  • rtol (Optional[float]) – 浮點比較的相對容差。如果指定了atol,則也必須指定。如果省略,則根據dtype選擇預設值(參見torch.testing.assert_close()中的表格)。

  • atol (Optional[float]) – 浮點比較的絕對容差。如果指定了rtol,則也必須指定。如果省略,則根據dtype選擇預設值(參見torch.testing.assert_close()中的表格)。

返回型別

dict[str, str]

警告

opcheck 和torch.autograd.gradcheck()測試不同的內容;opcheck 測試你對 torch.library API 的使用是否正確,而torch.autograd.gradcheck()測試你的 autograd 公式在數學上是否正確。兩者都應該用於測試支援梯度計算的自定義運算元。

示例

>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
>>> def numpy_mul(x: Tensor, y: float) -> Tensor:
>>>     x_np = x.numpy(force=True)
>>>     z_np = x_np * y
>>>     return torch.from_numpy(z_np).to(x.device)
>>>
>>> @numpy_mul.register_fake
>>> def _(x, y):
>>>     return torch.empty_like(x)
>>>
>>> def setup_context(ctx, inputs, output):
>>>     y, = inputs
>>>     ctx.y = y
>>>
>>> def backward(ctx, grad):
>>>     return grad * ctx.y, None
>>>
>>> numpy_mul.register_autograd(backward, setup_context=setup_context)
>>>
>>> sample_inputs = [
>>>     (torch.randn(3), 3.14),
>>>     (torch.randn(2, 3, device='cuda'), 2.718),
>>>     (torch.randn(1, 10, requires_grad=True), 1.234),
>>>     (torch.randn(64, 64, device='cuda', requires_grad=True), 90.18),
>>> ]
>>>
>>> for args in sample_inputs:
>>>     torch.library.opcheck(numpy_mul, args)

在 Python 中建立新的自定義運算元

使用torch.library.custom_op()來建立新的自定義運算元。

torch.library.custom_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None)[source]

將函式封裝成自定義運算元。

建立自定義運算元的原因包括: - 封裝第三方庫或自定義 kernel,使其與 PyTorch 子系統(如 Autograd)協同工作。 - 阻止 torch.compile/export/FX 追蹤深入函式內部。

此 API 用作函式的裝飾器(請參閱示例)。提供的函式必須有型別提示;這些對於與 PyTorch 的各種子系統進行介面互動是必需的。

引數
  • name (str) – 自定義運算元的名稱,格式為 “{namespace}::{name}”,例如 “mylib::my_linear”。此名稱在 PyTorch 子系統(例如 torch.export, FX graphs)中用作運算元的穩定識別符號。為避免名稱衝突,請使用你的專案名稱作為 namespace;例如,pytorch/fbgemm 中的所有自定義運算元都使用 “fbgemm” 作為 namespace。

  • mutates_args (Iterable[str] or "unknown") – 函式修改(mutate)的引數(args)的名稱。這必須準確,否則行為是未定義的。如果為 “unknown”,則悲觀地假定運算元的所有輸入都被修改。

  • device_types (None | str | Sequence[str]) – 函式有效的裝置型別。如果未提供裝置型別,則該函式將用作所有裝置型別的預設實現。示例:“cpu”, “cuda”。當為不接受 Tensor 的運算元註冊裝置特定實現時,我們要求該運算元有一個 “device: torch.device argument”。

  • schema (None | str) – 運算元的 schema 字串。如果為 None(推薦),我們將從其型別註解推斷運算元的 schema。我們建議讓系統推斷 schema,除非你有特殊原因不這樣做。自己編寫 schema 容易出錯。示例:“(Tensor x, int y) -> (Tensor, Tensor)”。

返回型別

Union[Callable[[Callable[[…], object]], CustomOpDef], CustomOpDef]

注意

我們建議不要傳入schema引數,而是讓我們從型別註解中推斷它。自己編寫 schema 容易出錯。如果系統對型別註解的解釋不是你想要的,你可能希望提供自己的 schema。有關如何編寫 schema 字串的更多資訊,請參閱此處

示例:
>>> import torch
>>> from torch import Tensor
>>> from torch.library import custom_op
>>> import numpy as np
>>>
>>> @custom_op("mylib::numpy_sin", mutates_args=())
>>> def numpy_sin(x: Tensor) -> Tensor:
>>>     x_np = x.cpu().numpy()
>>>     y_np = np.sin(x_np)
>>>     return torch.from_numpy(y_np).to(device=x.device)
>>>
>>> x = torch.randn(3)
>>> y = numpy_sin(x)
>>> assert torch.allclose(y, x.sin())
>>>
>>> # Example of a custom op that only works for one device type.
>>> @custom_op("mylib::numpy_sin_cpu", mutates_args=(), device_types="cpu")
>>> def numpy_sin_cpu(x: Tensor) -> Tensor:
>>>     x_np = x.numpy()
>>>     y_np = np.sin(x_np)
>>>     return torch.from_numpy(y_np)
>>>
>>> x = torch.randn(3)
>>> y = numpy_sin_cpu(x)
>>> assert torch.allclose(y, x.sin())
>>>
>>> # Example of a custom op that mutates an input
>>> @custom_op("mylib::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu")
>>> def numpy_sin_inplace(x: Tensor) -> None:
>>>     x_np = x.numpy()
>>>     np.sin(x_np, out=x_np)
>>>
>>> x = torch.randn(3)
>>> expected = x.sin()
>>> numpy_sin_inplace(x)
>>> assert torch.allclose(x, expected)
>>>
>>> # Example of a factory function
>>> @torch.library.custom_op("mylib::bar", mutates_args={}, device_types="cpu")
>>> def bar(device: torch.device) -> Tensor:
>>>     return torch.ones(3)
>>>
>>> bar("cpu")
torch.library.triton_op(name, fn=None, /, *, mutates_args, schema=None)[source]

建立一個由 1 個或多個 triton kernel 支援實現的自定義運算元。

這是使用 triton kernel 與 PyTorch 協同工作的一種更結構化的方式。傾向於直接使用 triton kernel 而不帶torch.library自定義運算元包裝器(例如torch.library.custom_op()torch.library.triton_op())因為這樣更簡單;只有當你想要建立一個行為類似於 PyTorch 內建運算元的運算元時,才使用torch.library.custom_op()/torch.library.triton_op()。例如,你可以使用torch.library包裝器 API 來定義 triton kernel 在接收 tensor subclass 或在 TorchDispatchMode 下的行為。

請注意,當實現由 1 個或多個 triton kernel 組成時,使用torch.library.triton_op()代替torch.library.custom_op()torch.library.custom_op()將自定義運算元視為不透明(torch.compile()torch.export.export()永遠不會追蹤進入它們),但triton_op使實現對這些子系統可見,從而允許它們最佳化 triton kernel。

請注意,fn必須只包含對 PyTorch 可識別的運算元和 triton kernel 的呼叫。在fn內部呼叫的任何 triton kernel 必須包裝在對torch.library.wrap_triton()的呼叫中。

引數
  • name (str) – 自定義運算元的名稱,格式為 “{namespace}::{name}”,例如 “mylib::my_linear”。此名稱在 PyTorch 子系統(例如 torch.export, FX graphs)中用作運算元的穩定識別符號。為避免名稱衝突,請使用你的專案名稱作為 namespace;例如,pytorch/fbgemm 中的所有自定義運算元都使用 “fbgemm” 作為 namespace。

  • mutates_args (Iterable[str] or "unknown") – 函式修改(mutate)的引數(args)的名稱。這必須準確,否則行為是未定義的。如果為 “unknown”,則悲觀地假定運算元的所有輸入都被修改。

  • schema (None | str) – 運算元的 schema 字串。如果為 None(推薦),我們將從其型別註解推斷運算元的 schema。我們建議讓系統推斷 schema,除非你有特殊原因不這樣做。自己編寫 schema 容易出錯。示例:“(Tensor x, int y) -> (Tensor, Tensor)”。

返回型別

可呼叫物件

示例

>>> import torch
>>> from torch.library import triton_op, wrap_triton
>>>
>>> import triton
>>> from triton import language as tl
>>>
>>> @triton.jit
>>> def add_kernel(
>>>     in_ptr0,
>>>     in_ptr1,
>>>     out_ptr,
>>>     n_elements,
>>>     BLOCK_SIZE: "tl.constexpr",
>>> ):
>>>     pid = tl.program_id(axis=0)
>>>     block_start = pid * BLOCK_SIZE
>>>     offsets = block_start + tl.arange(0, BLOCK_SIZE)
>>>     mask = offsets < n_elements
>>>     x = tl.load(in_ptr0 + offsets, mask=mask)
>>>     y = tl.load(in_ptr1 + offsets, mask=mask)
>>>     output = x + y
>>>     tl.store(out_ptr + offsets, output, mask=mask)
>>>
>>> @triton_op("mylib::add", mutates_args={})
>>> def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
>>>     output = torch.empty_like(x)
>>>     n_elements = output.numel()
>>>
>>>     def grid(meta):
>>>         return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
>>>
>>>     # NB: we need to wrap the triton kernel in a call to wrap_triton
>>>     wrap_triton(add_kernel)[grid](x, y, output, n_elements, 16)
>>>     return output
>>>
>>> @torch.compile
>>> def f(x, y):
>>>     return add(x, y)
>>>
>>> x = torch.randn(3, device="cuda")
>>> y = torch.randn(3, device="cuda")
>>>
>>> z = f(x, y)
>>> assert torch.allclose(z, x + y)
torch.library.wrap_triton(triton_kernel, /)[source]

允許透過 make_fx 或非嚴格的torch.export將 triton kernel 捕獲到圖中。

這些技術執行基於 Dispatcher 的追蹤(透過__torch_dispatch__)並且無法看到對原始 triton kernel 的呼叫。wrap_tritonAPI 將 triton kernel 包裝成一個可呼叫物件,該物件實際上可以被追蹤到圖中。

請將此 API 與torch.library.triton_op()一起使用。

示例

>>> import torch
>>> import triton
>>> from triton import language as tl
>>> from torch.fx.experimental.proxy_tensor import make_fx
>>> from torch.library import wrap_triton
>>>
>>> @triton.jit
>>> def add_kernel(
>>>     in_ptr0,
>>>     in_ptr1,
>>>     out_ptr,
>>>     n_elements,
>>>     BLOCK_SIZE: "tl.constexpr",
>>> ):
>>>     pid = tl.program_id(axis=0)
>>>     block_start = pid * BLOCK_SIZE
>>>     offsets = block_start + tl.arange(0, BLOCK_SIZE)
>>>     mask = offsets < n_elements
>>>     x = tl.load(in_ptr0 + offsets, mask=mask)
>>>     y = tl.load(in_ptr1 + offsets, mask=mask)
>>>     output = x + y
>>>     tl.store(out_ptr + offsets, output, mask=mask)
>>>
>>> def add(x, y):
>>>     output = torch.empty_like(x)
>>>     n_elements = output.numel()
>>>
>>>     def grid_fn(meta):
>>>         return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
>>>
>>>     wrap_triton(add_kernel)[grid_fn](x, y, output, n_elements, 16)
>>>     return output
>>>
>>> x = torch.randn(3, device="cuda")
>>> y = torch.randn(3, device="cuda")
>>> gm = make_fx(add)(x, y)
>>> print(gm.code)
>>> # def forward(self, x_1, y_1):
>>> #     empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False)
>>> #     triton_kernel_wrapper_mutation_proxy = triton_kernel_wrapper_mutation(
>>> #         kernel_idx = 0, constant_args_idx = 0,
>>> #         grid = [(1, 1, 1)], kwargs = {
>>> #             'in_ptr0': x_1, 'in_ptr1': y_1, 'out_ptr': empty_like,
>>> #             'n_elements': 3, 'BLOCK_SIZE': 16
>>> #         })
>>> #     return empty_like
返回型別

Any

擴充套件自定義運算元(透過 Python 或 C++ 建立)

使用 register.* 方法,例如torch.library.register_kernel()torch.library.register_fake(),為任何運算元新增實現(它們可能使用torch.library.custom_op()或透過 PyTorch C++ 運算元註冊 API 建立)。

torch.library.register_kernel(op, device_types, func=None, /, *, lib=None)[source][source]

為此運算元的裝置型別註冊一個實現。

一些有效的 device_types 包括:“cpu”、“cuda”、“xla”、“mps”、“ipu”、“xpu”。此 API 可用作裝飾器。

引數
  • op (str | OpOverload) – 要註冊實現的運算元。

  • device_types (None | str | Sequence[str]) – 要註冊實現的 device_types。如果為 None,我們將註冊到所有裝置型別 – 請僅在你的實現真正與裝置型別無關時使用此選項。

  • func (Callable) – 作為給定裝置型別的實現進行註冊的函式。

  • lib (Optional[Library]) – 如果提供,此註冊的生命週期

示例:
>>> import torch
>>> from torch import Tensor
>>> from torch.library import custom_op
>>> import numpy as np
>>>
>>> # Create a custom op that works on cpu
>>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu")
>>> def numpy_sin(x: Tensor) -> Tensor:
>>>     x_np = x.numpy()
>>>     y_np = np.sin(x_np)
>>>     return torch.from_numpy(y_np)
>>>
>>> # Add implementations for the cuda device
>>> @torch.library.register_kernel("mylib::numpy_sin", "cuda")
>>> def _(x):
>>>     x_np = x.cpu().numpy()
>>>     y_np = np.sin(x_np)
>>>     return torch.from_numpy(y_np).to(device=x.device)
>>>
>>> x_cpu = torch.randn(3)
>>> x_cuda = x_cpu.cuda()
>>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin())
>>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin())
torch.library.register_autocast(op, device_type, cast_inputs, /, *, lib=None)[source][source]

為此自定義運算元註冊 autocast 排程規則。

有效的device_type包括:“cpu” 和 “cuda”。

引數
  • op (str | OpOverload) – 要註冊 autocast 排程規則的運算元。

  • device_type (torch.device) – 要使用的裝置型別。“cuda” 或 “cpu”。該型別與torch.devicetype屬性相同。因此,你可以使用Tensor.device.type獲取 tensor 的裝置型別。

  • cast_inputs (torch.dtype) – 當自定義運算元在 autocast 啟用區域內執行時,將輸入的浮點 Tensor 轉換為目標 dtype(非浮點 Tensor 不受影響),然後在停用 autocast 的情況下執行自定義運算元。

  • lib (Optional[Library]) – 如果提供,此註冊的生命週期

示例:
>>> import torch
>>> from torch import Tensor
>>> from torch.library import custom_op
>>>
>>> # Create a custom op that works on cuda
>>> @torch.library.custom_op("mylib::my_sin", mutates_args=())
>>> def my_sin(x: Tensor) -> Tensor:
>>>     return torch.sin(x)
>>>
>>> # Register autocast dispatch rule for the cuda device
>>> torch.library.register_autocast("mylib::my_sin", "cuda", torch.float16)
>>>
>>> x = torch.randn(3, dtype=torch.float32, device="cuda")
>>> with torch.autocast("cuda", dtype=torch.float16):
>>>     y = torch.ops.mylib.my_sin(x)
>>> assert y.dtype == torch.float16
torch.library.register_autograd(op, backward, /, *, setup_context=None, lib=None)[source][source]

為此自定義運算元註冊反向傳播公式。

為了使運算元與 autograd 協同工作,你需要註冊一個反向傳播公式:1. 你必須透過提供一個“backward”函式來告訴我們如何在反向傳播過程中計算梯度。2. 如果你需要從正向傳播中獲取任何值來計算梯度,可以使用setup_context來儲存用於反向傳播的值。

backward在反向傳播過程中執行。它接受(ctx, *grads):-grads是一個或多個梯度。梯度的數量與運算元的輸出數量相匹配。ctx物件是torch.autograd.Function使用的相同的 ctx 物件backward_fn的語義與torch.autograd.Function.backward()相同。

setup_context(ctx, inputs, output)在正向傳播過程中執行。請將反向傳播所需的資料量儲存到ctx物件上,可以透過torch.autograd.function.FunctionCtx.save_for_backward()或將它們作為ctx的屬性來儲存。如果你的自定義運算元有僅關鍵字引數,我們預期setup_context的簽名是setup_context(ctx, inputs, keyword_only_inputs, output)

Both setup_context_fn and backward_fn must be traceable. That is, they may not directly access torch.Tensor.data_ptr() and they must not depend on or mutate global state. If you need a non-traceable backward, you can make it a separate custom_op that you call inside backward_fn.

If you need different autograd behavior on different devices, then we recommend creating two different custom operators, one for each device that needs different behavior, and switching between them at runtime.

示例

>>> import torch
>>> import numpy as np
>>> from torch import Tensor
>>>
>>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=())
>>> def numpy_sin(x: Tensor) -> Tensor:
>>>     x_np = x.cpu().numpy()
>>>     y_np = np.sin(x_np)
>>>     return torch.from_numpy(y_np).to(device=x.device)
>>>
>>> def setup_context(ctx, inputs, output) -> Tensor:
>>>     x, = inputs
>>>     ctx.save_for_backward(x)
>>>
>>> def backward(ctx, grad):
>>>     x, = ctx.saved_tensors
>>>     return grad * x.cos()
>>>
>>> torch.library.register_autograd(
...     "mylib::numpy_sin", backward, setup_context=setup_context
... )
>>>
>>> x = torch.randn(3, requires_grad=True)
>>> y = numpy_sin(x)
>>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
>>> assert torch.allclose(grad_x, x.cos())
>>>
>>> # Example with a keyword-only arg
>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
>>> def numpy_mul(x: Tensor, *, val: float) -> Tensor:
>>>     x_np = x.cpu().numpy()
>>>     y_np = x_np * val
>>>     return torch.from_numpy(y_np).to(device=x.device)
>>>
>>> def setup_context(ctx, inputs, keyword_only_inputs, output) -> Tensor:
>>>     ctx.val = keyword_only_inputs["val"]
>>>
>>> def backward(ctx, grad):
>>>     return grad * ctx.val
>>>
>>> torch.library.register_autograd(
...     "mylib::numpy_mul", backward, setup_context=setup_context
... )
>>>
>>> x = torch.randn(3, requires_grad=True)
>>> y = numpy_mul(x, val=3.14)
>>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
>>> assert torch.allclose(grad_x, torch.full_like(x, 3.14))
torch.library.register_fake(op, func=None, /, *, lib=None, _stacklevel=1)[source][source]

Register a FakeTensor implementation (“fake impl”) for this operator.

Also sometimes known as a “meta kernel”, “abstract impl”.

An “FakeTensor implementation” specifies the behavior of this operator on Tensors that carry no data (“FakeTensor”). Given some input Tensors with certain properties (sizes/strides/storage_offset/device), it specifies what the properties of the output Tensors are.

The FakeTensor implementation has the same signature as the operator. It is run for both FakeTensors and meta tensors. To write a FakeTensor implementation, assume that all Tensor inputs to the operator are regular CPU/CUDA/Meta tensors, but they do not have storage, and you are trying to return regular CPU/CUDA/Meta tensor(s) as output. The FakeTensor implementation must consist of only PyTorch operations (and may not directly access the storage or data of any input or intermediate Tensors).

This API may be used as a decorator (see examples).

For a detailed guide on custom ops, please see https://pytorch.com.tw/tutorials/advanced/custom_ops_landing_page.html

示例

>>> import torch
>>> import numpy as np
>>> from torch import Tensor
>>>
>>> # Example 1: an operator without data-dependent output shape
>>> @torch.library.custom_op("mylib::custom_linear", mutates_args=())
>>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
>>>     raise NotImplementedError("Implementation goes here")
>>>
>>> @torch.library.register_fake("mylib::custom_linear")
>>> def _(x, weight, bias):
>>>     assert x.dim() == 2
>>>     assert weight.dim() == 2
>>>     assert bias.dim() == 1
>>>     assert x.shape[1] == weight.shape[1]
>>>     assert weight.shape[0] == bias.shape[0]
>>>     assert x.device == weight.device
>>>
>>>     return (x @ weight.t()) + bias
>>>
>>> with torch._subclasses.fake_tensor.FakeTensorMode():
>>>     x = torch.randn(2, 3)
>>>     w = torch.randn(3, 3)
>>>     b = torch.randn(3)
>>>     y = torch.ops.mylib.custom_linear(x, w, b)
>>>
>>> assert y.shape == (2, 3)
>>>
>>> # Example 2: an operator with data-dependent output shape
>>> @torch.library.custom_op("mylib::custom_nonzero", mutates_args=())
>>> def custom_nonzero(x: Tensor) -> Tensor:
>>>     x_np = x.numpy(force=True)
>>>     res = np.stack(np.nonzero(x_np), axis=1)
>>>     return torch.tensor(res, device=x.device)
>>>
>>> @torch.library.register_fake("mylib::custom_nonzero")
>>> def _(x):
>>> # Number of nonzero-elements is data-dependent.
>>> # Since we cannot peek at the data in an fake impl,
>>> # we use the ctx object to construct a new symint that
>>> # represents the data-dependent size.
>>>     ctx = torch.library.get_ctx()
>>>     nnz = ctx.new_dynamic_size()
>>>     shape = [nnz, x.dim()]
>>>     result = x.new_empty(shape, dtype=torch.int64)
>>>     return result
>>>
>>> from torch.fx.experimental.proxy_tensor import make_fx
>>>
>>> x = torch.tensor([0, 1, 2, 3, 4, 0])
>>> trace = make_fx(torch.ops.mylib.custom_nonzero, tracing_mode="symbolic")(x)
>>> trace.print_readable()
>>>
>>> assert torch.allclose(trace(x), torch.ops.mylib.custom_nonzero(x))
torch.library.register_vmap(op, func=None, /, *, lib=None)[source][source]

Register a vmap implementation to support torch.vmap() for this custom op.

This API may be used as a decorator (see examples).

In order for an operator to work with torch.vmap(), you may need to register a vmap implementation in the following signature

vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs),

where *args and **kwargs are the arguments and kwargs for op. We do not support kwarg-only Tensor args.

It specifies how do we compute the batched version of op given inputs with an additional dimension (specified by in_dims).

For each arg in args, in_dims has a corresponding Optional[int]. It is None if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer specifying what dimension of the Tensor is being vmapped over.

info is a collection of additional metadata that may be helpful: info.batch_size specifies the size of the dimension being vmapped over, while info.randomness is the randomness option that was passed to torch.vmap().

The return of the function func is a tuple of (output, out_dims). Similar to in_dims, out_dims should be of the same structure as output and contain one out_dim per output that specifies if the output has the vmapped dimension and what index it is in.

示例

>>> import torch
>>> import numpy as np
>>> from torch import Tensor
>>> from typing import Tuple
>>>
>>> def to_numpy(tensor):
>>>     return tensor.cpu().numpy()
>>>
>>> lib = torch.library.Library("mylib", "FRAGMENT")
>>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=())
>>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]:
>>>     x_np = to_numpy(x)
>>>     dx = torch.tensor(3 * x_np ** 2, device=x.device)
>>>     return torch.tensor(x_np ** 3, device=x.device), dx
>>>
>>> def numpy_cube_vmap(info, in_dims, x):
>>>     result = numpy_cube(x)
>>>     return result, (in_dims[0], in_dims[0])
>>>
>>> torch.library.register_vmap(numpy_cube, numpy_cube_vmap)
>>>
>>> x = torch.randn(3)
>>> torch.vmap(numpy_cube)(x)
>>>
>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
>>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor:
>>>     return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
>>>
>>> @torch.library.register_vmap("mylib::numpy_mul")
>>> def numpy_mul_vmap(info, in_dims, x, y):
>>>     x_bdim, y_bdim = in_dims
>>>     x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
>>>     y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
>>>     result = x * y
>>>     result = result.movedim(-1, 0)
>>>     return result, 0
>>>
>>>
>>> x = torch.randn(3)
>>> y = torch.randn(3)
>>> torch.vmap(numpy_mul)(x, y)

注意

The vmap function should aim to preserve the semantics of the entire custom operator. That is, grad(vmap(op)) should be replaceable with a grad(map(op)).

If your custom operator has any custom behavior in the backward pass, please keep this in mind.

torch.library.impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1)[source][source]

This API was renamed to torch.library.register_fake() in PyTorch 2.4. Please use that instead.

torch.library.get_ctx()[source][source]

get_ctx() returns the current AbstractImplCtx object.

Calling get_ctx() is only valid inside of an fake impl (see torch.library.register_fake() for more usage details.

返回型別

FakeImplCtx

torch.library.register_torch_dispatch(op, torch_dispatch_class, func=None, /, *, lib=None)[source][source]

Registers a torch_dispatch rule for the given operator and torch_dispatch_class.

This allows for open registration to specify the behavior between the operator and the torch_dispatch_class without needing to modify the torch_dispatch_class or the operator directly.

The torch_dispatch_class is either a Tensor subclass with __torch_dispatch__ or a TorchDispatchMode.

If it is a Tensor subclass, we expect func to have the following signature: (cls, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any

If it is a TorchDispatchMode, we expect func to have the following signature: (mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any

args and kwargs will have been normalized the same way they are in __torch_dispatch__ (see __torch_dispatch__ calling convention).

示例

>>> import torch
>>>
>>> @torch.library.custom_op("mylib::foo", mutates_args={})
>>> def foo(x: torch.Tensor) -> torch.Tensor:
>>>     return x.clone()
>>>
>>> class MyMode(torch.utils._python_dispatch.TorchDispatchMode):
>>>     def __torch_dispatch__(self, func, types, args=(), kwargs=None):
>>>         return func(*args, **kwargs)
>>>
>>> @torch.library.register_torch_dispatch("mylib::foo", MyMode)
>>> def _(mode, func, types, args, kwargs):
>>>     x, = args
>>>     return x + 1
>>>
>>> x = torch.randn(3)
>>> y = foo(x)
>>> assert torch.allclose(y, x)
>>>
>>> with MyMode():
>>>     y = foo(x)
>>> assert torch.allclose(y, x + 1)
torch.library.infer_schema(prototype_function, /, *, mutates_args, op_name=None)[source]

Parses the schema of a given function with type hints. The schema is inferred from the function’s type hints, and can be used to define a new operator.

We make the following assumptions

  • None of the outputs alias any of the inputs or each other.

  • String type annotations “device, dtype, Tensor, types” without library specification are
    assumed to be torch.*. Similarly, string type annotations “Optional, List, Sequence, Union”
    without library specification are assumed to be typing.*.
  • Only the args listed in mutates_args are being mutated. If mutates_args is “unknown”,
    it assumes that all inputs to the operator are being mutates.

Callers (e.g. the custom ops API) are responsible for checking these assumptions.

引數
  • prototype_function (Callable) – The function from which to infer a schema for from its type annotations.

  • op_name (Optional[str]) – The name of the operator in the schema. If name is None, then the name is not included in the inferred schema. Note that the input schema to torch.library.Library.define requires a operator name.

  • mutates_args ("unknown" | Iterable[str]) – The arguments that are mutated in the function.

Returns

The inferred schema.

返回型別

str

示例

>>> def foo_impl(x: torch.Tensor) -> torch.Tensor:
>>>     return x.sin()
>>>
>>> infer_schema(foo_impl, op_name="foo", mutates_args={})
foo(Tensor x) -> Tensor
>>>
>>> infer_schema(foo_impl, mutates_args={})
(Tensor x) -> Tensor
class torch._library.custom_ops.CustomOpDef(namespace, name, schema, fn)[source][source]

CustomOpDef is a wrapper around a function that turns it into a custom op.

It has various methods for registering additional behavior for this custom op.

You should not instantiate CustomOpDef directly; instead, use the torch.library.custom_op() API.

set_kernel_enabled(device_type, enabled=True)[source][source]

Disable or re-enable an already registered kernel for this custom operator.

If the kernel is already disabled/enabled, this is a no-op.

注意

If a kernel is first disabled and then registered, it is disabled until enabled again.

引數
  • device_type (str) – The device type to disable/enable the kernel for.

  • enabled (bool) – Whether to disable or enable the kernel.

示例

>>> inp = torch.randn(1)
>>>
>>> # define custom op `f`.
>>> @custom_op("mylib::f", mutates_args=())
>>> def f(x: Tensor) -> Tensor:
>>>     return torch.zeros(1)
>>>
>>> print(f(inp))  # tensor([0.]), default kernel
>>>
>>> @f.register_kernel("cpu")
>>> def _(x):
>>>     return torch.ones(1)
>>>
>>> print(f(inp))  # tensor([1.]), CPU kernel
>>>
>>> # temporarily disable the CPU kernel
>>> with f.set_kernel_enabled("cpu", enabled = False):
>>>     print(f(inp))  # tensor([0.]) with CPU kernel disabled

Low-level APIs

The following APIs are direct bindings to PyTorch’s C++ low-level operator registration APIs.

警告

The low-level operator registration APIs and the PyTorch Dispatcher are a complicated PyTorch concept. We recommend you use the higher level APIs above (that do not require a torch.library.Library object) when possible. This blog post <http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/>`_ is a good starting point to learn about the PyTorch Dispatcher.

A tutorial that walks you through some examples on how to use this API is available on Google Colab.

class torch.library.Library(ns, kind, dispatch_key='')[source][source]

A class to create libraries that can be used to register new operators or override operators in existing libraries from Python. A user can optionally pass in a dispatch keyname if they only want to register kernels corresponding to only one specific dispatch key.

To create a library to override operators in an existing library (with name ns), set the kind to “IMPL”. To create a new library (with name ns) to register new operators, set the kind to “DEF”. To create a fragment of a possibly existing library to register operators (and bypass the limitation that there is only one library for a given namespace), set the kind to “FRAGMENT”.

引數
  • ns – library name

  • kind – “DEF”, “IMPL” (default: “IMPL”), “FRAGMENT”

  • dispatch_key – PyTorch dispatch key (default: “”)

define(schema, alias_analysis='', *, tags=())[source][source]

Defines a new operator and its semantics in the ns namespace.

引數
  • schema – function schema to define a new operator.

  • alias_analysis (optional) – Indicates if the aliasing properties of the operator arguments can be inferred from the schema (default behavior) or not (“CONSERVATIVE”).

  • tags (Tag | Sequence[Tag]) – one or more torch.Tag to apply to this operator. Tagging an operator changes the operator’s behavior under various PyTorch subsystems; please read the docs for the torch.Tag carefully before applying it.

Returns

name of the operator as inferred from the schema.

Example:
>>> my_lib = Library("mylib", "DEF")
>>> my_lib.define("sum(Tensor self) -> Tensor")
fallback(fn, dispatch_key='', *, with_keyset=False)[source][source]

Registers the function implementation as the fallback for the given key.

This function only works for a library with global namespace (“_”).

引數
  • fn – function used as fallback for the given dispatch key or fallthrough_kernel() to register a fallthrough.

  • dispatch_key – dispatch key that the input function should be registered for. By default, it uses the dispatch key that the library was created with.

  • with_keyset – flag controlling if the current dispatcher call keyset should be passed as the first argument to fn when calling. This should be used to create the appropriate keyset for redispatch calls.

Example:
>>> my_lib = Library("_", "IMPL")
>>> def fallback_kernel(op, *args, **kwargs):
>>>     # Handle all autocast ops generically
>>>     # ...
>>> my_lib.fallback(fallback_kernel, "Autocast")
impl(op_name, fn, dispatch_key='', *, with_keyset=False)[source][source]

Registers the function implementation for an operator defined in the library.

引數
  • op_name – operator name (along with the overload) or OpOverload object.

  • fn – function that’s the operator implementation for the input dispatch key or fallthrough_kernel() to register a fallthrough.

  • dispatch_key – dispatch key that the input function should be registered for. By default, it uses the dispatch key that the library was created with.

  • with_keyset – flag controlling if the current dispatcher call keyset should be passed as the first argument to fn when calling. This should be used to create the appropriate keyset for redispatch calls.

Example:
>>> my_lib = Library("aten", "IMPL")
>>> def div_cpu(self, other):
>>>     return self * (1 / other)
>>> my_lib.impl("div.Tensor", div_cpu, "CPU")
torch.library.fallthrough_kernel()[source][source]

A dummy function to pass to Library.impl in order to register a fallthrough.

torch.library.define(qualname, schema, *, lib=None, tags=())[source][source]
torch.library.define(lib, schema, alias_analysis='')

Defines a new operator.

In PyTorch, defining an op (short for “operator”) is a two step-process: - we need to define the op (by providing an operator name and schema) - we need to implement behavior for how the operator interacts with various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc.

This entrypoint defines the custom operator (the first step) you must then perform the second step by calling various impl_* APIs, like torch.library.impl() or torch.library.register_fake().

引數
  • qualname (str) – The qualified name for the operator. Should be a string that looks like “namespace::name”, e.g. “aten::sin”. Operators in PyTorch need a namespace to avoid name collisions; a given operator may only be created once. If you are writing a Python library, we recommend the namespace to be the name of your top-level module.

  • schema (str) – The schema of the operator. E.g. “(Tensor x) -> Tensor” for an op that accepts one Tensor and returns one Tensor. It does not contain the operator name (that is passed in qualname).

  • lib (Optional[Library]) – If provided, the lifetime of this operator will be tied to the lifetime of the Library object.

  • tags (Tag | Sequence[Tag]) – one or more torch.Tag to apply to this operator. Tagging an operator changes the operator’s behavior under various PyTorch subsystems; please read the docs for the torch.Tag carefully before applying it.

Example:
>>> import torch
>>> import numpy as np
>>>
>>> # Define the operator
>>> torch.library.define("mylib::sin", "(Tensor x) -> Tensor")
>>>
>>> # Add implementations for the operator
>>> @torch.library.impl("mylib::sin", "cpu")
>>> def f(x):
>>>     return torch.from_numpy(np.sin(x.numpy()))
>>>
>>> # Call the new operator from torch.ops.
>>> x = torch.randn(3)
>>> y = torch.ops.mylib.sin(x)
>>> assert torch.allclose(y, x.sin())
torch.library.impl(lib, name, dispatch_key='')[source][source]
torch.library.impl(qualname: str, types: Union[str, Sequence[str]], func: Literal[None] = None, *, lib: Optional[Library] = None) Callable[[Callable[..., object]], None]
torch.library.impl(qualname: str, types: Union[str, Sequence[str]], func: Callable[..., object], *, lib: Optional[Library] = None) None
torch.library.impl(lib: Library, name: str, dispatch_key: str = '') Callable[[Callable[_P, _T]], Callable[_P, _T]]

為此運算元的裝置型別註冊一個實現。

您可以為 types 傳遞“default”來將此實現註冊為適用於所有裝置型別的預設實現。僅當實現真正支援所有裝置型別時才使用此選項;例如,如果它是內建 PyTorch 運算元的組合,則情況屬實。

這個 API 可以用作裝飾器。你可以使用巢狀的裝飾器,前提是它們返回一個函式並放置在此 API 內部(參見示例 2)。

一些有效的型別包括:“cpu”、“cuda”、“xla”、“mps”、“ipu”、“xpu”。

引數
  • qualname (str) – 應為一個形似“namespace::operator_name”的字串。

  • types (str | Sequence[str]) – 要註冊實現的裝置型別。

  • lib (Optional[Library]) – 如果提供,此註冊的生命週期將與 Library 物件的生命週期繫結。

示例

>>> import torch
>>> import numpy as np
>>> # Example 1: Register function.
>>> # Define the operator
>>> torch.library.define("mylib::mysin", "(Tensor x) -> Tensor")
>>>
>>> # Add implementations for the cpu device
>>> @torch.library.impl("mylib::mysin", "cpu")
>>> def f(x):
>>>     return torch.from_numpy(np.sin(x.numpy()))
>>>
>>> x = torch.randn(3)
>>> y = torch.ops.mylib.mysin(x)
>>> assert torch.allclose(y, x.sin())
>>>
>>> # Example 2: Register function with decorator.
>>> def custom_decorator(func):
>>>     def wrapper(*args, **kwargs):
>>>         return func(*args, **kwargs) + 1
>>>     return wrapper
>>>
>>> # Define the operator
>>> torch.library.define("mylib::sin_plus_one", "(Tensor x) -> Tensor")
>>>
>>> # Add implementations for the operator
>>> @torch.library.impl("mylib::sin_plus_one", "cpu")
>>> @custom_decorator
>>> def f(x):
>>>     return torch.from_numpy(np.sin(x.numpy()))
>>>
>>> # Call the new operator from torch.ops.
>>> x = torch.randn(3)
>>>
>>> y1 = torch.ops.mylib.sin_plus_one(x)
>>> y2 = torch.sin(x) + 1
>>> assert torch.allclose(y1, y2)

文件

訪問 PyTorch 開發者文件大全

檢視文件

教程

獲取面向初學者和高階開發人員的深度教程

檢視教程

資源

查詢開發資源並解答您的疑問

檢視資源