快捷方式

torch.nn.utils.rnn.pack_padded_sequence

torch.nn.utils.rnn.pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True)[source][source]

打包包含可變長度填充序列的張量。

input 可以是 T x B x * 大小(如果 batch_firstFalse)或 B x T x * 大小(如果 batch_firstTrue),其中 T 是最長序列的長度,B 是批次大小,* 是任意維度數(包括 0)。

對於未排序的序列,請使用 enforce_sorted = False。如果 enforce_sortedTrue,序列應按長度降序排列,即 input[:,0] 應為最長序列,而 input[:,B-1] 應為最短序列。enforce_sorted = True 僅在 ONNX 匯出時是必要的。

它是 pad_packed_sequence() 的逆操作,因此可以使用 pad_packed_sequence() 來恢復打包在 PackedSequence 中的底層張量。

注意

此函式接受至少包含兩個維度的任何輸入。您可以將其用於打包標籤,並使用 RNN 的輸出與標籤一起直接計算損失。可以透過訪問 PackedSequence 物件的 .data 屬性來從中檢索張量。

引數
  • input (Tensor) – 可變長度序列的填充批次。

  • lengths (Tensorlist(int)) – 每個批次元素的序列長度列表(如果以張量形式提供,則必須位於 CPU 上)。

  • batch_first (bool, 可選) – 如果為 True,則輸入應為 B x T x * 格式,否則為 T x B x *。預設值:False

  • enforce_sorted (bool, 可選) – 如果為 True,則輸入應包含按長度降序排列的序列。如果為 False,輸入將無條件地進行排序。預設值:True

返回型別

PackedSequence

警告

如果 input 張量的長度大於 length 中的對應值,則其維度將被截斷。

返回

一個 PackedSequence 物件

返回型別

PackedSequence


© 版權所有 PyTorch 貢獻者。

使用 Sphinx 構建,主題由 Read the Docs 提供。

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源