快捷方式

TensorClass

class tensordict.TensorClass

TensorClass 是 @tensorclass 裝飾器的基於繼承的版本。

TensorClass 允許您編寫型別檢查更佳、更符合 Python 風格的資料類,相比於使用 @tensorclass 裝飾器構建的資料類。

示例

>>> from typing import Any
>>> import torch
>>> from tensordict import TensorClass
>>> class Foo(TensorClass):
...     tensor: torch.Tensor
...     non_tensor: Any
...     nested: Any = None
>>> foo = Foo(tensor=torch.randn(3), non_tensor="a string!", nested=None, batch_size=[3])
>>> print(foo)
Foo(
    non_tensor=NonTensorData(data=a string!, batch_size=torch.Size([3]), device=None),
    tensor=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
    nested=None,
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)

您可以透過兩種方式傳遞關鍵字引數:使用方括號或直接使用關鍵字引數。

示例

>>> class Foo(TensorClass["autocast"]):
...     integer: int
>>> Foo(integer=torch.ones(())).integer
1
>>> class Foo(TensorClass, autocast=True):  # equivalent
...     integer: int
>>> Foo(integer=torch.ones(())).integer
1
>>> class Foo(TensorClass["nocast"]):
...     integer: int
>>> Foo(integer=1).integer
1
>>> class Foo(TensorClass["nocast", "frozen"]):  # multiple keywords can be used
...     integer: int
>>> Foo(integer=1).integer
1
>>> class Foo(TensorClass, nocast=True):  # equivalent
...     integer: int
>>> Foo(integer=1).integer
1
>>> class Foo(TensorClass):
...     integer: int
>>> Foo(integer=1).integer
tensor(1)

警告

TensorClass 本身沒有被裝飾為 tensorclass,但其子類會。這是因為我們無法預知 frozen 引數是否會被設定,如果設定了,它可能與父類衝突(如果父類未被凍結,子類也不能被凍結)。

文件

獲取 PyTorch 的全面開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源