GaussianNLLLoss¶
- class torch.nn.GaussianNLLLoss(*, full=False, eps=1e-06, reduction='mean')[原始碼][原始碼]¶
高斯負對數似然損失。
目標被視為來自高斯分佈的樣本,其期望和方差由神經網路預測。對於一個被建模為具有高斯分佈的
target張量,其中期望張量為input,正方差張量為var,損失計算如下:其中
eps用於穩定性。預設情況下,除非full為True,否則損失函式中的常數項將被忽略。如果var的大小與input不同(由於同方差假設),則為了正確廣播,它必須要麼最後一個維度為 1,要麼維度比input少一個(並且所有其他維度的大小相同)。- 引數
- 形狀
輸入: 或 ,其中 表示任意數量的附加維度
目標: 或 ,形狀與輸入相同,或者形狀與輸入相同但有一個維度等於 1(以允許廣播)
方差: 或 ,形狀與輸入相同,或者形狀與輸入相同但有一個維度等於 1,或者形狀比輸入少一個維度(以允許廣播),或者是一個標量值
輸出:如果
reduction為'mean'(預設)或'sum',則為標量。如果reduction為'none',則為 ,形狀與輸入相同
- 示例:
>>> loss = nn.GaussianNLLLoss() >>> input = torch.randn(5, 2, requires_grad=True) >>> target = torch.randn(5, 2) >>> var = torch.ones(5, 2, requires_grad=True) # heteroscedastic >>> output = loss(input, target, var) >>> output.backward()
>>> loss = nn.GaussianNLLLoss() >>> input = torch.randn(5, 2, requires_grad=True) >>> target = torch.randn(5, 2) >>> var = torch.ones(5, 1, requires_grad=True) # homoscedastic >>> output = loss(input, target, var) >>> output.backward()
注意
`var` 的鉗制操作對於 autograd 是忽略的,因此梯度不受其影響。
- 參考
Nix, D. A. and Weigend, A. S., “Estimating the mean and variance of the target probability distribution”, Proceedings of 1994 IEEE International Conference on Neural Networks (ICNN’94), Orlando, FL, USA, 1994, pp. 55-60 vol.1, doi: 10.1109/ICNN.1994.374138.