如何編寫自己的 TVTensor 類¶
本指南適用於高階使用者和下游庫維護者。我們將解釋如何編寫自己的 TVTensor 類,以及如何使其與內建的 Torchvision v2 變換相容。在繼續之前,請確保您已閱讀 TVTensors FAQ。
import torch
from torchvision import tv_tensors
from torchvision.transforms import v2
我們將建立一個非常簡單的類,它只繼承自基礎 TVTensor 類。這足以涵蓋實現更復雜用例所需的知識。如果需要建立攜帶元資料的類,可以參考 BoundingBoxes 類的實現。
class MyTVTensor(tv_tensors.TVTensor):
pass
my_dp = MyTVTensor([1, 2, 3])
my_dp
MyTVTensor([1., 2., 3.])
現在我們已經定義了自定義的 TVTensor 類,我們希望它能與內建的 torchvision 變換和函式式 API 相容。為此,我們需要實現一個執行變換核心邏輯的 kernel,然後透過 register_kernel() 將其“掛鉤”到我們想要支援的 functional 上。
我們將在下方說明此過程:為 MyTVTensor 類的“水平翻轉”操作建立 kernel,並將其註冊到函式式 API。
from torchvision.transforms.v2 import functional as F
@F.register_kernel(functional="hflip", tv_tensor_cls=MyTVTensor)
def hflip_my_tv_tensor(my_dp, *args, **kwargs):
print("Flipping!")
out = my_dp.flip(-1)
return tv_tensors.wrap(out, like=my_dp)
要理解為何使用 wrap(),請參閱 我有一個 TVTensor,但現在它變成了 Tensor。救命!。請暫時忽略 *args, **kwargs,我們將在下方 引數轉發以及確保 kernel 的未來相容性 中解釋它。
注意
在我們上面呼叫 register_kernel 時,我們使用字串 functional="hflip" 來指代我們想要掛鉤的 functional。我們也可以直接使用 functional 本身,即 @register_kernel(functional=F.hflip, ...)。
現在我們已經註冊了 kernel,可以在 MyTVTensor 例項上呼叫函式式 API。
my_dp = MyTVTensor(torch.rand(3, 256, 256))
_ = F.hflip(my_dp)
Flipping!
我們還可以使用 RandomHorizontalFlip 變換,因為它內部依賴於 hflip()。
t = v2.RandomHorizontalFlip(p=1)
_ = t(my_dp)
Flipping!
注意
我們不能為變換類註冊 kernel,只能為 functional 註冊 kernel。不能註冊變換類的原因是,一個變換內部可能依賴多個 functional,所以通常情況下我們無法為給定類註冊單個 kernel。
引數轉發以及確保 kernel 的未來相容性¶
您掛鉤的函式式 API 是公開的,因此具有向後相容性:我們保證這些 functional 的引數不會在沒有適當棄用週期的情況下被移除或重新命名。然而,我們不保證向前相容性,將來我們可能會新增新引數。
假設在將來的版本中,Torchvision 的 hflip() functional 中添加了一個新的 inplace 引數。如果您已經將自己的 kernel 定義並註冊為
def hflip_my_tv_tensor(my_dp): # noqa
print("Flipping!")
out = my_dp.flip(-1)
return tv_tensors.wrap(out, like=my_dp)
那麼呼叫 F.hflip(my_dp) 將會失敗,因為 hflip 會嘗試將新的 inplace 引數傳遞給您的 kernel,但您的 kernel 不接受它。
因此,我們建議您始終按照上述示例在 kernel 的簽名中包含 *args, **kwargs。這樣,您的 kernel 就能接受未來可能新增的任何新引數。(技術上講,只新增 **kwargs 應該就足夠了)。
指令碼總執行時間: (0 分鐘 0.004 秒)