torch.einsum¶
- torch.einsum(equation, *operands) 張量[source][source]¶
根據愛因斯坦求和約定表示法,沿指定維度對輸入
operands的元素乘積求和。Einsum 允許使用基於愛因斯坦求和約定的簡寫格式(由
equation指定)來計算許多常見的多維線性代數陣列操作。這種格式的詳細資訊將在下面描述,但其總體思想是為輸入operands的每個維度標記一個下標,並定義哪些下標屬於輸出。然後,透過對operands元素沿著不屬於輸出的下標維度求積再求和,來計算輸出。例如,矩陣乘法可以使用 einsum 表示為 torch.einsum(“ij,jk->ik”, A, B)。這裡,j 是求和下標,i 和 k 是輸出下標(關於原因的更多詳細資訊請參見下面的章節)。方程表示式
equation字串指定了輸入operands的每個維度所對應的下標([a-zA-Z] 中的字母),順序與維度順序一致,使用逗號 (‘,’) 分隔每個運算元的下標,例如 ‘ij,jk’ 指定了兩個 2D 運算元的下標。標記有相同下標的維度必須是可廣播的,也就是說,它們的尺寸必須匹配或為 1。例外情況是,如果一個下標在同一個輸入運算元中重複出現,則此運算元中標記該下標的維度尺寸必須匹配,並且該運算元將沿這些維度被其對角線替換。在equation中只出現一次的下標將成為輸出的一部分,並按字母升序排列。輸出是透過將輸入operands元素逐個相乘(根據下標對齊維度),然後對不屬於輸出的下標維度求和計算得出的。此外,可以透過在方程末尾新增箭頭 (‘->’) 並跟隨輸出下標來顯式定義輸出下標。例如,以下方程計算矩陣乘積的轉置:‘ij,jk->ki’。輸出下標必須至少在某個輸入運算元中出現一次,且在輸出中最多出現一次。
可以使用省略號 (‘…’) 代替下標,以廣播省略號所覆蓋的維度。每個輸入運算元最多可以包含一個省略號,它將覆蓋未被下標覆蓋的維度,例如,對於一個 5 維的輸入運算元,方程 ‘ab…c’ 中的省略號覆蓋第三和第四維。省略號在不同運算元中不必覆蓋相同數量的維度,但省略號的“形狀”(它們覆蓋的維度尺寸)必須能夠一起廣播。如果未使用箭頭 (‘->’) 表示法顯式定義輸出,則省略號將首先出現在輸出中(最左邊的維度),然後才是輸入運算元中只出現一次的下標標籤。例如,以下方程實現了批次矩陣乘法 ‘…ij,…jk’。
最後幾點注意事項:方程中可以在不同元素(下標、省略號、箭頭和逗號)之間包含空格,但類似 ‘…’ 的寫法是無效的。空字串 ‘’ 對於標量運算元是有效的。
注意
torch.einsum對省略號 (‘…’) 的處理與 NumPy 不同,它允許對省略號覆蓋的維度進行求和,也就是說,省略號不強制要求成為輸出的一部分。注意
請安裝 opt-einsum (https://optimized-einsum.readthedocs.io/en/stable/) 以獲得性能更好的 einsum。您可以在安裝 torch 時一起安裝:pip install torch[opt-einsum],或者單獨安裝:pip install opt-einsum。
如果 opt-einsum 可用,此函式將透過我們的 opt_einsum 後端
torch.backends.opt_einsum(我知道 _ 和 - 容易混淆)最佳化收縮順序,從而自動加速計算和/或減少記憶體消耗。當輸入至少有三個時才會進行此最佳化,否則順序無關緊要。請注意,找到最優路徑是 NP 難問題,因此 opt-einsum 依賴於不同的啟發式方法來獲得接近最優的結果。如果 opt-einsum 不可用,預設順序是從左到右收縮。要繞過此預設行為,新增以下程式碼以停用 opt_einsum 並跳過路徑計算:
torch.backends.opt_einsum.enabled = False要指定 opt_einsum 計算收縮路徑的策略,新增以下程式碼:
torch.backends.opt_einsum.strategy = 'auto'。預設策略是 ‘auto’,我們也支援 ‘greedy’ 和 ‘optimal’。請注意,‘optimal’ 策略的執行時間是輸入數量的階乘!有關更多詳細資訊,請參閱 opt-einsum 文件 (https://optimized-einsum.readthedocs.io/en/stable/path_finding.html)。注意
自 PyTorch 1.10 起,
torch.einsum()也支援子列表格式(參見下面的示例)。在這種格式中,每個運算元的下標由子列表([0, 52) 範圍內的整數列表)指定。這些子列表跟在其運算元後面,並且可以在輸入末尾出現一個額外的子列表來指定輸出的下標,例如 torch.einsum(op1, sublist1, op2, sublist2, …, [subslist_out])。Python 的 Ellipsis 物件可以在子列表中提供,以實現上面方程表示式部分描述的廣播功能。示例
>>> # trace >>> torch.einsum('ii', torch.randn(4, 4)) tensor(-1.2104) >>> # diagonal >>> torch.einsum('ii->i', torch.randn(4, 4)) tensor([-0.1034, 0.7952, -0.2433, 0.4545]) >>> # outer product >>> x = torch.randn(5) >>> y = torch.randn(4) >>> torch.einsum('i,j->ij', x, y) tensor([[ 0.1156, -0.2897, -0.3918, 0.4963], [-0.3744, 0.9381, 1.2685, -1.6070], [ 0.7208, -1.8058, -2.4419, 3.0936], [ 0.1713, -0.4291, -0.5802, 0.7350], [ 0.5704, -1.4290, -1.9323, 2.4480]]) >>> # batch matrix multiplication >>> As = torch.randn(3, 2, 5) >>> Bs = torch.randn(3, 5, 4) >>> torch.einsum('bij,bjk->bik', As, Bs) tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], [-1.6706, -0.8097, -0.8025, -2.1183]], [[ 4.2239, 0.3107, -0.5756, -0.2354], [-1.4558, -0.3460, 1.5087, -0.8530]], [[ 2.8153, 1.8787, -4.3839, -1.2112], [ 0.3728, -2.1131, 0.0921, 0.8305]]]) >>> # with sublist format and ellipsis >>> torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2]) tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], [-1.6706, -0.8097, -0.8025, -2.1183]], [[ 4.2239, 0.3107, -0.5756, -0.2354], [-1.4558, -0.3460, 1.5087, -0.8530]], [[ 2.8153, 1.8787, -4.3839, -1.2112], [ 0.3728, -2.1131, 0.0921, 0.8305]]]) >>> # batch permute >>> A = torch.randn(2, 3, 4, 5) >>> torch.einsum('...ij->...ji', A).shape torch.Size([2, 3, 5, 4]) >>> # equivalent to torch.nn.functional.bilinear >>> A = torch.randn(3, 5, 4) >>> l = torch.randn(2, 5) >>> r = torch.randn(2, 4) >>> torch.einsum('bn,anm,bm->ba', l, A, r) tensor([[-0.3430, -5.2405, 0.4494], [ 0.3311, 5.5201, -3.0356]])