torchaudio.functional.rnnt_loss¶
- torchaudio.functional.rnnt_loss(logits: Tensor, targets: Tensor, logit_lengths: Tensor, target_lengths: Tensor, blank: int = -1, clamp: float = -1, reduction: str = 'mean', fused_log_softmax: bool = True)[source]¶
計算來自 Sequence Transduction with Recurrent Neural Networks [Graves, 2012] 的 RNN 換能器損失 (RNN Transducer loss)。
RNN 換能器損失透過定義一個涵蓋所有長度的輸出序列上的分佈,並共同建模輸入-輸出和輸出-輸出依賴性,從而擴充套件了 CTC 損失。
- 引數:
logits (Tensor) – 維度為 (batch, max seq length, max target length + 1, class) 的 Tensor,包含來自 joiner 的輸出
targets (Tensor) – 維度為 (batch, max target length) 的 Tensor,包含零填充的目標
logit_lengths (Tensor) – 維度為 (batch) 的 Tensor,包含編碼器中每個序列的長度
target_lengths (Tensor) – 維度為 (batch) 的 Tensor,包含每個序列目標的長度
blank (int, optional) – blank 標籤 (預設值:
-1)clamp (float, optional) – 梯度截斷 (預設值:
-1)reduction (string, optional) – 指定應用於輸出的歸約方式:
"none"|"mean"|"sum"。(預設值:"mean")fused_log_softmax (bool) – 如果在損失函式外部呼叫 log_softmax,則設定為 False (預設值:
True)
- 返回:
應用了歸約選項的損失。如果
reduction為"none",則返回大小為 (batch) 的 Tensor,否則返回標量。- 返回型別:
Tensor