注意
點選此處下載完整的示例程式碼
自定義 Python 運算元¶
創建於:2024 年 6 月 18 日 | 最後更新:2025 年 3 月 19 日 | 最後驗證:2024 年 11 月 5 日
如何將用 Python 編寫的自定義運算元與 PyTorch 整合
如何使用
torch.library.opcheck測試自定義運算元
PyTorch 2.4 或更高版本
PyTorch 提供了大量的運算元庫,可用於處理張量(例如 torch.add、torch.sum 等)。但是,您可能希望在 PyTorch 中使用新的定製運算元,例如由第三方庫編寫的運算元。本教程演示瞭如何包裝 Python 函式,使其行為類似於 PyTorch 原生運算元。您可能希望在 PyTorch 中建立自定義運算元的原因包括:
將任意 Python 函式視為
torch.compile的不透明可呼叫物件(即阻止torch.compile跟蹤到該函式內部)。為任意 Python 函式新增訓練支援
使用 torch.library.custom_op() 建立 Python 自定義運算元。使用 C++ TORCH_LIBRARY API 建立 C++ 自定義運算元(這些運算元可在無 Python 環境中工作)。有關更多詳細資訊,請參閱自定義運算元登陸頁。
請注意,如果您的操作可以透過現有 PyTorch 運算元的組合來表達,則通常無需使用自定義運算元 API – 所有功能(例如 torch.compile、訓練支援)都應該能夠正常工作。
示例:將 PIL 的 crop 包裝成自定義運算元¶
假設我們正在使用 PIL 的 crop 操作。
import torch
from torchvision.transforms.functional import to_pil_image, pil_to_tensor
import PIL
import IPython
import matplotlib.pyplot as plt
def crop(pic, box):
img = to_pil_image(pic.cpu())
cropped_img = img.crop(box)
return pil_to_tensor(cropped_img).to(pic.device) / 255.
def display(img):
plt.imshow(img.numpy().transpose((1, 2, 0)))
img = torch.ones(3, 64, 64)
img *= torch.linspace(0, 1, steps=64) * torch.linspace(0, 1, steps=64).unsqueeze(-1)
display(img)

cropped_img = crop(img, (10, 10, 50, 50))
display(cropped_img)

crop 不能被 torch.compile 有效地直接處理:torch.compile 會在其無法處理的函式上產生“圖中斷 (graph break)”,而圖中斷會影響效能。以下程式碼透過引發錯誤來演示這一點(torch.compile 配合 fullgraph=True 會在發生圖中斷時引發錯誤)。
為了將 crop 視為黑盒以配合 torch.compile 使用,我們需要做兩件事:
將函式包裝成 PyTorch 自定義運算元。
為運算元新增一個“
FakeTensor核”(也稱為“元核”)。給定一些FakeTensor輸入(沒有儲存的虛擬張量),此函式應返回您選擇的虛擬張量,並具有正確的張量元資料(形狀/步長/dtype/裝置)。
from typing import Sequence
# Use torch.library.custom_op to define a new custom operator.
# If your operator mutates any input Tensors, their names must be specified
# in the ``mutates_args`` argument.
@torch.library.custom_op("mylib::crop", mutates_args=())
def crop(pic: torch.Tensor, box: Sequence[int]) -> torch.Tensor:
img = to_pil_image(pic.cpu())
cropped_img = img.crop(box)
return (pil_to_tensor(cropped_img) / 255.).to(pic.device, pic.dtype)
# Use register_fake to add a ``FakeTensor`` kernel for the operator
@crop.register_fake
def _(pic, box):
channels = pic.shape[0]
x0, y0, x1, y1 = box
result = pic.new_empty(y1 - y0, x1 - x0, channels).permute(2, 0, 1)
# The result should have the same metadata (shape/strides/``dtype``/device)
# as running the ``crop`` function above.
return result
之後,crop 現在可以在沒有圖中斷的情況下工作

display(cropped_img)

