BCEWithLogitsLoss¶
- class torch.nn.BCEWithLogitsLoss(weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None)[source][source]¶
此損失結合了 Sigmoid 層和 BCELoss,整合在一個類中。透過將操作合併為一個層,利用 log-sum-exp 技巧提高了數值穩定性,因此此版本比簡單地順序使用 Sigmoid 和 BCELoss 更具數值穩定性。
未歸約的損失(即,當
reduction設定為'none'時)可以描述為其中 是批次大小。如果
reduction不是'none'(預設為'mean'),則這用於衡量例如自編碼器中的重構誤差。注意,目標 t[i] 應該是 0 到 1 之間的數。
透過為正例新增權重,可以權衡召回率和精確率。在多標籤分類的情況下,損失可以描述為
其中 是類別號(對於多標籤二分類 ,對於單標籤二分類 ), 是批次中樣本的編號, 是類別 的正例權重。
增加召回率, 增加精確率。
例如,如果一個數據集包含一個類別的 100 個正樣本和 300 個負樣本,則該類別的
pos_weight應等於 。損失函式的作用就如同資料集包含 個正樣本一樣。示例
>>> target = torch.ones([10, 64], dtype=torch.float32) # 64 classes, batch size = 10 >>> output = torch.full([10, 64], 1.5) # A prediction (logit) >>> pos_weight = torch.ones([64]) # All weights are equal to 1 >>> criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight) >>> criterion(output, target) # -log(sigmoid(1.5)) tensor(0.20...)
在上面的例子中,
pos_weight張量中的元素對應於多標籤二分類場景中的 64 個不同類別。pos_weight中的每個元素都用於根據對應類別的負樣本和正樣本之間的不平衡來調整損失函式。這種方法在類別不平衡程度不同的資料集中非常有用,可確保損失計算準確地考慮每個類別的分佈。- 引數
weight (Tensor, 可選) – 對每個批次元素損失的手動重新縮放權重。如果給定,必須是大小為 nbatch 的 Tensor。
size_average (bool, 可選) – 已廢棄(請參閱
reduction)。預設情況下,損失會在批次中的每個損失元素上進行平均。注意,對於某些損失,每個樣本有多個元素。如果欄位size_average設定為False,則損失將改為對每個小批次求和。當reduce為False時忽略。預設值:Truereduce (bool, 可選) – 已廢棄(請參閱
reduction)。預設情況下,損失會根據size_average對每個小批次的觀測值進行平均或求和。當reduce為False時,將返回每個批次元素的損失,並忽略size_average。預設值:Truereduction (str, 可選) – 指定要應用於輸出的歸約方法:
'none'|'mean'|'sum'。'none':不應用歸約;'mean':輸出的總和將除以輸出中的元素數量;'sum':輸出將被求和。注意:size_average和reduce正在廢棄中,同時指定這兩個引數中的任何一個都會覆蓋reduction。預設值:'mean'pos_weight (Tensor, 可選) – 用於與目標張量進行廣播運算的正樣本權重。它必須是一個在類別維度上與類別數量大小相等的張量。請密切注意 PyTorch 的廣播語義,以實現所需的操作。對於大小為 [B, C, H, W] 的目標張量(其中 B 為批次大小),大小為 [B, C, H, W] 的 pos_weight 會對批次的每個元素應用不同的 pos_weights,而大小為 [C, H, W] 的 pos_weight 則會在整個批次中應用相同的 pos_weights。要對 2D 多類別目標 [C, H, W] 在所有空間維度上應用相同的正樣本權重,請使用:[C, 1, 1]。預設值:
None
- 形狀
輸入:,其中 表示任意數量的維度。
目標:,與輸入形狀相同。
輸出:標量。如果
reduction為'none',則為 ,與輸入形狀相同。
示例
>>> loss = nn.BCEWithLogitsLoss() >>> input = torch.randn(3, requires_grad=True) >>> target = torch.empty(3).random_(2) >>> output = loss(input, target) >>> output.backward()