多頭注意力¶
- class torch.ao.nn.quantizable.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[source][source]¶
- forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False)[source][source]¶
- 注意:
更多資訊請參考
forward()
- 引數
query (Tensor) – 將查詢和一組鍵值對對映到輸出。更多詳細資訊請參閱“Attention Is All You Need”論文。
key (Tensor) – 將查詢和一組鍵值對對映到輸出。更多詳細資訊請參閱“Attention Is All You Need”論文。
value (Tensor) – 將查詢和一組鍵值對對映到輸出。更多詳細資訊請參閱“Attention Is All You Need”論文。
key_padding_mask (Optional[Tensor]) – 如果提供,鍵中指定的填充元素將被注意力機制忽略。當給定一個二值掩碼且值為 True 時,注意力層上的相應值將被忽略。
need_weights (bool) – 輸出 attn_output_weights。
attn_mask (Optional[Tensor]) – 用於阻止注意力關注特定位置的二維或三維掩碼。二維掩碼將廣播應用於所有批次,而三維掩碼允許為每個批次的條目指定不同的掩碼。
- 返回型別
- 形狀
輸入
query: 其中 L 是目標序列長度,N 是批次大小,E 是嵌入維度。 如果
batch_first為True。key: ,其中 S 是源序列長度,N 是批次大小,E 是嵌入維度。 如果
batch_first為True。value: 其中 S 是源序列長度,N 是批次大小,E 是嵌入維度。 如果
batch_first為True。key_padding_mask: 其中 N 是批次大小,S 是源序列長度。如果提供了 BoolTensor,值為
True的位置將被忽略,而值為False的位置保持不變。attn_mask: 二維掩碼 其中 L 是目標序列長度,S 是源序列長度。三維掩碼 其中 N 是批次大小,L 是目標序列長度,S 是源序列長度。attn_mask 確保位置 i 能夠關注未被掩碼的位置。如果提供了 BoolTensor,值為
True的位置不允許關注,而值為False的位置保持不變。如果提供了 FloatTensor,它將被新增到注意力權重中。is_causal: 如果指定,將應用因果掩碼作為注意力掩碼。與提供 attn_mask 互斥。預設值:
False。average_attn_weights: 如果為 True,表示返回的
attn_weights應在所有注意力頭之間平均。否則,attn_weights將按注意力頭單獨提供。請注意,此標誌僅在need_weights=True時有效。預設值:True(即在注意力頭之間平均權重)輸出
attn_output: 其中 L 是目標序列長度,N 是批次大小,E 是嵌入維度。 如果
batch_first為True。attn_output_weights: 如果
average_attn_weights=True,返回在注意力頭之間平均的注意力權重,形狀為 ,其中 N 是批次大小,L 是目標序列長度,S 是源序列長度。如果average_attn_weights=False,返回每個注意力頭的注意力權重,形狀為 。