快捷方式

torchaudio.functional.rnnt_loss

torchaudio.functional.rnnt_loss(logits: Tensor, targets: Tensor, logit_lengths: Tensor, target_lengths: Tensor, blank: int = -1, clamp: float = -1, reduction: str = 'mean', fused_log_softmax: bool = True)[source]

計算來自 Sequence Transduction with Recurrent Neural Networks [Graves, 2012] 的 RNN 換能器損失 (RNN Transducer loss)。

This feature supports the following devices: CPU, CUDA This API supports the following properties: Autograd, TorchScript

RNN 換能器損失透過定義一個涵蓋所有長度的輸出序列上的分佈,並共同建模輸入-輸出和輸出-輸出依賴性,從而擴充套件了 CTC 損失。

引數:
  • logits (Tensor) – 維度為 (batch, max seq length, max target length + 1, class) 的 Tensor,包含來自 joiner 的輸出

  • targets (Tensor) – 維度為 (batch, max target length) 的 Tensor,包含零填充的目標

  • logit_lengths (Tensor) – 維度為 (batch) 的 Tensor,包含編碼器中每個序列的長度

  • target_lengths (Tensor) – 維度為 (batch) 的 Tensor,包含每個序列目標的長度

  • blank (int, optional) – blank 標籤 (預設值: -1)

  • clamp (float, optional) – 梯度截斷 (預設值: -1)

  • reduction (string, optional) – 指定應用於輸出的歸約方式: "none" | "mean" | "sum"。(預設值: "mean")

  • fused_log_softmax (bool) – 如果在損失函式外部呼叫 log_softmax,則設定為 False (預設值: True)

返回:

應用了歸約選項的損失。如果 reduction"none",則返回大小為 (batch) 的 Tensor,否則返回標量。

返回型別:

Tensor

文件

訪問 PyTorch 開發者綜合文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲取問題解答

檢視資源