• 文件 >
  • 操縱 TensorDict 的形狀
快捷方式

操縱 TensorDict 的形狀

作者: Tom Begley

在本教程中,你將學習如何操縱 TensorDict 及其內容的形狀。

當我們建立一個 TensorDict 時,會指定一個 batch_size,它必須與 TensorDict 中所有條目的前導維度一致。由於我們保證所有條目共享這些公共維度,因此 TensorDict 能夠提供多種方法來操縱 TensorDict 及其內容的形狀。

import torch
from tensordict.tensordict import TensorDict

索引 TensorDict

由於批處理維度保證存在於所有條目上,我們可以隨意對它們進行索引,並且 TensorDict 的每個條目都將以相同的方式被索引。

a = torch.rand(3, 4)
b = torch.rand(3, 4, 5)
tensordict = TensorDict({"a": a, "b": b}, batch_size=[3, 4])

indexed_tensordict = tensordict[:2, 1]
assert indexed_tensordict["a"].shape == torch.Size([2])
assert indexed_tensordict["b"].shape == torch.Size([2, 5])

重塑 TensorDict

TensorDict.reshape 的工作方式與 torch.Tensor.reshape() 完全一樣。它沿批處理維度應用於 TensorDict 的所有內容——注意下面示例中 b 的形狀。它還會更新 batch_size 屬性。

reshaped_tensordict = tensordict.reshape(-1)
assert reshaped_tensordict.batch_size == torch.Size([12])
assert reshaped_tensordict["a"].shape == torch.Size([12])
assert reshaped_tensordict["b"].shape == torch.Size([12, 5])

分割 TensorDict

TensorDict.splittorch.Tensor.split() 類似。它將 TensorDict 分割成塊。每個塊都是一個與原始 TensorDict 結構相同的 TensorDict,但其條目是原始 TensorDict 中相應條目的檢視(view)。

chunks = tensordict.split([3, 1], dim=1)
assert chunks[0].batch_size == torch.Size([3, 3])
assert chunks[1].batch_size == torch.Size([3, 1])
torch.testing.assert_close(chunks[0]["a"], tensordict["a"][:, :-1])

注意

每當函式或方法接受 dim 引數時,負維度會相對於呼叫該函式或方法的 TensorDictbatch_size 進行解釋。特別是,如果存在具有不同批處理大小的巢狀 TensorDict 值,負維度始終相對於根 TensorDict 的批處理維度進行解釋。

>>> tensordict = TensorDict(
...     {
...         "a": torch.rand(3, 4),
...         "nested": TensorDict({"b": torch.rand(3, 4, 5)}, [3, 4, 5])
...     },
...     [3, 4],
... )
>>> # dim = -2 will be interpreted as the first dimension throughout, as the root
>>> # TensorDict has 2 batch dimensions, even though the nested TensorDict has 3
>>> chunks = tensordict.split([2, 1], dim=-2)
>>> assert chunks[0].batch_size == torch.Size([2, 4])
>>> assert chunks[0]["nested"].batch_size == torch.Size([2, 4, 5])

從這個例子可以看出,TensorDict.split 方法的行為與我們在呼叫前將 dim=-2 替換為 dim=tensordict.batch_dims - 2 時完全相同。

解綁

TensorDict.unbindtorch.Tensor.unbind() 類似,概念上與 TensorDict.split 相似。它移除指定的維度並返回沿該維度的所有切片的 tuple

slices = tensordict.unbind(dim=1)
assert len(slices) == 4
assert all(s.batch_size == torch.Size([3]) for s in slices)
torch.testing.assert_close(slices[0]["a"], tensordict["a"][:, 0])

堆疊和拼接

TensorDict 可以與 torch.cattorch.stack 結合使用。

堆疊 TensorDict

堆疊可以懶惰地或連續地完成。懶惰堆疊(lazy stack)只是一個 tensordict 列表,表示為一個 tensordict 堆疊。它允許使用者攜帶內容形狀、裝置或鍵集不同的 tensordict 集合。另一個優點是堆疊操作可能開銷很大,如果只需要一小部分鍵,懶惰堆疊會比真正的堆疊快得多。它依賴於 LazyStackedTensorDict 類。在這種情況下,值只會在訪問時按需堆疊。

