快捷方式

torch.nn.functional.gumbel_softmax

torch.nn.functional.gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1)[source][source]

從 Gumbel-Softmax 分佈(連結 1 連結 2)中取樣,並可選擇進行離散化。

引數
  • logits (Tensor) – […, num_features] 未歸一化的對數機率

  • tau (float) – 非負標量溫度

  • hard (bool) – 如果為 True,返回的樣本將被離散化為 one-hot 向量,但在 autograd 中會按軟樣本進行微分

  • dim (int) – 將計算 softmax 的維度。預設為 -1。

返回

從 Gumbel-Softmax 分佈中取樣的張量,形狀與 logits 相同。如果 hard=True,返回的樣本將是 one-hot 向量;否則,它們將是在 dim 維度上求和為 1 的機率分佈。

返回型別

Tensor

注意

此函數出於歷史原因保留,未來可能從 nn.Functional 中移除。

注意

hard 的主要技巧是執行 y_hard - y_soft.detach() + y_soft

這實現了兩件事: - 使輸出值精確為 one-hot(因為我們先加後減了 y_soft 的值) - 使梯度等於 y_soft 的梯度(因為我們去除了所有其他梯度)

示例:
>>> logits = torch.randn(20, 32)
>>> # Sample soft categorical using reparametrization trick:
>>> F.gumbel_softmax(logits, tau=1, hard=False)
>>> # Sample hard categorical using "Straight-through" trick:
>>> F.gumbel_softmax(logits, tau=1, hard=True)

文件

訪問 PyTorch 的綜合開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並解答您的問題

檢視資源