快捷方式

BCELoss

class torch.nn.BCELoss(weight=None, size_average=None, reduce=None, reduction='mean')[source][source]

建立一個衡量目標和輸入機率之間的二元交叉熵(Binary Cross Entropy)的準則。

未降維(即 reduction 設定為 'none')的損失可以描述為

(x,y)=L={l1,,lN},ln=wn[ynlogxn+(1yn)log(1xn)],\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad l_n = - w_n \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right],

其中 NN 是批處理大小。如果 reduction 不是 'none'(預設值為 'mean'),則

(x,y)={mean(L),if reduction=‘mean’;sum(L),if reduction=‘sum’.\ell(x, y) = \begin{cases} \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} \end{cases}

這用於衡量例如自編碼器中的重構誤差。請注意,目標 yy 的值應介於 0 和 1 之間。

請注意,如果 xnx_n 的值為 0 或 1,則上述損失公式中的對數項之一在數學上是未定義的。PyTorch 選擇將 log(0)\log (0) 設定為 -\infty,因為 limx0log(x)=\lim_{x\to 0} \log (x) = -\infty。然而,損失公式中的無限項有幾個不希望出現的原因。

首先,如果 yn=0y_n = 0(1yn)=0(1 - y_n) = 0,我們將得到 0 乘以無窮大。其次,如果損失值為無限大,那麼我們的梯度中也會有一個無限項,因為 limx0ddxlog(x)=\lim_{x\to 0} \frac{d}{dx} \log (x) = \infty。這將導致 BCELoss 的 backward 方法相對於 xnx_n 變成非線性,並且將其用於諸如線性迴歸之類的任務將不那麼直接。

我們的解決方案是 BCELoss 將其對數函式的輸出鉗位到大於等於 -100。這樣,我們總能得到有限的損失值和線性的 backward 方法。

引數
  • weight (Tensor, optional) – 手動調整每個批處理元素損失權重的引數。如果給定,必須是一個大小為 nbatch 的 Tensor。

  • size_average (bool, optional) – 已棄用(請參閱 reduction)。預設情況下,損失在批處理中的每個損失元素上進行平均。請注意,對於某些損失,每個樣本有多個元素。如果欄位 size_average 設定為 False,則損失將改為對每個 minibatch 進行求和。當 reduceFalse 時忽略此引數。預設值:True

  • reduce (bool, optional) – 已棄用(請參閱 reduction)。預設情況下,損失會根據 size_average 對每個 minibatch 的觀察值進行平均或求和。當 reduceFalse 時,將返回每個批處理元素的損失,並忽略 size_average。預設值:True

  • reduction (str, optional) – 指定應用於輸出的歸約方式:'none' | 'mean' | 'sum''none': 不應用歸約;'mean': 輸出的總和將除以輸出中的元素數量;'sum': 輸出將被求和。注意:size_averagereduce 正在被棄用,在此期間,指定這兩個引數中的任何一個都將覆蓋 reduction。預設值: 'mean'

形狀
  • 輸入: ()(*),其中 * 表示任意維數。

  • 目標: ()(*),與輸入形狀相同。

  • 輸出: 標量。如果 reduction'none',則為 ()(*),與輸入形狀相同。

示例

>>> m = nn.Sigmoid()
>>> loss = nn.BCELoss()
>>> input = torch.randn(3, 2, requires_grad=True)
>>> target = torch.rand(3, 2, requires_grad=False)
>>> output = loss(m(input), target)
>>> output.backward()

文件

查閱 PyTorch 的綜合開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源