torch.nn.functional.ctc_loss¶
- torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False)[source][source]¶
應用 Connectionist Temporal Classification (連線時序分類) 損失。
詳見
CTCLoss。注意
在某些情況下,當在 CUDA 裝置上給定張量並使用 CuDNN 時,此運算元可能會選擇非確定性演算法以提高效能。如果不需要這樣,您可以透過設定
torch.backends.cudnn.deterministic = True來嘗試使操作確定性(可能會犧牲效能)。更多資訊請參閱 可復現性。注意
在 CUDA 裝置上給定張量時,此操作可能會產生非確定性梯度。更多資訊請參閱 可復現性。
- 引數
log_probs (Tensor) – 或 ,其中 C = 字母表中包括空白字元的數量,T = 輸入長度,N = 批大小。輸出的對數化機率(例如,透過
torch.nn.functional.log_softmax()獲得)。targets (Tensor) – 或 (sum(target_lengths))。目標不能是空白字元。在第二種形式中,目標被認為是已連線的。
input_lengths (Tensor) – 或 。輸入的長度(每個必須 )
target_lengths (Tensor) – 或 。目標的長度
blank (int, optional) – 空白字元的標籤。預設為 。
reduction (str, optional) – 指定應用於輸出的歸約方法:
'none'|'mean'|'sum'。'none':不應用歸約;'mean':輸出損失將除以目標長度,然後取批次平均值;'sum':輸出將被求和。預設值:'mean'zero_infinity (bool, optional) – 是否將無限損失及其相關梯度清零。預設值:
False。無限損失主要發生在輸入太短而無法與目標對齊時。
- 返回型別
示例
>>> log_probs = torch.randn(50, 16, 20).log_softmax(2).detach().requires_grad_() >>> targets = torch.randint(1, 20, (16, 30), dtype=torch.long) >>> input_lengths = torch.full((16,), 50, dtype=torch.long) >>> target_lengths = torch.randint(10, 30, (16,), dtype=torch.long) >>> loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths) >>> loss.backward()