CTCLoss¶
- 類 torch.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=False)[原始碼][原始碼]¶
連線主義時間分類損失 (Connectionist Temporal Classification loss)。
計算連續(未分段)時間序列與目標序列之間的損失。CTCLoss 對輸入與目標序列所有可能的對齊方式的機率求和,從而生成一個相對於每個輸入節點都可微分的損失值。輸入與目標的對齊被假定為“多對一”關係,這將目標序列的長度限制為必須 輸入長度。
- 引數
- 形狀
Log_probs:形狀為 或 的 Tensor,其中 ,,而 。輸出的對數化機率(例如,使用
torch.nn.functional.log_softmax()獲得)。Targets:形狀為 或 的 Tensor,其中 ;如果形狀是 ,則 。它表示目標序列。目標序列中的每個元素都是一個類別索引。並且目標索引不能是空格標籤(預設值 = 0)。在 形式中,目標序列將被填充到最長序列的長度,並堆疊在一起。在 形式中,假定目標序列未填充並在一維中連線。
Input_lengths:形狀為 或 的 Tuple 或 Tensor,其中 。它表示輸入的長度(每個長度必須 )。並且為每個序列指定了長度,以便在假設序列已填充到相等長度的情況下實現掩碼。
Target_lengths:形狀為 或 的 Tuple 或 Tensor,其中 。它表示目標的長度。為每個序列指定了長度,以便在假設序列已填充到相等長度的情況下實現掩碼。如果目標形狀是 ,則 target_lengths 實際上是每個目標序列的停止索引 ,使得對於批次中的每個目標,有
target_n = targets[n,0:s_n]。每個長度必須 。如果目標序列作為一維 Tensor 給出,該 Tensor 是各個目標序列的拼接,則 target_lengths 的總和必須等於該 Tensor 的總長度。輸出:如果
reduction是'mean'(預設值)或'sum',則為標量。如果reduction是'none',則如果輸入是批次處理的,形狀為 ;如果輸入未進行批次處理,形狀為 ,其中 。
示例
>>> # Target are to be padded >>> T = 50 # Input sequence length >>> C = 20 # Number of classes (including blank) >>> N = 16 # Batch size >>> S = 30 # Target sequence length of longest target in batch (padding length) >>> S_min = 10 # Minimum target length, for demonstration purposes >>> >>> # Initialize random batch of input vectors, for *size = (T,N,C) >>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_() >>> >>> # Initialize random batch of targets (0 = blank, 1:C = classes) >>> target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long) >>> >>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long) >>> target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long) >>> ctc_loss = nn.CTCLoss() >>> loss = ctc_loss(input, target, input_lengths, target_lengths) >>> loss.backward() >>> >>> >>> # Target are to be un-padded >>> T = 50 # Input sequence length >>> C = 20 # Number of classes (including blank) >>> N = 16 # Batch size >>> >>> # Initialize random batch of input vectors, for *size = (T,N,C) >>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_() >>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long) >>> >>> # Initialize random batch of targets (0 = blank, 1:C = classes) >>> target_lengths = torch.randint(low=1, high=T, size=(N,), dtype=torch.long) >>> target = torch.randint(low=1, high=C, size=(sum(target_lengths),), dtype=torch.long) >>> ctc_loss = nn.CTCLoss() >>> loss = ctc_loss(input, target, input_lengths, target_lengths) >>> loss.backward() >>> >>> >>> # Target are to be un-padded and unbatched (effectively N=1) >>> T = 50 # Input sequence length >>> C = 20 # Number of classes (including blank) >>> >>> # Initialize random batch of input vectors, for *size = (T,C) >>> input = torch.randn(T, C).log_softmax(1).detach().requires_grad_() >>> input_lengths = torch.tensor(T, dtype=torch.long) >>> >>> # Initialize random batch of targets (0 = blank, 1:C = classes) >>> target_lengths = torch.randint(low=1, high=T, size=(), dtype=torch.long) >>> target = torch.randint(low=1, high=C, size=(target_lengths,), dtype=torch.long) >>> ctc_loss = nn.CTCLoss() >>> loss = ctc_loss(input, target, input_lengths, target_lengths) >>> loss.backward()
- 參考
A. Graves 等人:連線主義時間分類:使用迴圈神經網路標記未分段序列資料 (https://www.cs.toronto.edu/~graves/icml_2006.pdf)
注意
為了使用 CuDNN,必須滿足以下條件:
targets必須採用拼接格式,所有input_lengths必須等於 T。,target_lengths,整型引數的 dtype 必須是torch.int32。常規實現使用(在 PyTorch 中更常見)的 torch.long dtype。
注意
在使用 CUDA 後端和 CuDNN 的某些情況下,此運算元可能會選擇非確定性演算法以提高效能。如果這是不可取的,您可以透過設定
torch.backends.cudnn.deterministic = True來嘗試使操作具有確定性(可能會犧牲效能)。背景資訊請參閱可復現性相關的注意事項。