快捷方式

tensorclass

@tensorclass 裝飾器可幫助您構建自定義類,這些類繼承了 TensorDict 的行為,同時能夠將可能的條目限制為預定義的集合或為其類實現自定義方法。

TensorDict 類似,@tensorclass 支援巢狀、索引、重塑、專案賦值。它還支援 clonesqueezetorch.catsplit 等許多張量操作。@tensorclass 允許非張量條目,但所有張量操作都嚴格限制於張量屬性。

需要為非張量資料實現自定義方法。重要的是要注意 @tensorclass 不強制執行嚴格的型別匹配。

>>> from __future__ import annotations
>>> from tensordict.prototype import tensorclass
>>> import torch
>>> from torch import nn
>>> from typing import Optional
>>>
>>> @tensorclass
... class MyData:
...     floatdata: torch.Tensor
...     intdata: torch.Tensor
...     non_tensordata: str
...     nested: Optional[MyData] = None
...
...     def check_nested(self):
...         assert self.nested is not None
>>>
>>> data = MyData(
...   floatdata=torch.randn(3, 4, 5),
...   intdata=torch.randint(10, (3, 4, 1)),
...   non_tensordata="test",
...   batch_size=[3, 4]
... )
>>> print("data:", data)
data: MyData(
  floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
  intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
  non_tensordata='test',
  nested=None,
  batch_size=torch.Size([3, 4]),
  device=None,
  is_shared=False)
>>> data.nested = MyData(
...     floatdata = torch.randn(3, 4, 5),
...     intdata=torch.randint(10, (3, 4, 1)),
...     non_tensordata="nested_test",
...     batch_size=[3, 4]
... )
>>> print("nested:", data)
nested: MyData(
  floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
  intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
  non_tensordata='test',
  nested=MyData(
      floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
      intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
      non_tensordata='nested_test',
      nested=None,
      batch_size=torch.Size([3, 4]),
      device=None,
      is_shared=False),
  batch_size=torch.Size([3, 4]),
  device=None,
  is_shared=False)

TensorDict 一樣,從 v0.4 開始,如果省略批次大小,則將其視為空。

如果提供了非空批次大小,@tensorclass 支援索引。在內部,張量物件會被索引,但非張量資料保持不變。

