pad¶
- class tensordict.pad(tensordict: T, pad_size: Sequence[int], value: _float = 0.0)¶
使用常數值沿批處理維度填充 tensordict 中的所有張量,並返回一個新的 tensordict。
- 引數:
tensordict (TensorDict) – 要填充的 tensordict
pad_size (Sequence[int]) – 用於填充 tensordict 的部分批處理維度的填充大小,從第一個維度開始向前。批處理大小的 [len(pad_size) / 2] 個維度將被填充。例如,僅填充第一個維度時,pad 的形式為 (padding_left, padding_right)。填充兩個維度時,形式為 (padding_left, padding_right, padding_top, padding_bottom),依此類推。pad_size 必須為偶數,且小於或等於批處理維度的兩倍。
value (float, optional) – 用於填充的填充值,預設為 0.0
- 返回:
沿批處理維度填充後的新 TensorDict
示例
>>> from tensordict import TensorDict, pad >>> import torch >>> td = TensorDict({'a': torch.ones(3, 4, 1), ... 'b': torch.ones(3, 4, 1, 1)}, batch_size=[3, 4]) >>> dim0_left, dim0_right, dim1_left, dim1_right = [0, 1, 0, 2] >>> padded_td = pad(td, [dim0_left, dim0_right, dim1_left, dim1_right], value=0.0) >>> print(padded_td.batch_size) torch.Size([4, 6]) >>> print(padded_td.get("a").shape) torch.Size([4, 6, 1]) >>> print(padded_td.get("b").shape) torch.Size([4, 6, 1, 1])