torch.nn.utils.rnn.pad_sequence¶
- torch.nn.utils.rnn.pad_sequence(sequences, batch_first=False, padding_value=0.0, padding_side='right')[源][源]¶
使用
padding_value填充變長 Tensor 列表。pad_sequence沿新維度堆疊變長 Tensor 列表,並將它們填充到等長。sequences可以是大小為L x *的序列列表,其中 L 是序列長度,*是任意數量的維度(包括0)。如果batch_first為False,輸出大小為T x B x *;否則為B x T x *,其中B是批次大小(sequences中的元素數量),T是最長序列的長度。示例
>>> from torch.nn.utils.rnn import pad_sequence >>> a = torch.ones(25, 300) >>> b = torch.ones(22, 300) >>> c = torch.ones(15, 300) >>> pad_sequence([a, b, c]).size() torch.Size([25, 3, 300])
注意
此函式返回一個大小為
T x B x *或B x T x *的 Tensor,其中 T 是最長序列的長度。此函式假定 sequences 中所有 Tensor 的尾隨維度和型別相同。- 引數
- 返回
如果
batch_first為False,則返回大小為T x B x *的 Tensor。否則返回大小為B x T x *的 Tensor- 返回型別