快捷方式

tensorclass

tensordict.tensorclass(cls: 可選[T] = , /, *, autocast: 布林值 = False, frozen: 布林值 = False, nocast: 布林值 = False, shadow: 布林值 = False)

一個用於建立 tensorclass 類的裝飾器。

tensorclass 類是專門化的 dataclasses.dataclass() 例項,它們可以即時執行一些預定義的張量操作,例如索引、項賦值、重塑、轉換為裝置或儲存等許多操作。

關鍵字引數:
  • autocast (布林值, 可選) – 如果為 True,則在設定引數時將強制執行指定的型別。此引數與 nocast 互斥(兩者不能同時為 True)。預設為 False

  • frozen (布林值, 可選) – 如果為 True,則 tensorclass 的內容無法修改。提供此引數是為了與 dataclass 相容,透過類建構函式中的 lock 引數可以獲得類似的行為。預設為 False

  • nocast (布林值, 可選) – 如果為 True,則 Tensor 相容的型別,如 intnp.ndarray 等,將不會被轉換為張量型別。此引數與 autocast 互斥(兩者不能同時為 True)。預設為 False

  • shadow (布林值, 可選) – 停用對欄位名與 TensorDict 保留屬性的驗證。請謹慎使用,因為這可能導致意外後果。預設為 False。

tensorclass 可以帶或不帶引數使用

示例

>>> @tensorclass
... class X:
...     y: int
>>> X(torch.ones(())).y
tensor(1.)
>>> @tensorclass(autocast=False)
... class X:
...     y: int
>>> X(torch.ones(())).y
tensor(1.)
>>> @tensorclass(autocast=True)
... class X:
...     y: int
>>> X(torch.ones(())).y
1
>>> @tensorclass(nocast=True)
... class X:
...     y: Any
>>> X(1).y
1
>>> @tensorclass(nocast=False)
... class X:
...     y: Any
>>> X(1).y
tensor(1)

示例

>>> from tensordict import tensorclass
>>> import torch
>>> from typing import Optional
>>>
>>> @tensorclass
... class MyData:
...     X: torch.Tensor
...     y: torch.Tensor
...     z: str
...     def expand_and_mask(self):
...         X = self.X.unsqueeze(-1).expand_as(self.y)
...         X = X[self.y]
...         return X
...
>>> data = MyData(
...     X=torch.ones(3, 4, 1),
...     y=torch.zeros(3, 4, 2, 2, dtype=torch.bool),
...     z="test"
...     batch_size=[3, 4])
>>> print(data)
MyData(
    X=Tensor(torch.Size([3, 4, 1]), dtype=torch.float32),
    y=Tensor(torch.Size([3, 4, 2, 2]), dtype=torch.bool),
    z="test"
    batch_size=[3, 4],
    device=None,
    is_shared=False)
>>> print(data.expand_and_mask())
tensor([])
也可以將 tensorclass 例項互相巢狀

示例: >>> from tensordict import tensorclass >>> import torch >>> from typing import Optional >>> >>> @tensorclass … class NestingMyData: … nested: MyData … >>> nesting_data = NestingMyData(nested=data, batch_size=[3, 4]) >>> # 儘管資料儲存為 TensorDict,但型別提示有助於我們將資料適當地轉換為正確的型別 >>> assert isinstance(nesting_data.nested, type(data))

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源