快捷方式

多頭注意力¶

class torch.ao.nn.quantizable.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[source][source]
dequantize()[source][source]

將量化的 MHA 轉換回浮點格式的實用工具。

這樣做的動機是,將量化版本中使用的權重格式轉換回浮點格式並不容易。

forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False)[source][source]
注意:

更多資訊請參考 forward()

引數
  • query (Tensor) – 將查詢和一組鍵值對對映到輸出。更多詳細資訊請參閱“Attention Is All You Need”論文。

  • key (Tensor) – 將查詢和一組鍵值對對映到輸出。更多詳細資訊請參閱“Attention Is All You Need”論文。

  • value (Tensor) – 將查詢和一組鍵值對對映到輸出。更多詳細資訊請參閱“Attention Is All You Need”論文。

  • key_padding_mask (Optional[Tensor]) – 如果提供,鍵中指定的填充元素將被注意力機制忽略。當給定一個二值掩碼且值為 True 時,注意力層上的相應值將被忽略。

  • need_weights (bool) – 輸出 attn_output_weights。

  • attn_mask (Optional[Tensor]) – 用於阻止注意力關注特定位置的二維或三維掩碼。二維掩碼將廣播應用於所有批次,而三維掩碼允許為每個批次的條目指定不同的掩碼。

返回型別

tuple[torch.Tensor, Optional[torch.Tensor]]

形狀
  • 輸入

  • query: (L,N,E)(L, N, E) 其中 L 是目標序列長度,N 是批次大小,E 是嵌入維度。(N,L,E)(N, L, E) 如果 batch_firstTrue

  • key: (S,N,E)(S, N, E),其中 S 是源序列長度,N 是批次大小,E 是嵌入維度。(N,S,E)(N, S, E) 如果 batch_firstTrue

  • value: (S,N,E)(S, N, E) 其中 S 是源序列長度,N 是批次大小,E 是嵌入維度。(N,S,E)(N, S, E) 如果 batch_firstTrue

  • key_padding_mask: (N,S)(N, S) 其中 N 是批次大小,S 是源序列長度。如果提供了 BoolTensor,值為 True 的位置將被忽略,而值為 False 的位置保持不變。

  • attn_mask: 二維掩碼 (L,S)(L, S) 其中 L 是目標序列長度,S 是源序列長度。三維掩碼 (Nnumheads,L,S)(N*num_heads, L, S) 其中 N 是批次大小,L 是目標序列長度,S 是源序列長度。attn_mask 確保位置 i 能夠關注未被掩碼的位置。如果提供了 BoolTensor,值為 True 的位置不允許關注,而值為 False 的位置保持不變。如果提供了 FloatTensor,它將被新增到注意力權重中。

  • is_causal: 如果指定,將應用因果掩碼作為注意力掩碼。與提供 attn_mask 互斥。預設值:False

  • average_attn_weights: 如果為 True,表示返回的 attn_weights 應在所有注意力頭之間平均。否則,attn_weights 將按注意力頭單獨提供。請注意,此標誌僅在 need_weights=True 時有效。預設值:True(即在注意力頭之間平均權重)

  • 輸出

  • attn_output: (L,N,E)(L, N, E) 其中 L 是目標序列長度,N 是批次大小,E 是嵌入維度。(N,L,E)(N, L, E) 如果 batch_firstTrue

  • attn_output_weights: 如果 average_attn_weights=True,返回在注意力頭之間平均的注意力權重,形狀為 (N,L,S)(N, L, S),其中 N 是批次大小,L 是目標序列長度,S 是源序列長度。如果 average_attn_weights=False,返回每個注意力頭的注意力權重,形狀為 (N,numheads,L,S)(N, num_heads, L, S)

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發者資源並獲得問題解答

檢視資源