from tensordict import LazyStackedTensorDict

cloned_tensordict = tensordict.clone()
stacked_tensordict = LazyStackedTensorDict.lazy_stack(
    [tensordict, cloned_tensordict], dim=0
)
print(stacked_tensordict)

# Previously, torch.stack was always returning a lazy stack. For consistency with
# the regular PyTorch API, this behaviour will soon be adapted to deliver only
# dense tensordicts. To control which behaviour you are relying on, you can use
# the :func:`~tensordict.utils.set_lazy_legacy` decorator/context manager:

from tensordict.utils import set_lazy_legacy

with set_lazy_legacy(True):  # old behaviour
    lazy_stack = torch.stack([tensordict, cloned_tensordict])
assert isinstance(lazy_stack, LazyStackedTensorDict)

with set_lazy_legacy(False):  # new behaviour
    dense_stack = torch.stack([tensordict, cloned_tensordict])
assert isinstance(dense_stack, TensorDict)
LazyStackedTensorDict(
    fields={
        a: Tensor(shape=torch.Size([2, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([2, 3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False)},
    exclusive_fields={
    },
    batch_size=torch.Size([2, 3, 4]),
    device=None,
    is_shared=False,
    stack_dim=0)

如果我們沿堆疊維度索引 LazyStackedTensorDict,我們會恢復原始的 TensorDict

assert stacked_tensordict[0] is tensordict
assert stacked_tensordict[1] is cloned_tensordict

訪問 LazyStackedTensorDict 中的鍵會導致這些值被堆疊。如果鍵對應於巢狀的 TensorDict,那麼我們將恢復另一個 LazyStackedTensorDict

assert stacked_tensordict["a"].shape == torch.Size([2, 3, 4])

注意

由於值是按需堆疊的,多次訪問一個專案意味著它會被多次堆疊,這效率很低。如果你需要多次訪問堆疊的 TensorDict 中的值,你可能需要考慮將 LazyStackedTensorDict 轉換為連續的 TensorDict,這可以透過 LazyStackedTensorDict.to_tensordictLazyStackedTensorDict.contiguous 方法完成。

>>> assert isinstance(stacked_tensordict.contiguous(), TensorDict)
>>> assert isinstance(stacked_tensordict.contiguous(), TensorDict)

呼叫這些方法中的任一個後,我們將得到一個包含堆疊值的常規 TensorDict,並且在訪問值時不會執行額外的計算。

拼接 TensorDict

拼接不是懶惰進行的,而是對 TensorDict 例項列表呼叫 torch.cat() 會直接返回一個 TensorDict,其條目是列表中元素的拼接條目。

concatenated_tensordict = torch.cat([tensordict, cloned_tensordict], dim=0)
assert isinstance(concatenated_tensordict, TensorDict)
assert concatenated_tensordict.batch_size == torch.Size([6, 4])
assert concatenated_tensordict["b"].shape == torch.Size([6, 4, 5])

擴充套件 TensorDict

我們可以使用 TensorDict.expand 擴充套件 TensorDict 的所有條目。

exp_tensordict = tensordict.expand(2, *tensordict.batch_size)
assert exp_tensordict.batch_size == torch.Size([2, 3, 4])
torch.testing.assert_close(exp_tensordict["a"][0], exp_tensordict["a"][1])

壓縮(Squeezing)和解壓縮(Unsqueezing)TensorDict

我們可以使用 squeeze()unsqueeze() 方法壓縮或解壓縮 TensorDict 的內容。

tensordict = TensorDict({"a": torch.rand(3, 1, 4)}, [3, 1, 4])
squeezed_tensordict = tensordict.squeeze()
assert squeezed_tensordict["a"].shape == torch.Size([3, 4])
print(squeezed_tensordict, end="\n\n")

unsqueezed_tensordict = tensordict.unsqueeze(-1)
assert unsqueezed_tensordict["a"].shape == torch.Size([3, 1, 4, 1])
print(unsqueezed_tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3, 4]),
    device=None,
    is_shared=False)

TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3, 1, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3, 1, 4, 1]),
    device=None,
    is_shared=False)

