CausalVariant¶
- class torch.nn.attention.bias.CausalVariant(value)[source][source]¶
用於注意力機制中的因果變體列舉。
定義了兩種型別的因果偏置
UPPER_LEFT:表示標準因果注意力機制中的左上三角偏置。構建此偏置的等效 PyTorch 程式碼是
torch.tril(torch.ones(size, dtype=torch.bool))
例如,當 shape=(3,4) 時,具體化的偏置張量將是
[[1, 0, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0]]
LOWER_RIGHT:表示右下三角偏置,包含的值與矩陣的右下角對齊。
構建此偏置的等效 PyTorch 程式碼是
diagonal_offset = size[1] - size[0] torch.tril( torch.ones(size, dtype=torch.bool), diagonal=diagonal_offset, )
例如,當 shape=(3,4) 時,具體化的偏置張量將是
[[1, 1, 0, 0], [1, 1, 1, 0], [1, 1, 1, 1]]
請注意,當查詢和鍵/值張量的序列長度相等時,這些變體是等效的,因為此時三角矩陣是方陣。
警告
此列舉為原型,未來可能會發生變化。