快捷方式

torch.nn.functional.embedding

torch.nn.functional.embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False)[source][source]

生成一個簡單的查詢表,用於在固定字典和大小中查詢嵌入。

此模組通常用於使用索引檢索詞嵌入。模組的輸入是索引列表和嵌入矩陣,輸出是相應的詞嵌入。

有關更多詳細資訊,請參閱 torch.nn.Embedding

注意

請注意,此函式相對於 weight 中由 padding_idx 指定的行的梯度的解析值預計與數值值不同。

注意

請注意,:class:`torch.nn.Embedding 與此函式的區別在於,它在構建時會將由 padding_idx 指定的 weight 行初始化為全零。

引數
  • input (LongTensor) – 包含嵌入矩陣索引的 Tensor

  • weight (Tensor) – 浮點型的嵌入矩陣,其行數等於最大可能索引 + 1,列數等於嵌入大小

  • padding_idx (int, 可選) – 如果指定,padding_idx 處的條目不參與梯度計算;因此,訓練期間不會更新 padding_idx 處的嵌入向量,即它保持為固定的“填充”。

  • max_norm (float, 可選) – 如果給定,每個範數大於 max_norm 的嵌入向量將被重新歸一化,使其範數等於 max_norm。注意:這將原地修改 weight

  • norm_type (float, 可選) – 用於 max_norm 選項的 p-範數中的 p。預設為 2

  • scale_grad_by_freq (bool, 可選) – 如果給定,這將根據 mini-batch 中詞語頻率的倒數來縮放梯度。預設為 False

  • sparse (bool, 可選) – 如果為 True,則相對於 weight 的梯度將是一個稀疏 tensor。有關稀疏梯度的更多詳細資訊,請參閱 torch.nn.Embedding 下的註釋。

返回型別

Tensor

形狀
  • Input: 任意形狀的 LongTensor,包含要提取的索引

  • Weight: 浮點型的嵌入矩陣,形狀為 (V, embedding_dim),其中 V = 最大索引 + 1,embedding_dim = 嵌入大小

  • Output: (*, embedding_dim),其中 * 是輸入的形狀

示例

>>> # a batch of 2 samples of 4 indices each
>>> input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]])
>>> # an embedding matrix containing 10 tensors of size 3
>>> embedding_matrix = torch.rand(10, 3)
>>> F.embedding(input, embedding_matrix)
tensor([[[ 0.8490,  0.9625,  0.6753],
         [ 0.9666,  0.7761,  0.6108],
         [ 0.6246,  0.9751,  0.3618],
         [ 0.4161,  0.2419,  0.7383]],

        [[ 0.6246,  0.9751,  0.3618],
         [ 0.0237,  0.7794,  0.0528],
         [ 0.9666,  0.7761,  0.6108],
         [ 0.3385,  0.8612,  0.1867]]])

>>> # example with padding_idx
>>> weights = torch.rand(10, 3)
>>> weights[0, :].zero_()
>>> embedding_matrix = weights
>>> input = torch.tensor([[0, 2, 0, 5]])
>>> F.embedding(input, embedding_matrix, padding_idx=0)
tensor([[[ 0.0000,  0.0000,  0.0000],
         [ 0.5609,  0.5384,  0.8720],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.6262,  0.2438,  0.7471]]])

文件

查閱 PyTorch 的完整開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源