• 文件 >
  • 操作 TensorDict 的鍵
快捷方式

操作 TensorDict 的鍵

作者: Tom Begley

在本教程中,您將學習如何在 TensorDict 中使用和操作鍵,包括獲取和設定鍵、迭代鍵、操作巢狀值以及展平鍵。

設定和獲取鍵

我們可以使用與 Python dict 相同的語法來設定和獲取鍵

import torch
from tensordict.tensordict import TensorDict

tensordict = TensorDict()

# set a key
a = torch.rand(10)
tensordict["a"] = a

# retrieve the value stored under "a"
assert tensordict["a"] is a

注意

與 Python dict 不同,TensorDict 中的所有鍵都必須是字串。然而,正如我們將看到的,也可以使用字串元組來操作巢狀值。

我們也可以使用方法 .get().set 來實現同樣的功能。

tensordict = TensorDict()

# set a key
a = torch.rand(10)
tensordict.set("a", a)

# retrieve the value stored under "a"
assert tensordict.get("a") is a

dict 類似,我們可以為 get 提供一個預設值,當請求的鍵未找到時返回該值。

assert tensordict.get("banana", a) is a

同樣,與 dict 類似,我們可以使用 TensorDict.setdefault() 來獲取特定鍵的值,如果未找到該鍵則返回預設值,同時也在 TensorDict 中設定該值。

assert tensordict.setdefault("banana", a) is a
# a is now stored under "banana"
assert tensordict["banana"] is a

刪除鍵的方式也與 Python dict 相同,使用 del 語句和選擇的鍵。同樣,我們也可以使用 TensorDict.del_ 方法。

del tensordict["banana"]

此外,使用 .set() 設定鍵時,我們可以使用關鍵字引數 inplace=True 進行原地更新,或者等價地使用 .set_() 方法。

tensordict.set("a", torch.zeros(10), inplace=True)

# all the entries of the "a" tensor are now zero
assert (tensordict.get("a") == 0).all()
# but it's still the same tensor as before
assert tensordict.get("a") is a

# we can achieve the same with set_
tensordict.set_("a", torch.ones(10))
assert (tensordict.get("a") == 1).all()
assert tensordict.get("a") is a

重新命名鍵

要重新命名鍵,只需使用 TensorDict.rename_key_ 方法。儲存在原始鍵下的值將保留在 TensorDict 中,但鍵將更改為指定的新鍵。

