快捷方式

torch.einsum

torch.einsum(equation, *operands) 張量[source][source]

根據愛因斯坦求和約定表示法,沿指定維度對輸入 operands 的元素乘積求和。

Einsum 允許使用基於愛因斯坦求和約定的簡寫格式(由 equation 指定)來計算許多常見的多維線性代數陣列操作。這種格式的詳細資訊將在下面描述,但其總體思想是為輸入 operands 的每個維度標記一個下標,並定義哪些下標屬於輸出。然後,透過對 operands 元素沿著不屬於輸出的下標維度求積再求和,來計算輸出。例如,矩陣乘法可以使用 einsum 表示為 torch.einsum(“ij,jk->ik”, A, B)。這裡,j 是求和下標,i 和 k 是輸出下標(關於原因的更多詳細資訊請參見下面的章節)。

方程表示式

equation 字串指定了輸入 operands 的每個維度所對應的下標([a-zA-Z] 中的字母),順序與維度順序一致,使用逗號 (‘,’) 分隔每個運算元的下標,例如 ‘ij,jk’ 指定了兩個 2D 運算元的下標。標記有相同下標的維度必須是可廣播的,也就是說,它們的尺寸必須匹配或為 1。例外情況是,如果一個下標在同一個輸入運算元中重複出現,則此運算元中標記該下標的維度尺寸必須匹配,並且該運算元將沿這些維度被其對角線替換。在 equation 中只出現一次的下標將成為輸出的一部分,並按字母升序排列。輸出是透過將輸入 operands 元素逐個相乘(根據下標對齊維度),然後對不屬於輸出的下標維度求和計算得出的。

此外,可以透過在方程末尾新增箭頭 (‘->’) 並跟隨輸出下標來顯式定義輸出下標。例如,以下方程計算矩陣乘積的轉置:‘ij,jk->ki’。輸出下標必須至少在某個輸入運算元中出現一次,且在輸出中最多出現一次。

可以使用省略號 (‘…’) 代替下標,以廣播省略號所覆蓋的維度。每個輸入運算元最多可以包含一個省略號,它將覆蓋未被下標覆蓋的維度,例如,對於一個 5 維的輸入運算元,方程 ‘ab…c’ 中的省略號覆蓋第三和第四維。省略號在不同運算元中不必覆蓋相同數量的維度,但省略號的“形狀”(它們覆蓋的維度尺寸)必須能夠一起廣播。如果未使用箭頭 (‘->’) 表示法顯式定義輸出,則省略號將首先出現在輸出中(最左邊的維度),然後才是輸入運算元中只出現一次的下標標籤。例如,以下方程實現了批次矩陣乘法 ‘…ij,…jk’

最後幾點注意事項:方程中可以在不同元素(下標、省略號、箭頭和逗號)之間包含空格,但類似 ‘…’ 的寫法是無效的。空字串 ‘’ 對於標量運算元是有效的。

注意

torch.einsum 對省略號 (‘…’) 的處理與 NumPy 不同,它允許對省略號覆蓋的維度進行求和,也就是說,省略號不強制要求成為輸出的一部分。

注意

請安裝 opt-einsum (https://optimized-einsum.readthedocs.io/en/stable/) 以獲得性能更好的 einsum。您可以在安裝 torch 時一起安裝:pip install torch[opt-einsum],或者單獨安裝:pip install opt-einsum

如果 opt-einsum 可用,此函式將透過我們的 opt_einsum 後端 torch.backends.opt_einsum(我知道 _ 和 - 容易混淆)最佳化收縮順序,從而自動加速計算和/或減少記憶體消耗。當輸入至少有三個時才會進行此最佳化,否則順序無關緊要。請注意,找到最優路徑是 NP 難問題,因此 opt-einsum 依賴於不同的啟發式方法來獲得接近最優的結果。如果 opt-einsum 不可用,預設順序是從左到右收縮。

要繞過此預設行為,新增以下程式碼以停用 opt_einsum 並跳過路徑計算:torch.backends.opt_einsum.enabled = False

要指定 opt_einsum 計算收縮路徑的策略,新增以下程式碼:torch.backends.opt_einsum.strategy = 'auto'。預設策略是 ‘auto’,我們也支援 ‘greedy’ 和 ‘optimal’。請注意,‘optimal’ 策略的執行時間是輸入數量的階乘!有關更多詳細資訊,請參閱 opt-einsum 文件 (https://optimized-einsum.readthedocs.io/en/stable/path_finding.html)。

注意

自 PyTorch 1.10 起,torch.einsum() 也支援子列表格式(參見下面的示例)。在這種格式中,每個運算元的下標由子列表([0, 52) 範圍內的整數列表)指定。這些子列表跟在其運算元後面,並且可以在輸入末尾出現一個額外的子列表來指定輸出的下標,例如 torch.einsum(op1, sublist1, op2, sublist2, …, [subslist_out])。Python 的 Ellipsis 物件可以在子列表中提供,以實現上面方程表示式部分描述的廣播功能。

引數
  • equation (str) – 愛因斯坦求和的下標表達式。

  • operands (List[張量]) – 用於計算愛因斯坦求和的張量列表。

返回型別

張量

示例

>>> # trace
>>> torch.einsum('ii', torch.randn(4, 4))
tensor(-1.2104)

>>> # diagonal
>>> torch.einsum('ii->i', torch.randn(4, 4))
tensor([-0.1034,  0.7952, -0.2433,  0.4545])

>>> # outer product
>>> x = torch.randn(5)
>>> y = torch.randn(4)
>>> torch.einsum('i,j->ij', x, y)
tensor([[ 0.1156, -0.2897, -0.3918,  0.4963],
        [-0.3744,  0.9381,  1.2685, -1.6070],
        [ 0.7208, -1.8058, -2.4419,  3.0936],
        [ 0.1713, -0.4291, -0.5802,  0.7350],
        [ 0.5704, -1.4290, -1.9323,  2.4480]])

>>> # batch matrix multiplication
>>> As = torch.randn(3, 2, 5)
>>> Bs = torch.randn(3, 5, 4)
>>> torch.einsum('bij,bjk->bik', As, Bs)
tensor([[[-1.0564, -1.5904,  3.2023,  3.1271],
        [-1.6706, -0.8097, -0.8025, -2.1183]],

        [[ 4.2239,  0.3107, -0.5756, -0.2354],
        [-1.4558, -0.3460,  1.5087, -0.8530]],

        [[ 2.8153,  1.8787, -4.3839, -1.2112],
        [ 0.3728, -2.1131,  0.0921,  0.8305]]])

>>> # with sublist format and ellipsis
>>> torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2])
tensor([[[-1.0564, -1.5904,  3.2023,  3.1271],
        [-1.6706, -0.8097, -0.8025, -2.1183]],

        [[ 4.2239,  0.3107, -0.5756, -0.2354],
        [-1.4558, -0.3460,  1.5087, -0.8530]],

        [[ 2.8153,  1.8787, -4.3839, -1.2112],
        [ 0.3728, -2.1131,  0.0921,  0.8305]]])

>>> # batch permute
>>> A = torch.randn(2, 3, 4, 5)
>>> torch.einsum('...ij->...ji', A).shape
torch.Size([2, 3, 5, 4])

>>> # equivalent to torch.nn.functional.bilinear
>>> A = torch.randn(3, 5, 4)
>>> l = torch.randn(2, 5)
>>> r = torch.randn(2, 4)
>>> torch.einsum('bn,anm,bm->ba', l, A, r)
tensor([[-0.3430, -5.2405,  0.4494],
        [ 0.3311,  5.5201, -3.0356]])

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源