快捷方式

GaussianNLLLoss

class torch.nn.GaussianNLLLoss(*, full=False, eps=1e-06, reduction='mean')[原始碼][原始碼]

高斯負對數似然損失。

目標被視為來自高斯分佈的樣本,其期望和方差由神經網路預測。對於一個被建模為具有高斯分佈的 target 張量,其中期望張量為 input,正方差張量為 var,損失計算如下:

loss=12(log(max(var, eps))+(inputtarget)2max(var, eps))+const.\text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var}, \ \text{eps}\right)\right) + \frac{\left(\text{input} - \text{target}\right)^2} {\text{max}\left(\text{var}, \ \text{eps}\right)}\right) + \text{const.}

其中 eps 用於穩定性。預設情況下,除非 fullTrue,否則損失函式中的常數項將被忽略。如果 var 的大小與 input 不同(由於同方差假設),則為了正確廣播,它必須要麼最後一個維度為 1,要麼維度比 input 少一個(並且所有其他維度的大小相同)。

引數
  • full (bool, optional) – 在損失計算中包含常數項。預設值:False

  • eps (float, optional) – 用於鉗制 var 的值(見下注),以提高穩定性。預設值:1e-6。

  • reduction (str, optional) – 指定要應用於輸出的縮減方式:'none' | 'mean' | 'sum''none':不應用任何縮減,'mean':輸出是所有批次成員損失的平均值,'sum':輸出是所有批次成員損失的總和。預設值:'mean'

形狀
  • 輸入:(N,)(N, *)()(*),其中 * 表示任意數量的附加維度

  • 目標:(N,)(N, *)()(*),形狀與輸入相同,或者形狀與輸入相同但有一個維度等於 1(以允許廣播)

  • 方差:(N,)(N, *)()(*),形狀與輸入相同,或者形狀與輸入相同但有一個維度等於 1,或者形狀比輸入少一個維度(以允許廣播),或者是一個標量值

  • 輸出:如果 reduction'mean'(預設)或 'sum',則為標量。如果 reduction'none',則為 (N,)(N, *),形狀與輸入相同

示例:
>>> loss = nn.GaussianNLLLoss()
>>> input = torch.randn(5, 2, requires_grad=True)
>>> target = torch.randn(5, 2)
>>> var = torch.ones(5, 2, requires_grad=True)  # heteroscedastic
>>> output = loss(input, target, var)
>>> output.backward()
>>> loss = nn.GaussianNLLLoss()
>>> input = torch.randn(5, 2, requires_grad=True)
>>> target = torch.randn(5, 2)
>>> var = torch.ones(5, 1, requires_grad=True)  # homoscedastic
>>> output = loss(input, target, var)
>>> output.backward()

注意

`var` 的鉗制操作對於 autograd 是忽略的,因此梯度不受其影響。

參考

Nix, D. A. and Weigend, A. S., “Estimating the mean and variance of the target probability distribution”, Proceedings of 1994 IEEE International Conference on Neural Networks (ICNN’94), Orlando, FL, USA, 1994, pp. 55-60 vol.1, doi: 10.1109/ICNN.1994.374138.

文件

訪問 PyTorch 全面的開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲取問題解答

檢視資源