快捷方式

TVTensors 常見問題解答

注意

Colab 上嘗試或 跳轉至末尾 以下載完整的示例程式碼。

TVTensors 是與 torchvision.transforms.v2 一起引入的 Tensor 子類。本示例展示了 TVTensors 是什麼以及它們的行為方式。

警告

目標讀者 除非您正在編寫自己的變換或 TVTensors,否則您可能不需要閱讀本指南。這是一個相當底層的議題,大多數使用者無需擔心:您無需理解 TVTensors 的內部機制即可有效地使用 torchvision.transforms.v2。然而,對於嘗試實現自己的資料集、變換或直接使用 TVTensors 的高階使用者而言,它可能很有用。

import PIL.Image

import torch
from torchvision import tv_tensors

什麼是 TVTensors?

TVTensors 是零複製 Tensor 子類

tensor = torch.rand(3, 256, 256)
image = tv_tensors.Image(tensor)

assert isinstance(image, torch.Tensor)
assert image.data_ptr() == tensor.data_ptr()

在底層,torchvision.transforms.v2 中需要它們,以便根據輸入資料正確地分派到適當的函式。

torchvision.tv_tensors 支援四種類型的 TVTensors:

TVTensor 可以用來做什麼?

TVTensors 的外觀和使用感覺與普通 tensors 完全一樣 - 它們**就是** tensors。普通 torch.Tensor 上支援的一切,例如 .sum() 或任何 torch.* 運算元,也適用於 TVTensors。請參閱 我有一個 TVTensor,但現在變成 Tensor 了。求助! 瞭解一些需要注意的地方。

如何構造 TVTensor?

使用建構函式

每個 TVTensor 類都接受任何可以轉換為 Tensor 的類 tensor 資料。

image = tv_tensors.Image([[[[0, 1], [1, 0]]]])
print(image)
Image([[[[0, 1],
         [1, 0]]]], )

與其他 PyTorch 建立運算元類似,建構函式也接受 dtypedevicerequires_grad 引數。

float_image = tv_tensors.Image([[[0, 1], [1, 0]]], dtype=torch.float32, requires_grad=True)
print(float_image)
Image([[[0., 1.],
        [1., 0.]]], grad_fn=<AliasBackward0>, )

此外,ImageMask 還可以直接接受 PIL.Image.Image

image = tv_tensors.Image(PIL.Image.open("../assets/astronaut.jpg"))
print(image.shape, image.dtype)
torch.Size([3, 512, 512]) torch.uint8

一些 TVTensors 在構造時需要傳遞額外的元資料。例如,BoundingBoxes 需要座標格式以及相應影像的尺寸(canvas_size)以及實際值。這些元資料是正確變換邊界框所必需的。

bboxes = tv_tensors.BoundingBoxes(
    [[17, 16, 344, 495], [0, 10, 0, 10]],
    format=tv_tensors.BoundingBoxFormat.XYXY,
    canvas_size=image.shape[-2:]
)
print(bboxes)
BoundingBoxes([[ 17,  16, 344, 495],
               [  0,  10,   0,  10]], format=BoundingBoxFormat.XYXY, canvas_size=torch.Size([512, 512]))

使用 tv_tensors.wrap()

您還可以使用 wrap() 函式將 tensor 物件包裝成 TVTensor。當您已經擁有所需型別的物件時,這會很有用,這通常發生在編寫變換時:您只需像處理輸入一樣包裝輸出。

new_bboxes = torch.tensor([0, 20, 30, 40])
new_bboxes = tv_tensors.wrap(new_bboxes, like=bboxes)
assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)
assert new_bboxes.canvas_size == bboxes.canvas_size

new_bboxes 的元資料與 bboxes 相同,但您可以將其作為引數傳遞來覆蓋它。

我有一個 TVTensor,但現在變成 Tensor 了。求助!

預設情況下,對 TVTensor 物件的操作將返回一個純 Tensor。

assert isinstance(bboxes, tv_tensors.BoundingBoxes)

# Shift bboxes by 3 pixels in both H and W
new_bboxes = bboxes + 3

assert isinstance(new_bboxes, torch.Tensor)
assert not isinstance(new_bboxes, tv_tensors.BoundingBoxes)

注意

此行為僅影響原生的 torch 操作。如果您使用內建的 torchvision 變換或函式,您將始終獲得與輸入(純 TensorTVTensor)相同的輸出型別。

但我想要回 TVTensor!

您可以透過呼叫 TVTensor 建構函式,或者使用 wrap() 函式將純 tensor 重新包裝成 TVTensor(詳見上文的 如何構造 TVTensor?)。

new_bboxes = bboxes + 3
new_bboxes = tv_tensors.wrap(new_bboxes, like=bboxes)
assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)

或者,您可以使用 set_return_type() 作為整個程式的全域性配置設定,或者作為上下文管理器(閱讀其文件以瞭解更多注意事項)

with tv_tensors.set_return_type("TVTensor"):
    new_bboxes = bboxes + 3
assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)

為什麼會這樣?

出於效能原因。 TVTensor 類是 Tensor 子類,因此任何涉及 TVTensor 物件的操作都將透過 __torch_function__ 協議。這會帶來少量開銷,我們希望在可能的情況下避免。對於內建的 torchvision 變換而言,這並不重要,因為我們可以在那裡避免開銷,但在模型的 forward 中可能會出現問題。

無論如何,替代方案也好不到哪裡去。 對於每個保留 TVTensor 型別有意義的操作,也有同樣多更適合返回純 Tensor 的操作:例如,img.sum() 仍然是一個 Image 嗎?如果我們一直保留 TVTensor 型別,即使是模型的 logits 或損失函式的輸出也會變成 Image 型別,這顯然不是期望的結果。

注意

我們正在積極徵求對此行為的反饋意見。如果您對此感到意外,或者對如何更好地支援您的用例有任何建議,請透過此 issue 與我們聯絡:https://github.com/pytorch/vision/issues/7319

例外情況

此“解包”規則有幾個例外:clone()to()torch.Tensor.detach()requires_grad_() 保留 TVTensor 型別。

對 TVTensors 進行的原地操作,例如 obj.add_(),將保留 obj 的型別。然而,原地操作的**返回值**將是純 tensor。

image = tv_tensors.Image([[[0, 1], [1, 0]]])

new_image = image.add_(1).mul_(2)

# image got transformed in-place and is still a TVTensor Image, but new_image
# is a Tensor. They share the same underlying data and they're equal, just
# different classes.
assert isinstance(image, tv_tensors.Image)
print(image)

assert isinstance(new_image, torch.Tensor) and not isinstance(new_image, tv_tensors.Image)
assert (new_image == image).all()
assert new_image.data_ptr() == image.data_ptr()
Image([[[2, 4],
        [4, 2]]], )

指令碼總執行時間: (0 minutes 0.008 seconds)

畫廊由 Sphinx-Gallery 生成

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源