tensorclass¶
- 類 tensordict.tensorclass(cls: 可選[T] = 無, /, *, autocast: 布林值 = False, frozen: 布林值 = False, nocast: 布林值 = False, shadow: 布林值 = False)¶
一個用於建立
tensorclass類的裝飾器。tensorclass類是專門化的dataclasses.dataclass()例項,它們可以即時執行一些預定義的張量操作,例如索引、項賦值、重塑、轉換為裝置或儲存等許多操作。- 關鍵字引數:
autocast (布林值, 可選) – 如果為
True,則在設定引數時將強制執行指定的型別。此引數與nocast互斥(兩者不能同時為 True)。預設為False。frozen (布林值, 可選) – 如果為
True,則 tensorclass 的內容無法修改。提供此引數是為了與 dataclass 相容,透過類建構函式中的 lock 引數可以獲得類似的行為。預設為False。nocast (布林值, 可選) – 如果為
True,則 Tensor 相容的型別,如int、np.ndarray等,將不會被轉換為張量型別。此引數與autocast互斥(兩者不能同時為 True)。預設為False。shadow (布林值, 可選) – 停用對欄位名與 TensorDict 保留屬性的驗證。請謹慎使用,因為這可能導致意外後果。預設為 False。
tensorclass 可以帶或不帶引數使用
示例
>>> @tensorclass ... class X: ... y: int >>> X(torch.ones(())).y tensor(1.) >>> @tensorclass(autocast=False) ... class X: ... y: int >>> X(torch.ones(())).y tensor(1.) >>> @tensorclass(autocast=True) ... class X: ... y: int >>> X(torch.ones(())).y 1 >>> @tensorclass(nocast=True) ... class X: ... y: Any >>> X(1).y 1 >>> @tensorclass(nocast=False) ... class X: ... y: Any >>> X(1).y tensor(1)
示例
>>> from tensordict import tensorclass >>> import torch >>> from typing import Optional >>> >>> @tensorclass ... class MyData: ... X: torch.Tensor ... y: torch.Tensor ... z: str ... def expand_and_mask(self): ... X = self.X.unsqueeze(-1).expand_as(self.y) ... X = X[self.y] ... return X ... >>> data = MyData( ... X=torch.ones(3, 4, 1), ... y=torch.zeros(3, 4, 2, 2, dtype=torch.bool), ... z="test" ... batch_size=[3, 4]) >>> print(data) MyData( X=Tensor(torch.Size([3, 4, 1]), dtype=torch.float32), y=Tensor(torch.Size([3, 4, 2, 2]), dtype=torch.bool), z="test" batch_size=[3, 4], device=None, is_shared=False) >>> print(data.expand_and_mask()) tensor([])
- 也可以將 tensorclass 例項互相巢狀
示例: >>> from tensordict import tensorclass >>> import torch >>> from typing import Optional >>> >>> @tensorclass … class NestingMyData: … nested: MyData … >>> nesting_data = NestingMyData(nested=data, batch_size=[3, 4]) >>> # 儘管資料儲存為 TensorDict,但型別提示有助於我們將資料適當地轉換為正確的型別 >>> assert isinstance(nesting_data.nested, type(data))