快捷方式

torch.nn.functional.ctc_loss

torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False)[source][source]

應用 Connectionist Temporal Classification (連線時序分類) 損失。

詳見 CTCLoss

注意

在某些情況下,當在 CUDA 裝置上給定張量並使用 CuDNN 時,此運算元可能會選擇非確定性演算法以提高效能。如果不需要這樣,您可以透過設定 torch.backends.cudnn.deterministic = True 來嘗試使操作確定性(可能會犧牲效能)。更多資訊請參閱 可復現性

注意

在 CUDA 裝置上給定張量時,此操作可能會產生非確定性梯度。更多資訊請參閱 可復現性

引數
  • log_probs (Tensor) – (T,N,C)(T, N, C)(T,C)(T, C),其中 C = 字母表中包括空白字元的數量T = 輸入長度N = 批大小。輸出的對數化機率(例如,透過 torch.nn.functional.log_softmax() 獲得)。

  • targets (Tensor) – (N,S)(N, S)(sum(target_lengths))。目標不能是空白字元。在第二種形式中,目標被認為是已連線的。

  • input_lengths (Tensor) – (N)(N)()()。輸入的長度(每個必須 T\leq T

  • target_lengths (Tensor) – (N)(N)()()。目標的長度

  • blank (int, optional) – 空白字元的標籤。預設為 00

  • reduction (str, optional) – 指定應用於輸出的歸約方法:'none' | 'mean' | 'sum''none':不應用歸約;'mean':輸出損失將除以目標長度,然後取批次平均值;'sum':輸出將被求和。預設值:'mean'

  • zero_infinity (bool, optional) – 是否將無限損失及其相關梯度清零。預設值:False。無限損失主要發生在輸入太短而無法與目標對齊時。

返回型別

Tensor

示例

>>> log_probs = torch.randn(50, 16, 20).log_softmax(2).detach().requires_grad_()
>>> targets = torch.randint(1, 20, (16, 30), dtype=torch.long)
>>> input_lengths = torch.full((16,), 50, dtype=torch.long)
>>> target_lengths = torch.randint(10, 30, (16,), dtype=torch.long)
>>> loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths)
>>> loss.backward()

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得解答

檢視資源