快捷方式

pad_sequence

class tensordict.pad_sequence(list_of_tensordicts: Sequence[T], pad_dim: int = 0, padding_value: float = 0.0, out: Optional[T] = None, device: Optional[Union[device, str,int]] = None, return_mask: bool | tensordict._nestedkey.NestedKey = False)

填充 tensordict 列表,以便將它們堆疊成連續格式。

引數:
  • list_of_tensordicts (List[TensorDictBase]) – 要填充和堆疊的例項列表。

  • pad_dim (int, 可選) – pad_dim 指示要填充 tensordict 中所有鍵的維度。預設為 0

  • padding_value (number, 可選) – 填充值。預設為 0.0

  • out (TensorDictBase, 可選) – 如果提供,資料將寫入的目標位置。

  • return_mask (boolNestedKey, 可選) – 如果為 True,將返回一個“masks”條目。如果 return_mask 是一個巢狀鍵(字串或字串元組),它將返回掩碼並用作掩碼條目的鍵。它包含一個與堆疊的 tensordict 具有相同結構的 tensordict,其中每個條目包含有效值的掩碼,大小為 torch.Size([stack_len, *new_shape]),其中 new_shape[pad_dim] = max_seq_lengthnew_shape 的其餘部分與包含的張量的先前形狀匹配。

示例

>>> list_td = [
...     TensorDict({"a": torch.zeros((3, 8)), "b": torch.zeros((6, 8))}, batch_size=[]),
...     TensorDict({"a": torch.zeros((5, 8)), "b": torch.zeros((6, 8))}, batch_size=[]),
...     ]
>>> padded_td = pad_sequence(list_td, return_mask=True)
>>> print(padded_td)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2, 4, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([2, 5, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        masks: TensorDict(
            fields={
                a: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.bool, is_shared=False),
                b: Tensor(shape=torch.Size([2, 6]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([2]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲取問題解答

檢視資源