概述¶
TensorDict 使組織資料和編寫可重用、通用的 PyTorch 程式碼變得容易。它最初是為 TorchRL 開發的,後來我們將其分離出來成為一個獨立的庫。
TensorDict 主要是一個字典,但也像一個張量類:它支援多種主要與形狀和儲存相關的張量操作。它被設計成可以高效地從節點到節點或從程序到程序進行序列化或傳輸。最後,它帶有自己的 nn 模組,該模組與 torch.func 相容,旨在簡化模型整合和引數操作。
在本頁中,我們將闡述 TensorDict 的動機,並給出它的一些功能示例。
動機¶
TensorDict 允許您編寫可在不同範例中重用的通用程式碼模組。例如,以下迴圈可用於大多數 SL、SSL、UL 和 RL 任務。
>>> for i, tensordict in enumerate(dataset):
... # the model reads and writes tensordicts
... tensordict = model(tensordict)
... loss = loss_module(tensordict)
... loss.backward()
... optimizer.step()
... optimizer.zero_grad()
憑藉其 nn 模組,該包提供了許多工具,可以輕鬆地在程式碼庫中使用 TensorDict。
在多程序或分散式設定中,TensorDict 允許您將資料無縫地分派給每個工作程序
>>> # creates batches of 10 datapoints
>>> splits = torch.arange(tensordict.shape[0]).split(10)
>>> for worker in range(workers):
... idx = splits[worker]
... pipe[worker].send(tensordict[idx])
TensorDict 提供的一些操作也可以透過 tree_map 完成,但這會增加複雜性
>>> td = TensorDict(
... {"a": torch.randn(3, 11), "b": torch.randn(3, 3)}, batch_size=3
... )
>>> regular_dict = {"a": td["a"], "b": td["b"]}
>>> td0, td1, td2 = td.unbind(0)
>>> # similar structure with pytree
>>> regular_dicts = tree_map(lambda x: x.unbind(0))
>>> regular_dict1, regular_dict2, regular_dict3 = [
... {"a": regular_dicts["a"][i], "b": regular_dicts["b"][i]}
... for i in range(3)]
巢狀的情況更加引人注目
>>> td = TensorDict(
... {"a": {"c": torch.randn(3, 11)}, "b": torch.randn(3, 3)}, batch_size=3
... )
>>> regular_dict = {"a": {"c": td["a", "c"]}, "b": td["b"]}
>>> td0, td1, td2 = td.unbind(0)
>>> # similar structure with pytree
>>> regular_dicts = tree_map(lambda x: x.unbind(0))
>>> regular_dict1, regular_dict2, regular_dict3 = [
... {"a": {"c": regular_dicts["a"]["c"][i]}, "b": regular_dicts["b"][i]}
... for i in range(3)
在樸素地使用 pytree 時,應用 unbind 操作後將輸出字典分解為三個結構相似的字典會迅速變得相當麻煩。使用 tensordict,我們為希望分解或分割巢狀結構的使用者提供了一個簡單的 API,而不是計算一個巢狀的分割 / 分解的巢狀結構。
功能特性¶
一個 TensorDict 是一個類似字典的張量容器。要例項化 TensorDict,您可以指定鍵值對以及批大小(可以透過 TensorDict() 建立一個空的 tensordict)。TensorDict 中任何值的前導維度必須與批大小相容。
>>> import torch
>>> from tensordict import TensorDict
>>> tensordict = TensorDict(
... {"zeros": torch.zeros(2, 3, 4), "ones": torch.ones(2, 3, 4, 5)},
... batch_size=[2, 3],
... )
設定或檢索值的語法與常規字典非常相似。
>>> zeros = tensordict["zeros"]
>>> tensordict["twos"] = 2 * torch.ones(2, 3)
還可以沿著 batch_size 索引 tensordict,這樣只需幾個字元就能獲得資料的一致切片(請注意,使用省略號透過 tree_map 索引第 n 個前導維度需要更多的編碼)。
>>> sub_tensordict = tensordict[..., :2]
還可以使用帶有 inplace=True 的 set 方法或 set_() 方法對內容進行原地更新。前者是後者的容錯版本:如果找不到匹配的鍵,它將寫入一個新鍵。
TensorDict 的內容現在可以集體操作。例如,要將所有內容放到特定裝置上,只需執行
>>> tensordict = tensordict.to("cuda:0")
然後您可以斷言 tensordict 的裝置是 “cuda:0”
>>> assert tensordict.device == torch.device("cuda:0")
要重塑批維度,可以執行
>>> tensordict = tensordict.reshape(6)
該類支援許多其他操作,包括 squeeze()、unsqueeze()、view()、permute()、unbind()、stack()、cat() 等等。
如果某項操作不存在,通常可以使用 apply() 方法來解決問題。
避免形狀操作¶
在某些情況下,可能需要將張量儲存在 TensorDict 中,但在形狀操作期間不強制要求批大小一致性。
這可以透過將張量包裝在 UnbatchedTensor 例項中來實現。
UnbatchedTensor 在 TensorDict 上進行形狀操作時會忽略其形狀,從而可以靈活地儲存和操作具有任意形狀的張量。
>>> from tensordict import UnbatchedTensor
>>> tensordict = TensorDict({"zeros": UnbatchedTensor(torch.zeros(10))}, batch_size=[2, 3])
>>> reshaped_td = tensordict.reshape(6)
>>> reshaped_td["zeros"] is tensordict["zeros"]
True
非張量資料¶
Tensordict 是一個用於處理張量資料的強大庫,但也支援非張量資料。本指南將向您展示如何使用 tensordict 處理非張量資料。
使用非張量資料建立 TensorDict¶
您可以使用 NonTensorData 類建立包含非張量資料的 TensorDict。
>>> from tensordict import TensorDict, NonTensorData
>>> import torch
>>> td = TensorDict(
... a=NonTensorData("a string!"),
... b=torch.zeros(()),
... )
>>> print(td)
TensorDict(
fields={
a: NonTensorData(data=a string!, batch_size=torch.Size([]), device=None),
b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
如您所見,NonTensorData 物件像普通張量一樣儲存在 TensorDict 中。
訪問非張量資料¶
您可以使用鍵或 get 方法訪問非張量資料。常規的 getattr 呼叫將返回 NonTensorData 物件的內容,而 get() 將返回 NonTensorData 物件本身。
>>> print(td["a"]) # prints: a string!
>>> print(td.get("a")) # prints: NonTensorData(data=a string!, batch_size=torch.Size([]), device=None)
批處理非張量資料¶
如果您有一批非張量資料,可以將其儲存在指定批大小的 TensorDict 中。
>>> td = TensorDict(
... a=NonTensorData("a string!"),
... b=torch.zeros(3),
... batch_size=[3]
... )
>>> print(td)
TensorDict(
fields={
a: NonTensorData(data=a string!, batch_size=torch.Size([3]), device=None),
b: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
在這種情況下,我們假設 tensordict 的所有元素都具有相同的非張量資料。
>>> print(td[0])
TensorDict(
fields={
a: NonTensorData(data=a string!, batch_size=torch.Size([]), device=None),
b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
要為有形狀的 tensordict 中的每個元素分配不同的非張量資料物件,您可以使用非張量資料堆疊。
堆疊非張量資料¶
如果您有一個非張量資料列表想要儲存在 TensorDict 中,可以使用 NonTensorStack 類。
>>> td = TensorDict(
... a=NonTensorStack("a string!", "another string!", "a third string!"),
... b=torch.zeros(3),
... batch_size=[3]
... )
>>> print(td)
TensorDict(
fields={
a: NonTensorStack(
['a string!', 'another string!', 'a third string!'...,
batch_size=torch.Size([3]),
device=None),
b: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
您可以訪問第一個元素,然後您將獲得第一個字串
>>> print(td[0])
TensorDict(
fields={
a: NonTensorData(data=a string!, batch_size=torch.Size([]), device=None),
b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
相比之下,將 NonTensorData 與列表一起使用不會產生相同的結果,因為對於碰巧是列表的非張量資料,通常無法確定該如何處理
>>> td = TensorDict(
... a=NonTensorData(["a string!", "another string!", "a third string!"]),
... b=torch.zeros(3),
... batch_size=[3]
... )
>>> print(td[0])
TensorDict(
fields={
a: NonTensorData(data=['a string!', 'another string!', 'a third string!'], batch_size=torch.Size([]), device=None),
b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
堆疊包含非張量資料的 TensorDict¶
要堆疊非張量資料,stack() 將檢查非張量物件的標識,如果它們匹配,則生成單個 NonTensorData;否則,生成 NonTensorStack
>>> td = TensorDict(
... a=NonTensorData("a string!"),
... b = torch.zeros(()),
... )
>>> print(torch.stack([td, td]))
TensorDict(
fields={
a: NonTensorData(data=a string!, batch_size=torch.Size([2]), device=None),
b: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
如果您想確保結果是堆疊,請改用 lazy_stack()。
>>> print(TensorDict.lazy_stack([td, td]))
LazyStackedTensorDict(
fields={
a: NonTensorStack(
['a string!', 'a string!'],
batch_size=torch.Size([2]),
device=None),
b: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
exclusive_fields={
},
batch_size=torch.Size([2]),
device=None,
is_shared=False,
stack_dim=0)
命名維度¶
TensorDict 及相關類也支援維度命名。可以在構建時或稍後指定名稱。其語義類似於 torch.Tensor 的維度命名功能
>>> tensordict = TensorDict({}, batch_size=[3, 4], names=["a", None])
>>> tensordict.refine_names(..., "b")
>>> tensordict.names = ["z", "y"]
>>> tensordict.rename("m", "n")
>>> tensordict.rename(m="h")
巢狀 TensorDict¶
TensorDict 中的值本身也可以是 TensorDict(下面示例中的巢狀字典將被轉換為巢狀 TensorDict)。
>>> tensordict = TensorDict(
... {
... "inputs": {
... "image": torch.rand(100, 28, 28),
... "mask": torch.randint(2, (100, 28, 28), dtype=torch.uint8)
... },
... "outputs": {"logits": torch.randn(100, 10)},
... },
... batch_size=[100],
... )
訪問或設定巢狀鍵可以使用字串元組完成
>>> image = tensordict["inputs", "image"]
>>> logits = tensordict.get(("outputs", "logits")) # alternative way to access
>>> tensordict["outputs", "probabilities"] = torch.sigmoid(logits)
惰性評估¶
對 TensorDict 的一些操作會推遲執行,直到訪問其中的項。例如,堆疊、擠壓 (squeezing)、非擠壓 (unsqueezing)、置換批維度和建立檢視等操作不會立即在 TensorDict 的所有內容上執行。相反,它們在訪問 TensorDict 中的值時惰性執行。如果 TensorDict 包含許多值,這可以節省大量不必要的計算。
>>> tensordicts = [TensorDict({
... "a": torch.rand(10),
... "b": torch.rand(10, 1000, 1000)}, [10])
... for _ in range(3)]
>>> stacked = torch.stack(tensordicts, 0) # no stacking happens here
>>> stacked_a = stacked["a"] # we stack the a values, b values are not stacked
它還有一個優點,就是我們可以操作堆疊中的原始 tensordict
>>> stacked["a"] = torch.zeros_like(stacked["a"])
>>> assert (tensordicts[0]["a"] == 0).all()
需要注意的是,get 方法現在已成為一個昂貴的操作,如果重複多次,可能會導致一些開銷。只需在執行 stack 後呼叫 tensordict.contiguous() 即可避免這種情況。為了進一步緩解此問題,TensorDict 附帶了自己的元資料類 (MetaTensor),該類可以跟蹤字典中每個條目的型別、形狀、dtype 和裝置,而無需執行昂貴的操作。
惰性預分配¶
假設我們有一個函式 foo() -> TensorDict,然後我們執行以下操作
>>> tensordict = TensorDict({}, batch_size=[N])
>>> for i in range(N):
... tensordict[i] = foo()
當 i == 0 時,空的 TensorDict 將自動填充批大小為 N 的空張量。在後續的迴圈迭代中,所有更新都將原地寫入。
TensorDictModule¶
為了方便將 TensorDict 整合到程式碼庫中,我們提供了 tensordict.nn 包,該包允許使用者將 TensorDict 例項傳遞給 Module 物件(或任何可呼叫物件)。
TensorDictModule 包裝了 Module 並接受單個 TensorDict 作為輸入。您可以指定底層模組應從何處獲取輸入以及應將輸出寫入何處。這是我們可以編寫可重用、通用的高階程式碼(例如動機部分中的訓練迴圈)的關鍵原因。
>>> from tensordict.nn import TensorDictModule
>>> class Net(nn.Module):
... def __init__(self):
... super().__init__()
... self.linear = nn.LazyLinear(1)
...
... def forward(self, x):
... logits = self.linear(x)
... return logits, torch.sigmoid(logits)
>>> module = TensorDictModule(
... Net(),
... in_keys=["input"],
... out_keys=[("outputs", "logits"), ("outputs", "probabilities")],
... )
>>> tensordict = TensorDict({"input": torch.randn(32, 100)}, [32])
>>> tensordict = module(tensordict)
>>> # outputs can now be retrieved from the tensordict
>>> logits = tensordict["outputs", "logits"]
>>> probabilities = tensordict.get(("outputs", "probabilities"))
為了方便採用此類,還可以將張量作為 kwargs 傳遞
>>> tensordict = module(input=torch.randn(32, 100))
這將返回一個與前一個程式碼框中完全相同的 TensorDict。有關此功能的更多背景資訊,請參閱 匯出教程。
許多 PyTorch 使用者面臨的一個痛點是 nn.Sequential 無法處理具有多個輸入的模組。使用基於鍵的圖可以輕鬆解決此問題,因為序列中的每個節點都知道需要讀取哪些資料以及寫入到何處。
為此,我們提供了 TensorDictSequential 類,該類將資料透過一系列 TensorDictModules 進行傳遞。序列中的每個模組都從原始 TensorDict 獲取輸入,並將輸出寫入其中,這意味著序列中的模組可以忽略前一個模組的輸出,或根據需要從 tensordict 中獲取額外的輸入。以下是一個示例
>>> class Net(nn.Module):
... def __init__(self, input_size=100, hidden_size=50, output_size=10):
... super().__init__()
... self.fc1 = nn.Linear(input_size, hidden_size)
... self.fc2 = nn.Linear(hidden_size, output_size)
...
... def forward(self, x):
... x = torch.relu(self.fc1(x))
... return self.fc2(x)
...
... class Masker(nn.Module):
... def forward(self, x, mask):
... return torch.softmax(x * mask, dim=1)
>>> net = TensorDictModule(
... Net(), in_keys=[("input", "x")], out_keys=[("intermediate", "x")]
... )
>>> masker = TensorDictModule(
... Masker(),
... in_keys=[("intermediate", "x"), ("input", "mask")],
... out_keys=[("output", "probabilities")],
... )
>>> module = TensorDictSequential(net, masker)
>>> tensordict = TensorDict(
... {
... "input": TensorDict(
... {"x": torch.rand(32, 100), "mask": torch.randint(2, size=(32, 10))},
... batch_size=[32],
... )
... },
... batch_size=[32],
... )
>>> tensordict = module(tensordict)
>>> intermediate_x = tensordict["intermediate", "x"]
>>> probabilities = tensordict["output", "probabilities"]
在此示例中,第二個模組將第一個模組的輸出與儲存在 TensorDict 中 (“inputs”, “mask”) 鍵下的掩碼組合。
TensorDictSequential 提供了一系列其他功能:可以透過查詢 in_keys 和 out_keys 屬性來訪問輸入和輸出鍵列表。還可以透過使用所需的輸入和輸出鍵集查詢 select_subsequence() 來請求子圖。這將返回另一個 TensorDictSequential,其中只包含滿足這些要求必不可少的模組。TensorDictModule 也與 vmap() 和其他 torch.func 功能相容。