注意
轉到 末尾 下載完整的示例程式碼。
使用 TensorDict 預分配記憶體¶
作者:Tom Begley
在本教程中,您將學習如何在 TensorDict 中利用記憶體預分配。
假設我們有一個返回 TensorDict 的函式
import torch
from tensordict.tensordict import TensorDict
def make_tensordict():
return TensorDict({"a": torch.rand(3), "b": torch.rand(3, 4)}, [3])
我們可能希望多次呼叫此函式,並使用結果來填充一個 TensorDict。
TensorDict(
fields={
a: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([10, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([10, 3]),
device=None,
is_shared=False)
由於我們指定了 tensordict 的 batch_size,在迴圈的第一次迭代中,我們使用第一維大小為 N 的空張量填充 tensordict,其餘維度由 make_tensordict 的返回值確定。在上述示例中,我們為鍵 "a" 預分配了一個大小為 torch.Size([10, 3]) 的零陣列,為鍵 "b" 預分配了一個大小為 torch.Size([10, 3, 4]) 的陣列。隨後的迴圈迭代是就地寫入的。因此,如果不是所有值都被填充,它們將獲得預設值零。
讓我們透過逐步執行上述迴圈來演示正在發生的事情。我們首先初始化一個空的 TensorDict。
N = 10
tensordict = TensorDict({}, batch_size=[N, 3])
print(tensordict)
TensorDict(
fields={
},
batch_size=torch.Size([10, 3]),
device=None,
is_shared=False)
第一次迭代後,tensordict 已經預填充了鍵 "a" 和 "b" 的張量。這些張量包含零,除了我們已賦隨機值的第一行。
random_tensordict = make_tensordict()
tensordict[0] = random_tensordict
assert (tensordict[1:] == 0).all()
assert (tensordict[0] == random_tensordict).all()
print(tensordict)
TensorDict(
fields={
a: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([10, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([10, 3]),
device=None,
is_shared=False)
隨後的迭代,我們就地更新預分配的張量。
指令碼總執行時間: (0 分 0.003 秒)