• 文件 >
  • 使用 TensorDict 預分配記憶體
快捷方式

使用 TensorDict 預分配記憶體

作者Tom Begley

在本教程中,您將學習如何在 TensorDict 中利用記憶體預分配。

假設我們有一個返回 TensorDict 的函式

import torch
from tensordict.tensordict import TensorDict


def make_tensordict():
    return TensorDict({"a": torch.rand(3), "b": torch.rand(3, 4)}, [3])

我們可能希望多次呼叫此函式,並使用結果來填充一個 TensorDict

N = 10
tensordict = TensorDict({}, batch_size=[N, 3])

for i in range(N):
    tensordict[i] = make_tensordict()

print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([10, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([10, 3]),
    device=None,
    is_shared=False)

由於我們指定了 tensordict 的 batch_size,在迴圈的第一次迭代中,我們使用第一維大小為 N 的空張量填充 tensordict,其餘維度由 make_tensordict 的返回值確定。在上述示例中,我們為鍵 "a" 預分配了一個大小為 torch.Size([10, 3]) 的零陣列,為鍵 "b" 預分配了一個大小為 torch.Size([10, 3, 4]) 的陣列。隨後的迴圈迭代是就地寫入的。因此,如果不是所有值都被填充,它們將獲得預設值零。

讓我們透過逐步執行上述迴圈來演示正在發生的事情。我們首先初始化一個空的 TensorDict

N = 10
tensordict = TensorDict({}, batch_size=[N, 3])
print(tensordict)
TensorDict(
    fields={
    },
    batch_size=torch.Size([10, 3]),
    device=None,
    is_shared=False)

第一次迭代後,tensordict 已經預填充了鍵 "a""b" 的張量。這些張量包含零,除了我們已賦隨機值的第一行。

random_tensordict = make_tensordict()
tensordict[0] = random_tensordict

assert (tensordict[1:] == 0).all()
assert (tensordict[0] == random_tensordict).all()

print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([10, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([10, 3]),
    device=None,
    is_shared=False)

隨後的迭代,我們就地更新預分配的張量。

a = tensordict["a"]
random_tensordict = make_tensordict()
tensordict[1] = random_tensordict

# the same tensor is stored under "a", but the values have been updated
assert tensordict["a"] is a
assert (tensordict[:2] != 0).all()

指令碼總執行時間: (0 分 0.003 秒)

由 Sphinx-Gallery 生成的相簿

文件

訪問 PyTorch 全面的開發者文件

檢視文件

教程

獲取適合初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得解答

檢視資源