from_dataclass¶
- class tensordict.from_dataclass(obj: Any, *, auto_batch_size: bool = False, batch_dims: Optional[int] = None, batch_size: Optional[Size] = None, frozen: bool = False, autocast: bool = False, nocast: bool = False, inplace: bool = False, shadow: bool = False, device: Optional[device] = None)¶
將 dataclass 例項或型別分別轉換為 tensorclass 例項或型別。
此函式接收 dataclass 例項或 dataclass 型別,並將其轉換為張量相容的類,同時可選擇應用自動批處理、不變性和型別轉換等各種配置。
- 引數:
obj (Any) – 要轉換的 dataclass 例項或型別。如果提供的是型別,則返回一個新的類。
- 關鍵字引數:
auto_batch_size (bool, 可選) – 如果為
True,則自動確定並將批次大小應用於結果物件。預設為False。batch_dims (int, 可選) – 如果 auto_batch_size 為
True,則定義輸出 tensordict 應具有的維度數。預設為None(每層全批次大小)。batch_size (torch.Size, 可選) – TensorDict 的批次大小。預設為
None。frozen (bool, 可選) – 如果為
True,則結果類或例項將是不可變的。預設為False。autocast (bool, 可選) – 如果為
True,則為結果類或例項啟用自動型別轉換。預設為False。nocast (bool, 可選) – 如果為
True,則停用結果類或例項的任何型別轉換。預設為False。inplace (bool, 可選) – 如果為
True,則傳入的 dataclass 型別將被原地修改。預設為False。如果提供的是例項,則此引數無效。device (torch.device, 可選) – 建立 TensorDict 的裝置。預設為
None。shadow (bool, 可選) – 停用欄位名與 TensorDict 保留屬性的驗證。請謹慎使用,這可能會導致意外後果。預設為 False。
- 返回:
從提供的 dataclass 派生的張量相容類或例項。
- 丟擲:
TypeError – 如果提供的輸入不是 dataclass 例項或型別。
示例
>>> from dataclasses import dataclass >>> import torch >>> from tensordict.tensorclass import from_dataclass >>> >>> @dataclass >>> class X: ... a: int ... b: torch.Tensor ... >>> x = X(0, 0) >>> x2 = from_dataclass(x) >>> print(x2) X( a=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), b=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), batch_size=torch.Size([]), device=None, is_shared=False) >>> X2 = from_dataclass(X, autocast=True) >>> print(X2(a=0, b=0)) X( a=NonTensorData(data=0, batch_size=torch.Size([]), device=None), b=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), batch_size=torch.Size([]), device=None, is_shared=False)
警告
儘管
from_dataclass()預設返回一個TensorDict例項,但此方法將返回一個 tensorclass 例項或型別。