注意

到目前為止,像 unsqueeze()squeeze()view()permute()transpose() 等操作都會返回這些操作的懶惰版本(即一個容器,其中儲存了原始 tensordict,並且每次訪問鍵時都會應用這些操作)。此行為將來將被棄用,並且已經可以透過 set_lazy_legacy() 函式進行控制。

>>> with set_lazy_legacy(True):
...     lazy_unsqueeze = tensordict.unsqueeze(0)
>>> with set_lazy_legacy(False):
...     dense_unsqueeze = tensordict.unsqueeze(0)

請記住,與往常一樣,這些方法僅應用於批處理維度。條目的任何非批處理維度都不會受到影響。

tensordict = TensorDict({"a": torch.rand(3, 1, 1, 4)}, [3, 1])
squeezed_tensordict = tensordict.squeeze()
# only one of the singleton dimensions is dropped as the other
# is not a batch dimension
assert squeezed_tensordict["a"].shape == torch.Size([3, 1, 4])

檢視(View)TensorDict

TensorDict 也支援 view。這會建立一個 _ViewedTensorDict,它在內容被訪問時懶惰地建立檢視。

tensordict = TensorDict({"a": torch.arange(12)}, [12])
# no views are created at this step
viewed_tensordict = tensordict.view((2, 3, 2))

# the view of "a" is created on-demand when we access it
assert viewed_tensordict["a"].shape == torch.Size([2, 3, 2])

置換(Permuting)批處理維度

TensorDict.permute 方法可用於置換批處理維度,就像 torch.permute() 一樣。非批處理維度保持不變。

此操作是懶惰的,因此只有在嘗試訪問條目時才會置換批處理維度。與往常一樣,如果你可能需要多次訪問某個特定條目,請考慮將其轉換為 TensorDict

tensordict = TensorDict({"a": torch.rand(3, 4), "b": torch.rand(3, 4, 5)}, [3, 4])
# swap the batch dimensions
permuted_tensordict = tensordict.permute([1, 0])

assert permuted_tensordict["a"].shape == torch.Size([4, 3])
assert permuted_tensordict["b"].shape == torch.Size([4, 3, 5])

將 tensordicts 用作裝飾器

對於一系列可逆操作,tensordicts 可以用作裝飾器。這些操作包括用於函式呼叫的 to_module()unlock_()lock_(),或者形狀操作如 view()permute()transpose()squeeze()unsqueeze()。這是一個使用 transpose 函式的快速示例。

tensordict = TensorDict({"a": torch.rand(3, 4), "b": torch.rand(3, 4, 5)}, [3, 4])

with tensordict.transpose(1, 0) as tdt:
    tdt.set("c", torch.ones(4, 3))  # we have permuted the dims

# the ``"c"`` entry is now in the tensordict we used as decorator:
#

assert (tensordict.get("c") == 1).all()

在 TensorDict 中收集值

TensorDict.gather 方法可用於沿批處理維度進行索引並將結果收集到單個維度中,這與 torch.gather() 非常相似。

index = torch.randint(4, (3, 4))
gathered_tensordict = tensordict.gather(dim=1, index=index)
print("index:\n", index, end="\n\n")
print("tensordict['a']:\n", tensordict["a"], end="\n\n")
print("gathered_tensordict['a']:\n", gathered_tensordict["a"], end="\n\n")
index:
 tensor([[0, 2, 2, 1],
        [1, 0, 2, 2],
        [1, 3, 1, 2]])

tensordict['a']:
 tensor([[0.9580, 0.6498, 0.6842, 0.3068],
        [0.7585, 0.6647, 0.1465, 0.9081],
        [0.0090, 0.7241, 0.4385, 0.4466]])

gathered_tensordict['a']:
 tensor([[0.9580, 0.6842, 0.6842, 0.6498],
        [0.6647, 0.7585, 0.1465, 0.1465],
        [0.7241, 0.4466, 0.7241, 0.4385]])

指令碼總執行時間: (0 分 0.008 秒)

相簿由 Sphinx-Gallery 生成

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源