CrossEntropyLoss¶
- 類別 torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean', label_smoothing=0.0)[來源][來源]¶
此準則計算輸入 logits 和目標之間的交叉熵損失。
當訓練一個有 C 個類別的分類問題時,它很有用。如果提供,可選引數
weight應該是一個 1D Tensor,為每個類別分配權重。這在訓練集不平衡時特別有用。input 期望包含每個類別的未歸一化 logits(通常 不需要 為正或總和為 1)。對於非批處理輸入,input 必須是大小為 的 Tensor;對於批處理輸入,大小為 ;或者對於 K 維情況,大小為 ,其中 。最後一種情況對於更高維度的輸入很有用,例如計算 2D 影像的逐畫素交叉熵損失。
此準則期望的 target 應包含以下兩種形式之一:
在範圍 內的類別索引,其中 是類別數量;如果指定了 ignore_index,該損失函式也會接受此類別索引(此索引不一定在類別範圍內)。在這種情況下,未進行歸約(即
reduction設定為'none')的損失可以描述為:其中 是輸入, 是目標, 是權重, 是類別數量, 涵蓋 minibatch 維度以及 K 維情況下的 。如果
reduction不是'none'(預設為'mean'),則注意,這種情況等價於對輸入應用
LogSoftmax,然後應用NLLLoss。每類別的機率;在要求每個 mini-batch 專案的標籤超出單個類別時很有用,例如用於混合標籤(blended labels)、標籤平滑(label smoothing)等。在這種情況下,未歸約(即
reduction設定為'none')的損失可以描述為其中 是輸入, 是目標, 是權重, 是類別數量, 涵蓋 minibatch 維度以及 K 維情況下的 。如果
reduction不是'none'(預設為'mean'),則
注意
請注意,當 target 包含類別索引時,此標準的效能通常更佳,因為這允許進行最佳化的計算。僅當每個 mini-batch 專案的單個類別標籤限制性太大時,才考慮將 target 提供為類別機率。
- 引數
weight (Tensor, optional) – 給每個類別手動分配的重縮放權重。如果給出,必須是大小為 C 的 Tensor,且資料型別為浮點型。
size_average (bool, optional) – 已廢棄(參見
reduction)。預設情況下,損失在 batch 中的每個損失元素上求平均。請注意,對於某些損失,每個樣本有多個元素。如果欄位size_average設定為False,則改為對每個 mini-batch 的損失求和。當reduce為False時忽略。預設值:Trueignore_index (int, optional) – 指定一個被忽略且不計入輸入梯度的目標值。當
size_average為True時,損失在非忽略目標上求平均。請注意,ignore_index僅適用於 target 包含類別索引的情況。reduce (bool, optional) – 已廢棄(參見
reduction)。預設情況下,根據size_average的設定,損失在每個 mini-batch 的觀測值上求平均或求和。當reduce為False時,轉而返回每個 batch 元素的損失,並忽略size_average。預設值:Truereduction (str, optional) – 指定應用於輸出的歸約方式:
'none'|'mean'|'sum'。'none':不應用歸約,'mean':對輸出求加權平均,'sum':對輸出求和。注意:size_average和reduce正在被廢棄,同時,指定這兩個引數中的任何一個都將覆蓋reduction的設定。預設值:'mean'label_smoothing (float, optional) – 一個在 [0.0, 1.0] 範圍內的浮點數。指定計算損失時平滑的數量,其中 0.0 表示不進行平滑。目標變成原始真實標籤和均勻分佈的混合,如論文 Rethinking the Inception Architecture for Computer Vision 中所述。預設值:。
- 形狀
輸入: 形狀為 、 或 ,其中 表示 K 維損失的情況。
目標:如果包含類別索引,形狀為 、 或 ,在 K 維損失的情況下,其中 且每個值應介於 。使用類別索引時,目標資料型別必須為 long。如果包含類別機率,目標必須與輸入形狀相同,並且每個值應介於 。這意味著使用類別機率時,目標資料型別必須為 float。
輸出:如果 reduction 為 ‘none’,形狀為 、 或 ,在 K 維損失的情況下,其中 ,取決於輸入的形狀。否則,為標量。
其中
示例
>>> # Example of target with class indices >>> loss = nn.CrossEntropyLoss() >>> input = torch.randn(3, 5, requires_grad=True) >>> target = torch.empty(3, dtype=torch.long).random_(5) >>> output = loss(input, target) >>> output.backward() >>> >>> # Example of target with class probabilities >>> input = torch.randn(3, 5, requires_grad=True) >>> target = torch.randn(3, 5).softmax(dim=1) >>> output = loss(input, target) >>> output.backward()