快捷方式

CTCLoss

torch.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=False)[原始碼][原始碼]

連線主義時間分類損失 (Connectionist Temporal Classification loss)。

計算連續(未分段)時間序列與目標序列之間的損失。CTCLoss 對輸入與目標序列所有可能的對齊方式的機率求和,從而生成一個相對於每個輸入節點都可微分的損失值。輸入與目標的對齊被假定為“多對一”關係,這將目標序列的長度限制為必須 \leq 輸入長度。

引數
  • blank (int, optional) – 空格標籤。預設值 00

  • reduction (str, optional) – 指定要應用於輸出的歸約方式:'none' | 'mean' | 'sum''none':不應用任何歸約;'mean':輸出損失將除以目標長度,然後取批次平均值;'sum':輸出損失將被求和。預設值:'mean'

  • zero_infinity (bool, optional) – 是否將無限損失及其相關梯度清零。預設值:False 當輸入序列過短無法與目標序列對齊時,主要會出現無限損失。

形狀
  • Log_probs:形狀為 (T,N,C)(T, N, C)(T,C)(T, C) 的 Tensor,其中 T=輸入長度T = \text{輸入長度}N=批次大小N = \text{批次大小},而 C=類別數量(包括空格)C = \text{類別數量(包括空格)}。輸出的對數化機率(例如,使用 torch.nn.functional.log_softmax() 獲得)。

  • Targets:形狀為 (N,S)(N, S)(sum(target_lengths))(\operatorname{sum}(\text{target\_lengths})) 的 Tensor,其中 N=批次大小N = \text{批次大小};如果形狀是 (N,S)(N, S),則 S=最大目標長度S = \text{最大目標長度}。它表示目標序列。目標序列中的每個元素都是一個類別索引。並且目標索引不能是空格標籤(預設值 = 0)。在 (N,S)(N, S) 形式中,目標序列將被填充到最長序列的長度,並堆疊在一起。在 (sum(target_lengths))(\operatorname{sum}(\text{target\_lengths})) 形式中,假定目標序列未填充並在一維中連線。

  • Input_lengths:形狀為 (N)(N)()() 的 Tuple 或 Tensor,其中 N=批次大小N = \text{批次大小}。它表示輸入的長度(每個長度必須 T\leq T)。並且為每個序列指定了長度,以便在假設序列已填充到相等長度的情況下實現掩碼。

  • Target_lengths:形狀為 (N)(N)()() 的 Tuple 或 Tensor,其中 N=批次大小N = \text{批次大小}。它表示目標的長度。為每個序列指定了長度,以便在假設序列已填充到相等長度的情況下實現掩碼。如果目標形狀是 (N,S)(N,S),則 target_lengths 實際上是每個目標序列的停止索引 sns_n,使得對於批次中的每個目標,有 target_n = targets[n,0:s_n]。每個長度必須 S\leq S。如果目標序列作為一維 Tensor 給出,該 Tensor 是各個目標序列的拼接,則 target_lengths 的總和必須等於該 Tensor 的總長度。

  • 輸出:如果 reduction'mean'(預設值)或 'sum',則為標量。如果 reduction'none',則如果輸入是批次處理的,形狀為 (N)(N);如果輸入未進行批次處理,形狀為 ()(),其中 N=批次大小N = \text{批次大小}

示例

>>> # Target are to be padded
>>> T = 50      # Input sequence length
>>> C = 20      # Number of classes (including blank)
>>> N = 16      # Batch size
>>> S = 30      # Target sequence length of longest target in batch (padding length)
>>> S_min = 10  # Minimum target length, for demonstration purposes
>>>
>>> # Initialize random batch of input vectors, for *size = (T,N,C)
>>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
>>>
>>> # Initialize random batch of targets (0 = blank, 1:C = classes)
>>> target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long)
>>>
>>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
>>> target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long)
>>> ctc_loss = nn.CTCLoss()
>>> loss = ctc_loss(input, target, input_lengths, target_lengths)
>>> loss.backward()
>>>
>>>
>>> # Target are to be un-padded
>>> T = 50      # Input sequence length
>>> C = 20      # Number of classes (including blank)
>>> N = 16      # Batch size
>>>
>>> # Initialize random batch of input vectors, for *size = (T,N,C)
>>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
>>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
>>>
>>> # Initialize random batch of targets (0 = blank, 1:C = classes)
>>> target_lengths = torch.randint(low=1, high=T, size=(N,), dtype=torch.long)
>>> target = torch.randint(low=1, high=C, size=(sum(target_lengths),), dtype=torch.long)
>>> ctc_loss = nn.CTCLoss()
>>> loss = ctc_loss(input, target, input_lengths, target_lengths)
>>> loss.backward()
>>>
>>>
>>> # Target are to be un-padded and unbatched (effectively N=1)
>>> T = 50      # Input sequence length
>>> C = 20      # Number of classes (including blank)
>>>
>>> # Initialize random batch of input vectors, for *size = (T,C)
>>> input = torch.randn(T, C).log_softmax(1).detach().requires_grad_()
>>> input_lengths = torch.tensor(T, dtype=torch.long)
>>>
>>> # Initialize random batch of targets (0 = blank, 1:C = classes)
>>> target_lengths = torch.randint(low=1, high=T, size=(), dtype=torch.long)
>>> target = torch.randint(low=1, high=C, size=(target_lengths,), dtype=torch.long)
>>> ctc_loss = nn.CTCLoss()
>>> loss = ctc_loss(input, target, input_lengths, target_lengths)
>>> loss.backward()
參考

A. Graves 等人:連線主義時間分類:使用迴圈神經網路標記未分段序列資料 (https://www.cs.toronto.edu/~graves/icml_2006.pdf)

注意

為了使用 CuDNN,必須滿足以下條件:targets 必須採用拼接格式,所有 input_lengths 必須等於 Tblank=0blank=0target_lengths 256\leq 256,整型引數的 dtype 必須是 torch.int32

常規實現使用(在 PyTorch 中更常見)的 torch.long dtype。

注意

在使用 CUDA 後端和 CuDNN 的某些情況下,此運算元可能會選擇非確定性演算法以提高效能。如果這是不可取的,您可以透過設定 torch.backends.cudnn.deterministic = True 來嘗試使操作具有確定性(可能會犧牲效能)。背景資訊請參閱可復現性相關的注意事項。

文件

查閱 PyTorch 的完整開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲取問題解答

檢視資源