TensorClass¶
- class tensordict.TensorClass¶
TensorClass 是 @tensorclass 裝飾器的基於繼承的版本。
TensorClass 允許您編寫型別檢查更佳、更符合 Python 風格的資料類,相比於使用 @tensorclass 裝飾器構建的資料類。
示例
>>> from typing import Any >>> import torch >>> from tensordict import TensorClass >>> class Foo(TensorClass): ... tensor: torch.Tensor ... non_tensor: Any ... nested: Any = None >>> foo = Foo(tensor=torch.randn(3), non_tensor="a string!", nested=None, batch_size=[3]) >>> print(foo) Foo( non_tensor=NonTensorData(data=a string!, batch_size=torch.Size([3]), device=None), tensor=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), nested=None, batch_size=torch.Size([3]), device=None, is_shared=False)
您可以透過兩種方式傳遞關鍵字引數:使用方括號或直接使用關鍵字引數。
示例
>>> class Foo(TensorClass["autocast"]): ... integer: int >>> Foo(integer=torch.ones(())).integer 1 >>> class Foo(TensorClass, autocast=True): # equivalent ... integer: int >>> Foo(integer=torch.ones(())).integer 1 >>> class Foo(TensorClass["nocast"]): ... integer: int >>> Foo(integer=1).integer 1 >>> class Foo(TensorClass["nocast", "frozen"]): # multiple keywords can be used ... integer: int >>> Foo(integer=1).integer 1 >>> class Foo(TensorClass, nocast=True): # equivalent ... integer: int >>> Foo(integer=1).integer 1 >>> class Foo(TensorClass): ... integer: int >>> Foo(integer=1).integer tensor(1)
警告
TensorClass 本身沒有被裝飾為 tensorclass,但其子類會。這是因為我們無法預知 frozen 引數是否會被設定,如果設定了,它可能與父類衝突(如果父類未被凍結,子類也不能被凍結)。