快捷方式

CrossEntropyLoss

類別 torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean', label_smoothing=0.0)[來源][來源]

此準則計算輸入 logits 和目標之間的交叉熵損失。

當訓練一個有 C 個類別的分類問題時,它很有用。如果提供,可選引數 weight 應該是一個 1D Tensor,為每個類別分配權重。這在訓練集不平衡時特別有用。

input 期望包含每個類別的未歸一化 logits(通常 不需要 為正或總和為 1)。對於非批處理輸入,input 必須是大小為 (C)(C)Tensor;對於批處理輸入,大小為 (minibatch,C)(minibatch, C);或者對於 K 維情況,大小為 (minibatch,C,d1,d2,...,dK)(minibatch, C, d_1, d_2, ..., d_K),其中 K1K \geq 1。最後一種情況對於更高維度的輸入很有用,例如計算 2D 影像的逐畫素交叉熵損失。

此準則期望的 target 應包含以下兩種形式之一:

  • 在範圍 [0,C)[0, C) 內的類別索引,其中 CC 是類別數量;如果指定了 ignore_index,該損失函式也會接受此類別索引(此索引不一定在類別範圍內)。在這種情況下,未進行歸約(即 reduction 設定為 'none')的損失可以描述為:

    (x,y)=L={l1,,lN},ln=wynlogexp(xn,yn)c=1Cexp(xn,c)1{ynignore_index}\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad l_n = - w_{y_n} \log \frac{\exp(x_{n,y_n})}{\sum_{c=1}^C \exp(x_{n,c})} \cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\}

    其中 xx 是輸入,yy 是目標,ww 是權重,CC 是類別數量,NN 涵蓋 minibatch 維度以及 K 維情況下的 d1,...,dkd_1, ..., d_k。如果 reduction 不是 'none'(預設為 'mean'),則

    (x,y)={n=1N1n=1Nwyn1{ynignore_index}ln,if reduction=‘mean’;n=1Nln,if reduction=‘sum’.\ell(x, y) = \begin{cases} \sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n} \cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\}} l_n, & \text{if reduction} = \text{`mean';}\\ \sum_{n=1}^N l_n, & \text{if reduction} = \text{`sum'.} \end{cases}

    注意,這種情況等價於對輸入應用 LogSoftmax,然後應用 NLLLoss

  • 每類別的機率;在要求每個 mini-batch 專案的標籤超出單個類別時很有用,例如用於混合標籤(blended labels)、標籤平滑(label smoothing)等。在這種情況下,未歸約(即 reduction 設定為 'none')的損失可以描述為

    (x,y)=L={l1,,lN},ln=c=1Cwclogexp(xn,c)i=1Cexp(xn,i)yn,c\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad l_n = - \sum_{c=1}^C w_c \log \frac{\exp(x_{n,c})}{\sum_{i=1}^C \exp(x_{n,i})} y_{n,c}

    其中 xx 是輸入,yy 是目標,ww 是權重,CC 是類別數量,NN 涵蓋 minibatch 維度以及 K 維情況下的 d1,...,dkd_1, ..., d_k。如果 reduction 不是 'none'(預設為 'mean'),則

    (x,y)={n=1NlnN,if reduction=‘mean’;n=1Nln,if reduction=‘sum’.\ell(x, y) = \begin{cases} \frac{\sum_{n=1}^N l_n}{N}, & \text{if reduction} = \text{`mean';}\\ \sum_{n=1}^N l_n, & \text{if reduction} = \text{`sum'.} \end{cases}

注意

請注意,當 target 包含類別索引時,此標準的效能通常更佳,因為這允許進行最佳化的計算。僅當每個 mini-batch 專案的單個類別標籤限制性太大時,才考慮將 target 提供為類別機率。

引數
  • weight (Tensor, optional) – 給每個類別手動分配的重縮放權重。如果給出,必須是大小為 C 的 Tensor,且資料型別為浮點型。

  • size_average (bool, optional) – 已廢棄(參見 reduction)。預設情況下,損失在 batch 中的每個損失元素上求平均。請注意,對於某些損失,每個樣本有多個元素。如果欄位 size_average 設定為 False,則改為對每個 mini-batch 的損失求和。當 reduceFalse 時忽略。預設值:True

  • ignore_index (int, optional) – 指定一個被忽略且不計入輸入梯度的目標值。當 size_averageTrue 時,損失在非忽略目標上求平均。請注意,ignore_index 僅適用於 target 包含類別索引的情況。

  • reduce (bool, optional) – 已廢棄(參見 reduction)。預設情況下,根據 size_average 的設定,損失在每個 mini-batch 的觀測值上求平均或求和。當 reduceFalse 時,轉而返回每個 batch 元素的損失,並忽略 size_average。預設值:True

  • reduction (str, optional) – 指定應用於輸出的歸約方式:'none' | 'mean' | 'sum''none':不應用歸約,'mean':對輸出求加權平均,'sum':對輸出求和。注意:size_averagereduce 正在被廢棄,同時,指定這兩個引數中的任何一個都將覆蓋 reduction 的設定。預設值:'mean'

  • label_smoothing (float, optional) – 一個在 [0.0, 1.0] 範圍內的浮點數。指定計算損失時平滑的數量,其中 0.0 表示不進行平滑。目標變成原始真實標籤和均勻分佈的混合,如論文 Rethinking the Inception Architecture for Computer Vision 中所述。預設值:0.00.0

形狀
  • 輸入: 形狀為 (C)(C)(N,C)(N, C)(N,C,d1,d2,...,dK)(N, C, d_1, d_2, ..., d_K),其中 K1K \geq 1 表示 K 維損失的情況。

  • 目標:如果包含類別索引,形狀為 ()()(N)(N)(N,d1,d2,...,dK)(N, d_1, d_2, ..., d_K),在 K 維損失的情況下,其中 K1K \geq 1 且每個值應介於 [0,C)[0, C)。使用類別索引時,目標資料型別必須為 long。如果包含類別機率,目標必須與輸入形狀相同,並且每個值應介於 [0,1][0, 1]。這意味著使用類別機率時,目標資料型別必須為 float。

  • 輸出:如果 reduction 為 ‘none’,形狀為 ()()(N)(N)(N,d1,d2,...,dK)(N, d_1, d_2, ..., d_K),在 K 維損失的情況下,其中 K1K \geq 1,取決於輸入的形狀。否則,為標量。

其中

C=類別數N=批次大小\begin{aligned} C ={} & \text{number of classes} \\ N ={} & \text{batch size} \\ \end{aligned}

示例

>>> # Example of target with class indices
>>> loss = nn.CrossEntropyLoss()
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(5)
>>> output = loss(input, target)
>>> output.backward()
>>>
>>> # Example of target with class probabilities
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.randn(3, 5).softmax(dim=1)
>>> output = loss(input, target)
>>> output.backward()

文件

訪問 PyTorch 的完整開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲取問題解答

檢視資源