為 crop 新增訓練支援¶
使用 torch.library.register_autograd 為運算元新增訓練支援。首選這種方式,而不是直接使用 torch.autograd.Function;autograd.Function 與 PyTorch 運算元註冊 API 的某些組合在與 torch.compile 組合時可能導致(並且已經導致)隱式的錯誤行為。
如果您不需要訓練支援,則無需使用 torch.library.register_autograd。如果您在使用沒有 autograd 註冊的 custom_op 進行訓練時,我們將引發錯誤訊息。
crop 的梯度公式本質上是 PIL.paste(我們將推導留給讀者作為練習)。首先,讓我們將 paste 包裝成一個自定義運算元
@torch.library.custom_op("mylib::paste", mutates_args=())
def paste(im1: torch.Tensor, im2: torch.Tensor, coord: Sequence[int]) -> torch.Tensor:
assert im1.device == im2.device
assert im1.dtype == im2.dtype
im1_pil = to_pil_image(im1.cpu())
im2_pil = to_pil_image(im2.cpu())
PIL.Image.Image.paste(im1_pil, im2_pil, coord)
return (pil_to_tensor(im1_pil) / 255.).to(im1.device, im1.dtype)
@paste.register_fake
def _(im1, im2, coord):
assert im1.device == im2.device
assert im1.dtype == im2.dtype
return torch.empty_like(im1)
現在,讓我們使用 register_autograd 來指定 crop 的梯度公式
def backward(ctx, grad_output):
grad_input = grad_output.new_zeros(ctx.pic_shape)
grad_input = paste(grad_input, grad_output, ctx.coords)
return grad_input, None
def setup_context(ctx, inputs, output):
pic, box = inputs
ctx.coords = box[:2]
ctx.pic_shape = pic.shape
crop.register_autograd(backward, setup_context=setup_context)
請注意,反向傳播必須是 PyTorch 可理解的運算元的組合,這就是我們將 paste 包裝成自定義運算元而不是直接使用 PIL 的 paste 的原因。

這是正確的梯度,裁剪區域為 1(白色),未使用區域為 0(黑色)。
測試 Python 自定義運算元¶
使用 torch.library.opcheck 測試自定義運算元是否正確註冊。這不會測試梯度在數學上是否正確;請為此編寫單獨的測試(手動測試或 torch.autograd.gradcheck)。
要使用 opcheck,請向其傳遞一組示例輸入進行測試。如果您的運算元支援訓練,則示例應包含需要梯度的張量。如果您的運算元支援多種裝置,則示例應包含來自每種裝置的張量。
examples = [
[torch.randn(3, 64, 64), [0, 0, 10, 10]],
[torch.randn(3, 91, 91, requires_grad=True), [10, 0, 20, 10]],
[torch.randn(3, 60, 60, dtype=torch.double), [3, 4, 32, 20]],
[torch.randn(3, 512, 512, requires_grad=True, dtype=torch.double), [3, 4, 32, 45]],
]
for example in examples:
torch.library.opcheck(crop, example)
可變 Python 自定義運算元¶
您還可以將修改輸入的 Python 函式包裝成自定義運算元。修改輸入的函式很常見,因為許多底層核 (kernel) 就是這樣編寫的;例如,計算 sin 的核可能會接收輸入和輸出張量,並將 input.sin() 寫入輸出張量。
我們將使用 numpy.sin 來演示一個可變 Python 自定義運算元的例子。
import numpy as np
@torch.library.custom_op("mylib::numpy_sin", mutates_args={"output"}, device_types="cpu")
def numpy_sin(input: torch.Tensor, output: torch.Tensor) -> None:
assert input.device == output.device
assert input.device.type == "cpu"
input_np = input.numpy()
output_np = output.numpy()
np.sin(input_np, out=output_np)
由於該運算元不返回任何內容,因此無需註冊 FakeTensor 核(元核)即可使其與 torch.compile 一起工作。
@torch.compile(fullgraph=True)
def f(x):
out = torch.empty(3)
numpy_sin(x, out)
return out
x = torch.randn(3)
y = f(x)
assert torch.allclose(y, x.sin())
這是一次 opcheck 執行,告訴我們確實正確註冊了運算元。例如,如果我們忘記將輸出新增到 mutates_args,opcheck 將會報錯。
example_inputs = [
[torch.randn(3), torch.empty(3)],
[torch.randn(0, 3), torch.empty(0, 3)],
[torch.randn(1, 2, 3, 4, dtype=torch.double), torch.empty(1, 2, 3, 4, dtype=torch.double)],
]
for example in example_inputs:
torch.library.opcheck(numpy_sin, example)
結論¶
在本教程中,我們學習瞭如何使用 torch.library.custom_op 在 Python 中建立一個自定義運算元,該運算元可與 PyTorch 的子系統(例如 torch.compile 和 autograd)一起工作。
本教程提供了對自定義運算元的基本介紹。有關更詳細的資訊,請參閱
指令碼總執行時間: ( 0 分鐘 2.402 秒)