快捷方式

from_dataclass

class tensordict.from_dataclass(obj: Any, *, auto_batch_size: bool = False, batch_dims: Optional[int] = None, batch_size: Optional[Size] = None, frozen: bool = False, autocast: bool = False, nocast: bool = False, inplace: bool = False, shadow: bool = False, device: Optional[device] = None)

將 dataclass 例項或型別分別轉換為 tensorclass 例項或型別。

此函式接收 dataclass 例項或 dataclass 型別,並將其轉換為張量相容的類,同時可選擇應用自動批處理、不變性和型別轉換等各種配置。

引數:

obj (Any) – 要轉換的 dataclass 例項或型別。如果提供的是型別,則返回一個新的類。

關鍵字引數:
  • auto_batch_size (bool, 可選) – 如果為 True,則自動確定並將批次大小應用於結果物件。預設為 False

  • batch_dims (int, 可選) – 如果 auto_batch_size 為 True,則定義輸出 tensordict 應具有的維度數。預設為 None(每層全批次大小)。

  • batch_size (torch.Size, 可選) – TensorDict 的批次大小。預設為 None

  • frozen (bool, 可選) – 如果為 True,則結果類或例項將是不可變的。預設為 False

  • autocast (bool, 可選) – 如果為 True,則為結果類或例項啟用自動型別轉換。預設為 False

  • nocast (bool, 可選) – 如果為 True,則停用結果類或例項的任何型別轉換。預設為 False

  • inplace (bool, 可選) – 如果為 True,則傳入的 dataclass 型別將被原地修改。預設為 False。如果提供的是例項,則此引數無效。

  • device (torch.device, 可選) – 建立 TensorDict 的裝置。預設為 None

  • shadow (bool, 可選) – 停用欄位名與 TensorDict 保留屬性的驗證。請謹慎使用,這可能會導致意外後果。預設為 False。

返回:

從提供的 dataclass 派生的張量相容類或例項。

丟擲:

TypeError – 如果提供的輸入不是 dataclass 例項或型別。

示例

>>> from dataclasses import dataclass
>>> import torch
>>> from tensordict.tensorclass import from_dataclass
>>>
>>> @dataclass
>>> class X:
...     a: int
...     b: torch.Tensor
...
>>> x = X(0, 0)
>>> x2 = from_dataclass(x)
>>> print(x2)
X(
    a=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
    b=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> X2 = from_dataclass(X, autocast=True)
>>> print(X2(a=0, b=0))
X(
    a=NonTensorData(data=0, batch_size=torch.Size([]), device=None),
    b=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

警告

儘管 from_dataclass() 預設返回一個 TensorDict 例項,但此方法將返回一個 tensorclass 例項或型別。

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源