快捷方式

如何編寫自己的 TVTensor 類

注意

Colab 上嘗試 或 轉到末尾 下載完整示例程式碼。

本指南適用於高階使用者和下游庫維護者。我們將解釋如何編寫自己的 TVTensor 類,以及如何使其與內建的 Torchvision v2 變換相容。在繼續之前,請確保您已閱讀 TVTensors FAQ

import torch
from torchvision import tv_tensors
from torchvision.transforms import v2

我們將建立一個非常簡單的類,它只繼承自基礎 TVTensor 類。這足以涵蓋實現更復雜用例所需的知識。如果需要建立攜帶元資料的類,可以參考 BoundingBoxes 類的實現

class MyTVTensor(tv_tensors.TVTensor):
    pass


my_dp = MyTVTensor([1, 2, 3])
my_dp
MyTVTensor([1., 2., 3.])

現在我們已經定義了自定義的 TVTensor 類,我們希望它能與內建的 torchvision 變換和函式式 API 相容。為此,我們需要實現一個執行變換核心邏輯的 kernel,然後透過 register_kernel() 將其“掛鉤”到我們想要支援的 functional 上。

我們將在下方說明此過程:為 MyTVTensor 類的“水平翻轉”操作建立 kernel,並將其註冊到函式式 API。

from torchvision.transforms.v2 import functional as F


@F.register_kernel(functional="hflip", tv_tensor_cls=MyTVTensor)
def hflip_my_tv_tensor(my_dp, *args, **kwargs):
    print("Flipping!")
    out = my_dp.flip(-1)
    return tv_tensors.wrap(out, like=my_dp)

要理解為何使用 wrap(),請參閱 我有一個 TVTensor,但現在它變成了 Tensor。救命!。請暫時忽略 *args, **kwargs,我們將在下方 引數轉發以及確保 kernel 的未來相容性 中解釋它。

注意

在我們上面呼叫 register_kernel 時,我們使用字串 functional="hflip" 來指代我們想要掛鉤的 functional。我們也可以直接使用 functional 本身,即 @register_kernel(functional=F.hflip, ...)

現在我們已經註冊了 kernel,可以在 MyTVTensor 例項上呼叫函式式 API。

my_dp = MyTVTensor(torch.rand(3, 256, 256))
_ = F.hflip(my_dp)
Flipping!

我們還可以使用 RandomHorizontalFlip 變換,因為它內部依賴於 hflip()

t = v2.RandomHorizontalFlip(p=1)
_ = t(my_dp)
Flipping!

注意

我們不能為變換類註冊 kernel,只能為 functional 註冊 kernel。不能註冊變換類的原因是,一個變換內部可能依賴多個 functional,所以通常情況下我們無法為給定類註冊單個 kernel。

引數轉發以及確保 kernel 的未來相容性

您掛鉤的函式式 API 是公開的,因此具有向後相容性:我們保證這些 functional 的引數不會在沒有適當棄用週期的情況下被移除或重新命名。然而,我們不保證向前相容性,將來我們可能會新增新引數。

假設在將來的版本中,Torchvision 的 hflip() functional 中添加了一個新的 inplace 引數。如果您已經將自己的 kernel 定義並註冊為

def hflip_my_tv_tensor(my_dp):  # noqa
    print("Flipping!")
    out = my_dp.flip(-1)
    return tv_tensors.wrap(out, like=my_dp)

那麼呼叫 F.hflip(my_dp) 將會失敗,因為 hflip 會嘗試將新的 inplace 引數傳遞給您的 kernel,但您的 kernel 不接受它。

因此,我們建議您始終按照上述示例在 kernel 的簽名中包含 *args, **kwargs。這樣,您的 kernel 就能接受未來可能新增的任何新引數。(技術上講,只新增 **kwargs 應該就足夠了)。

指令碼總執行時間: (0 分鐘 0.004 秒)

由 Sphinx-Gallery 生成的相簿

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源