>>> print("indexed:", data[:2])
indexed: MyData(
   floatdata=Tensor(shape=torch.Size([2, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
   intdata=Tensor(shape=torch.Size([2, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
   non_tensordata='test',
   nested=MyData(
      floatdata=Tensor(shape=torch.Size([2, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
      intdata=Tensor(shape=torch.Size([2, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
      non_tensordata='nested_test',
      nested=None,
      batch_size=torch.Size([2, 4]),
      device=None,
      is_shared=False),
   batch_size=torch.Size([2, 4]),
   device=None,
   is_shared=False)

@tensorclass 還支援設定和重置屬性,甚至對於巢狀物件也是如此。

>>> data.non_tensordata = "test_changed"
>>> print("data.non_tensordata: ", repr(data.non_tensordata))
data.non_tensordata: 'test_changed'

>>> data.floatdata = torch.ones(3, 4, 5)
>>> print("data.floatdata:", data.floatdata)
data.floatdata: tensor([[[1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.]],

      [[1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.]],

      [[1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.]]])

>>> # Changing nested tensor data
>>> data.nested.non_tensordata = "nested_test_changed"
>>> print("data.nested.non_tensordata:", repr(data.nested.non_tensordata))
data.nested.non_tensordata: 'nested_test_changed'

@tensorclass 支援對其內容的形狀和裝置進行多種 torch 操作,例如 stackcatreshapeto(device)。要獲取支援的操作的完整列表,請檢視 tensordict 文件。

這是一個例子

>>> data2 = data.clone()
>>> cat_tc = torch.cat([data, data2], 0)
>>> print("Concatenated data:", catted_tc)
Concatenated data: MyData(
   floatdata=Tensor(shape=torch.Size([6, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
   intdata=Tensor(shape=torch.Size([6, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
   non_tensordata='test_changed',
   nested=MyData(
       floatdata=Tensor(shape=torch.Size([6, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
       intdata=Tensor(shape=torch.Size([6, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
       non_tensordata='nested_test_changed',
       nested=None,
       batch_size=torch.Size([6, 4]),
       device=None,
       is_shared=False),
   batch_size=torch.Size([6, 4]),
   device=None,
   is_shared=False)

序列化

儲存 tensorclass 例項可以使用 memmap 方法實現。儲存策略如下:張量資料將使用記憶體對映張量儲存,可以使用 json 格式序列化的非張量資料將以這種方式儲存。其他資料型別將使用 save() 儲存,該方法依賴於 pickle

反序列化 tensorclass 可以透過 load_memmap() 完成。建立的例項將與儲存的例項具有相同的型別,前提是 tensorclass 在工作環境中可用。

>>> data.memmap("path/to/saved/directory")
>>> data_loaded = TensorDict.load_memmap("path/to/saved/directory")
>>> assert isinstance(data_loaded, type(data))

邊界情況

@tensorclass 支援相等和不等運算子,甚至對於巢狀物件也是如此。請注意,非張量/元資料未經驗證。這將返回一個 tensor class 物件,其中張量屬性具有布林值,非張量屬性具有 None。

這是一個例子

>>> print(data == data2)
MyData(
   floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.bool, is_shared=False),
   intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
   non_tensordata=None,
   nested=MyData(
       floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.bool, is_shared=False),
       intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
       non_tensordata=None,
       nested=None,
       batch_size=torch.Size([3, 4]),
       device=None,
       is_shared=False),
   batch_size=torch.Size([3, 4]),
   device=None,
   is_shared=False)

@tensorclass 支援設定專案。但是,在設定專案時,為了避免效能問題,對非張量/元資料進行身份檢查而不是相等性檢查。使用者需要確保專案的非張量資料與物件匹配,以避免差異。

這是一個例子

在設定具有不同 non_tensor 資料的專案時,將引發 UserWarning

>>> data2.non_tensordata = "test_new"
>>> data[0] = data2[0]
UserWarning: Meta data at 'non_tensordata' may or may not be equal, this may result in undefined behaviours

儘管 @tensorclass 支援 cat()stack() 等 torch 函式,但非張量/元資料未經驗證。torch 操作是在張量資料上執行的,並且在返回輸出時,考慮第一個 tensor class 物件的非張量/元資料。使用者需要確保所有 tensor class 物件列表具有相同的非張量資料,以避免差異。

這是一個例子

>>> data2.non_tensordata = "test_new"
>>> stack_tc = torch.cat([data, data2], dim=0)
>>> print(stack_tc)
MyData(
    floatdata=Tensor(shape=torch.Size([2, 3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
    intdata=Tensor(shape=torch.Size([2, 3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
    non_tensordata='test',
    nested=MyData(
        floatdata=Tensor(shape=torch.Size([2, 3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        intdata=Tensor(shape=torch.Size([2, 3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        non_tensordata='nested_test',
        nested=None,
        batch_size=torch.Size([2, 3, 4]),
        device=None,
        is_shared=False),
    batch_size=torch.Size([2, 3, 4]),
    device=None,
    is_shared=False)

@tensorclass 還支援預分配,您可以使用屬性為 None 初始化物件,稍後設定它們。請注意,在初始化時,內部 None 屬性將儲存為非張量/元資料,而在重置時,根據屬性值的型別,它將儲存為張量資料或非張量/元資料。

這是一個例子

>>> @tensorclass
... class MyClass:
...   X: Any
...   y: Any

>>> data = MyClass(X=None, y=None, batch_size = [3,4])
>>> data.X = torch.ones(3, 4, 5)
>>> data.y = "testing"
>>> print(data)
MyClass(
   X=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
   y='testing',
   batch_size=torch.Size([3, 4]),
   device=None,
   is_shared=False)

tensorclass([cls, autocast, frozen, nocast, ...])

一個建立 tensorclass 類的裝飾器。

TensorClass()

TensorClass 是 @tensorclass 裝飾器的基於繼承的版本。

NonTensorData(data[, _metadata, ...])

NonTensorStack(*args, **kwargs)

LazyStackedTensorDict 的一個輕量級包裝器,用於使非張量資料上的堆疊易於識別。

from_dataclass(obj, *[, auto_batch_size, ...])

將 dataclass 例項或型別分別轉換為 tensorclass 例項或型別。

自動型別轉換

警告

自動型別轉換是一個實驗性功能,將來可能會有所更改。與 python<=3.9 的相容性有限。

@tensorclass 部分支援自動型別轉換,作為實驗性功能。諸如 __setattr__updateupdate_from_dict 之類的方法將嘗試將型別標註的條目轉換為所需的 TensorDict / tensorclass 例項(除了下述情況外)。例如,以下程式碼會將 td 字典轉換為 TensorDict,將 tc 條目轉換為 MyClass 例項。

>>> @tensorclass
... class MyClass:
...     tensor: torch.Tensor
...     td: TensorDict
...     tc: MyClass
...
>>> obj = MyClass(
...     tensor=torch.randn(()),
...     td={"a": torch.randn(())},
...     tc={"tensor": torch.randn(()), "td": None, "tc": None})
>>> assert isinstance(obj.tensor, torch.Tensor)
>>> assert isinstance(obj.tc, TensorDict)
>>> assert isinstance(obj.td, MyClass)

注意

包含 typing.Optionaltyping.Union 的型別標註專案將不相容自動型別轉換,但 tensorclass 中的其他專案將相容。

>>> @tensorclass
... class MyClass:
...     tensor: torch.Tensor
...     tc_autocast: MyClass = None
...     tc_not_autocast: Optional[MyClass] = None
>>> obj = MyClass(
...     tensor=torch.randn(()),
...     tc_autocast={"tensor": torch.randn(())},
...     tc_not_autocast={"tensor": torch.randn(())},
... )
>>> assert isinstance(obj.tc_autocast, MyClass)
>>> # because the type is Optional or Union, auto-casting is disabled for
>>> # that variable.
>>> assert not isinstance(obj.tc_not_autocast, MyClass)

如果類中至少有一個專案使用 type0 | type1 語義進行標註,則整個類的自動型別轉換功能將被停用。因為 tensorclass 支援非張量葉節點,在這些情況下設定字典會導致將其設定為普通字典,而不是張量集合子類(TensorDicttensorclass)。

>>> @tensorclass
... class MyClass:
...     tensor: torch.Tensor
...     td: TensorDict
...     tc: MyClass | None
...
>>> obj = MyClass(
...     tensor=torch.randn(()),
...     td={"a": torch.randn(())},
...     tc={"tensor": torch.randn(()), "td": None, "tc": None})
>>> assert isinstance(obj.tensor, torch.Tensor)
>>> # tc and td have not been cast
>>> assert isinstance(obj.tc, dict)
>>> assert isinstance(obj.td, dict)

注意

葉節點(張量)未啟用自動型別轉換。原因是此功能與包含 type0 | type1 型別提示語義的型別標註不相容,而這種語義很普遍。如果型別標註僅略有不同,允許自動型別轉換將導致非常相似的程式碼具有截然不同的行為。

文件

訪問 PyTorch 的完整開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源