RNNTLoss¶
- class torchaudio.transforms.RNNTLoss(blank: int = -1, clamp: float = -1.0, reduction: str = 'mean', fused_log_softmax: bool = True)[source]¶
- 計算 *使用迴圈神經網路進行序列轉導* [id1] 中所述的 RNN 轉導損失(RNN Transducer loss)。 - RNN 轉導損失透過定義所有長度輸出序列的分佈,並聯合建模輸入-輸出和輸出-輸出依賴關係,擴充套件了 CTC 損失。 - 引數:
 - 示例
- >>> # Hypothetical values >>> logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1], >>> [0.1, 0.1, 0.6, 0.1, 0.1], >>> [0.1, 0.1, 0.2, 0.8, 0.1]], >>> [[0.1, 0.6, 0.1, 0.1, 0.1], >>> [0.1, 0.1, 0.2, 0.1, 0.1], >>> [0.7, 0.1, 0.2, 0.1, 0.1]]]], >>> dtype=torch.float32, >>> requires_grad=True) >>> targets = torch.tensor([[1, 2]], dtype=torch.int) >>> logit_lengths = torch.tensor([2], dtype=torch.int) >>> target_lengths = torch.tensor([2], dtype=torch.int) >>> transform = transforms.RNNTLoss(blank=0) >>> loss = transform(logits, targets, logit_lengths, target_lengths) >>> loss.backward() 
 - forward(logits: Tensor, targets: Tensor, logit_lengths: Tensor, target_lengths: Tensor)[source]¶
- 引數:
- logits (Tensor) – 維度為 (batch, 最大序列長度, 最大目標長度 + 1, 類別) 的 Tensor,包含來自聯結器(joiner)的輸出 
- targets (Tensor) – 維度為 (batch, 最大目標長度) 的 Tensor,包含零填充的目標序列 
- logit_lengths (Tensor) – 維度為 (batch) 的 Tensor,包含來自編碼器的每個序列的長度 
- target_lengths (Tensor) – 維度為 (batch) 的 Tensor,包含每個序列的目標長度 
 
- 返回:
- 應用了歸約選項的損失。如果 - reduction為- "none",則形狀為 (batch),否則為標量。
- 返回型別:
- Tensor