快捷方式

CausalVariant

class torch.nn.attention.bias.CausalVariant(value)[source][source]

用於注意力機制中的因果變體列舉。

定義了兩種型別的因果偏置

UPPER_LEFT:表示標準因果注意力機制中的左上三角偏置。構建此偏置的等效 PyTorch 程式碼是

torch.tril(torch.ones(size, dtype=torch.bool))

例如,當 shape=(3,4) 時,具體化的偏置張量將是

[[1, 0, 0, 0],
 [1, 1, 0, 0],
 [1, 1, 1, 0]]

LOWER_RIGHT:表示右下三角偏置,包含的值與矩陣的右下角對齊。

構建此偏置的等效 PyTorch 程式碼是

diagonal_offset = size[1] - size[0]
torch.tril(
    torch.ones(size, dtype=torch.bool),
    diagonal=diagonal_offset,
)

例如,當 shape=(3,4) 時,具體化的偏置張量將是

[[1, 1, 0, 0],
 [1, 1, 1, 0],
 [1, 1, 1, 1]]

請注意,當查詢和鍵/值張量的序列長度相等時,這些變體是等效的,因為此時三角矩陣是方陣。

警告

此列舉為原型,未來可能會發生變化。

文件

查閱 PyTorch 全面的開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並解答您的疑問

檢視資源