tensordict.rename_key_("a", "b")
assert tensordict.get("b") is a
print(tensordict)
TensorDict(
    fields={
        b: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

更新多個值

TensorDict.update 方法可用於使用另一個 TensorDict`dict 來更新一個 TensorDict`。已經存在的鍵將被覆蓋,不存在的鍵將被建立。

tensordict = TensorDict({"a": torch.rand(10), "b": torch.rand(10)}, [10])
tensordict.update(TensorDict({"a": torch.zeros(10), "c": torch.zeros(10)}, [10]))
assert (tensordict["a"] == 0).all()
assert (tensordict["b"] != 0).all()
assert (tensordict["c"] == 0).all()
print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
        c: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

巢狀值

TensorDict 的值本身也可以是 TensorDict。我們可以在例項化時新增巢狀值,可以透過直接新增 TensorDict,或者使用巢狀字典的方式

# creating nested values with a nested dict
nested_tensordict = TensorDict(
    {"a": torch.rand(2, 3), "double_nested": {"a": torch.rand(2, 3)}}, [2, 3]
)
# creating nested values with a TensorDict
tensordict = TensorDict({"a": torch.rand(2), "nested": nested_tensordict}, [2])

print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
            fields={
                a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                double_nested: TensorDict(
                    fields={
                        a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([2, 3]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([2, 3]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

要訪問這些巢狀值,我們可以使用字串元組。例如

double_nested_a = tensordict["nested", "double_nested", "a"]
nested_a = tensordict.get(("nested", "a"))

類似地,我們可以使用字串元組來設定巢狀值

tensordict["nested", "double_nested", "b"] = torch.rand(2, 3)
tensordict.set(("nested", "b"), torch.rand(2, 3))

print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
            fields={
                a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                double_nested: TensorDict(
                    fields={
                        a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                        b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([2, 3]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([2, 3]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

迭代 TensorDict 的內容

我們可以使用 .keys() 方法迭代 TensorDict 的鍵。

for key in tensordict.keys():
    print(key)
a
nested

預設情況下,這將只迭代 TensorDict 中的頂層鍵,但是可以透過關鍵字引數 include_nested=True 遞迴迭代 TensorDict 中的所有鍵。這將遞迴迭代任何巢狀 TensorDict 中的所有鍵,並以字串元組的形式返回巢狀鍵。

for key in tensordict.keys(include_nested=True):
    print(key)
a
('nested', 'a')
('nested', 'double_nested', 'a')
('nested', 'double_nested', 'b')
('nested', 'double_nested')
('nested', 'b')
nested

如果您只想迭代對應於 Tensor 值(即葉子節點)的鍵,您可以額外指定 leaves_only=True

for key in tensordict.keys(include_nested=True, leaves_only=True):
    print(key)
a
('nested', 'a')
('nested', 'double_nested', 'a')
('nested', 'double_nested', 'b')
('nested', 'b')

dict 非常類似,還有 .values.items 方法,它們接受相同的關鍵字引數。

for key, value in tensordict.items(include_nested=True):
    if isinstance(value, TensorDict):
        print(f"{key} is a TensorDict")
    else:
        print(f"{key} is a Tensor")
a is a Tensor
nested is a TensorDict
('nested', 'a') is a Tensor
('nested', 'double_nested') is a TensorDict
('nested', 'double_nested', 'a') is a Tensor
('nested', 'double_nested', 'b') is a Tensor
('nested', 'b') is a Tensor

檢查鍵是否存在

要檢查鍵是否存在於 TensorDict 中,請結合 .keys() 使用 in 運算子。

注意

執行 key in tensordict.keys() 會進行高效的鍵查詢(在巢狀情況下在每個級別上遞迴查詢),因此當 TensorDict 中存在大量鍵時,效能不會受到負面影響。

assert "a" in tensordict.keys()
# to check for nested keys, set include_nested=True
assert ("nested", "a") in tensordict.keys(include_nested=True)
assert ("nested", "banana") not in tensordict.keys(include_nested=True)

展平與非展平巢狀鍵

我們可以使用 .flatten_keys() 方法展平具有巢狀值(鍵)的 TensorDict

print(tensordict, end="\n\n")
print(tensordict.flatten_keys(separator="."))
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
            fields={
                a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                double_nested: TensorDict(
                    fields={
                        a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                        b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([2, 3]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([2, 3]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        nested.a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        nested.b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        nested.double_nested.a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        nested.double_nested.b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

給定一個已經展平的 TensorDict,可以使用 .unflatten_keys() 方法將其恢復到非展平狀態。

flattened_tensordict = tensordict.flatten_keys(separator=".")
print(flattened_tensordict, end="\n\n")
print(flattened_tensordict.unflatten_keys(separator="."))
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        nested.a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        nested.b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        nested.double_nested.a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        nested.double_nested.b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
            fields={
                a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                double_nested: TensorDict(
                    fields={
                        a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                        b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([2]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([2]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

當操作 torch.nn.Module 的引數時,這會特別有用,因為最終我們可以得到一個結構模仿模組結構的 TensorDict

import torch.nn as nn

module = nn.Sequential(
    nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 10)),
    nn.Linear(10, 1),
)
params = TensorDict(dict(module.named_parameters()), []).unflatten_keys()

print(params)
TensorDict(
    fields={
        0: TensorDict(
            fields={
                0: TensorDict(
                    fields={
                        bias: Parameter(shape=torch.Size([50]), device=cpu, dtype=torch.float32, is_shared=False),
                        weight: Parameter(shape=torch.Size([50, 100]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False),
                1: TensorDict(
                    fields={
                        bias: Parameter(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
                        weight: Parameter(shape=torch.Size([10, 50]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        1: TensorDict(
            fields={
                bias: Parameter(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
                weight: Parameter(shape=torch.Size([1, 10]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

選擇和排除鍵

我們可以使用 TensorDict.select 方法獲取一個包含鍵子集的新 TensorDict,該方法返回一個僅包含指定鍵的新 TensorDict,或者使用 :meth: TensorDict.exclude <tensordict.TensorDict.exclude> 方法,該方法返回一個省略指定鍵的新 TensorDict

print("Select:")
print(tensordict.select("a", ("nested", "a")), end="\n\n")
print("Exclude:")
print(tensordict.exclude(("nested", "b"), ("nested", "double_nested")))
Select:
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
            fields={
                a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([2, 3]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

Exclude:
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
            fields={
                a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([2, 3]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

指令碼總執行時間: (0 分鐘 0.009 秒)

由 Sphinx-Gallery 生成的畫廊

文件

訪問 PyTorch 的綜合開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源