命名張量運算元覆蓋範圍¶
請先閱讀 命名張量 以獲取命名張量介紹。
本文件是 名稱推斷 的參考資料,這個過程定義了命名張量如何
使用名稱來提供額外的自動執行時正確性檢查
從輸入張量傳播名稱到輸出張量
下面列出了命名張量支援的所有操作及其相關的名稱推斷規則。
如果您在此處未找到某個操作,但它對您的用例有幫助,請搜尋是否已有相關問題被提交,如果沒有,提交一個問題。
警告
命名張量 API 是實驗性的,未來可能會發生變化。
API |
名稱推斷規則 |
|---|---|
參見文件 |
|
參見文件 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
|
無 |
|
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
參見文件 |
|
無 |
|
無 |
|
|
無 |
無 |
|
|
參見文件 |
|
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
|
無 |
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
|
|
無 |
|
將掩碼與輸入對齊,然後合併來自輸入張量的名稱 |
|
參見文件 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
|
無 |
無 |
|
無 |
|
參見文件 |
|
無 |
|
無 |
|
參見文件 |
|
參見文件 |
|
無 |
|
無 |
|
只允許不改變形狀的 resize 操作 |
|
只允許不改變形狀的 resize 操作 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
無 |
|
參見文件 |
|
無 |
|
無 |
|
保留輸入名稱¶
所有逐點一元函式以及其他一些一元函式都遵循此規則。
檢查名稱:無
傳播名稱:將輸入張量的名稱傳播到輸出張量。
>>> x = torch.randn(3, 3, names=('N', 'C'))
>>> x.abs().names
('N', 'C')
移除維度¶
所有歸約操作(如 sum())會透過在指定維度上進行歸約來移除維度。其他操作(如 select() 和 squeeze())也會移除維度。
在任何可以向運算子傳遞整數維度索引的地方,也可以傳遞維度名稱。接受維度索引列表的函式也可以接受維度名稱列表。
檢查名稱:如果將
dim或dims作為名稱列表傳入,檢查這些名稱是否存在於self中。傳播名稱:如果輸入張量中由
dim或dims指定的維度不存在於輸出張量中,則這些維度的對應名稱不會出現在output.names中。
>>> x = torch.randn(1, 3, 3, 3, names=('N', 'C', 'H', 'W'))
>>> x.squeeze('N').names
('C', 'H', 'W')
>>> x = torch.randn(3, 3, 3, 3, names=('N', 'C', 'H', 'W'))
>>> x.sum(['N', 'C']).names
('H', 'W')
# Reduction ops with keepdim=True don't actually remove dimensions.
>>> x = torch.randn(3, 3, 3, 3, names=('N', 'C', 'H', 'W'))
>>> x.sum(['N', 'C'], keepdim=True).names
('N', 'C', 'H', 'W')
統一輸入名稱¶
所有二元算術操作都遵循此規則。廣播操作仍然從右側按位置廣播,以保留與未命名張量的相容性。若要按名稱執行顯式廣播,請使用 Tensor.align_as()。
檢查名稱:所有名稱必須從右側按位置匹配。例如,在
tensor + other中,對於(-min(tensor.dim(), other.dim()) + 1, -1]中的所有i,match(tensor.names[i], other.names[i])必須為真。檢查名稱:此外,所有已命名維度必須從右側對齊。在匹配過程中,如果我們將命名維度
A與未命名維度None匹配,則A不得出現在帶有未命名維度的張量中。傳播名稱:從右側統一兩個張量中的名稱對,以生成輸出名稱。
例如,
# tensor: Tensor[ N, None]
# other: Tensor[None, C]
>>> tensor = torch.randn(3, 3, names=('N', None))
>>> other = torch.randn(3, 3, names=(None, 'C'))
>>> (tensor + other).names
('N', 'C')
檢查名稱
match(tensor.names[-1], other.names[-1])isTruematch(tensor.names[-2], tensor.names[-2])isTrue因為我們將
tensor中的None與'C'匹配,請檢查確保'C'不存在於tensor中(它確實不存在)。檢查確保
'N'不存在於other中(它確實不存在)。
最後,輸出名稱透過 [unify('N', None), unify(None, 'C')] = ['N', 'C'] 計算得出。
更多示例
# Dimensions don't match from the right:
# tensor: Tensor[N, C]
# other: Tensor[ N]
>>> tensor = torch.randn(3, 3, names=('N', 'C'))
>>> other = torch.randn(3, names=('N',))
>>> (tensor + other).names
RuntimeError: Error when attempting to broadcast dims ['N', 'C'] and dims
['N']: dim 'C' and dim 'N' are at the same position from the right but do
not match.
# Dimensions aren't aligned when matching tensor.names[-1] and other.names[-1]:
# tensor: Tensor[N, None]
# other: Tensor[ N]
>>> tensor = torch.randn(3, 3, names=('N', None))
>>> other = torch.randn(3, names=('N',))
>>> (tensor + other).names
RuntimeError: Misaligned dims when attempting to broadcast dims ['N'] and
dims ['N', None]: dim 'N' appears in a different position from the right
across both lists.
注意
在最後兩個示例中,都可以按名稱對齊張量,然後執行加法。使用 Tensor.align_as() 按名稱對齊張量,或使用 Tensor.align_to() 將張量對齊到自定義的維度順序。
置換維度¶
一些操作,如 Tensor.t(),會置換維度順序。維度名稱附加到各個維度上,因此它們也會被置換。
如果運算子接受位置索引 dim,它也可以接受維度名稱作為 dim。
檢查名稱:如果將
dim作為名稱傳遞,檢查它是否存在於張量中。傳播名稱:以與被置換維度相同的方式置換維度名稱。
>>> x = torch.randn(3, 3, names=('N', 'C'))
>>> x.transpose('N', 'C').names
('C', 'N')
收縮維度¶
矩陣乘法函式遵循此規則的某種變體。我們首先來看 torch.mm(),然後概括批次矩陣乘法的規則。
對於 torch.mm(tensor, other)
檢查名稱:無
傳播名稱:結果名稱為
(tensor.names[-2], other.names[-1])。
>>> x = torch.randn(3, 3, names=('N', 'D'))
>>> y = torch.randn(3, 3, names=('in', 'out'))
>>> x.mm(y).names
('N', 'out')
本質上,矩陣乘法在兩個維度上執行點積,並將它們摺疊。當兩個張量進行矩陣乘法時,被收縮的維度會消失,並且不會出現在輸出張量中。
torch.mv()、torch.dot() 的工作方式類似:名稱推斷不檢查輸入名稱,並移除參與點積的維度。
>>> x = torch.randn(3, 3, names=('N', 'D'))
>>> y = torch.randn(3, names=('something',))
>>> x.mv(y).names
('N',)
現在,我們來看看 torch.matmul(tensor, other)。假設 tensor.dim() >= 2 且 other.dim() >= 2。
檢查名稱:檢查輸入的批次維度是否對齊且可廣播。關於輸入對齊的含義,請參閱統一輸入名稱。
傳播名稱:結果名稱透過統一批次維度和移除被收縮維度獲得:
unify(tensor.names[:-2], other.names[:-2]) + (tensor.names[-2], other.names[-1])。
示例
# Batch matrix multiply of matrices Tensor['C', 'D'] and Tensor['E', 'F'].
# 'A', 'B' are batch dimensions.
>>> x = torch.randn(3, 3, 3, 3, names=('A', 'B', 'C', 'D'))
>>> y = torch.randn(3, 3, 3, names=('B', 'E', 'F'))
>>> torch.matmul(x, y).names
('A', 'B', 'C', 'F')
最後,許多矩陣乘法函式都有融合的 add 版本。例如,addmm() 和 addmv()。這些函式被視為組合了例如 mm() 的名稱推斷和 add() 的名稱推斷。
工廠函式¶
工廠函式現在接受一個新的 names 引數,該引數為每個維度關聯一個名稱。
>>> torch.zeros(2, 3, names=('N', 'C'))
tensor([[0., 0., 0.],
[0., 0., 0.]], names=('N', 'C'))
out 函式和原地(in-place)變體¶
指定為 out= 張量的行為如下:
如果它沒有命名維度,則從操作中計算出的名稱會傳播給它。
如果它有任何命名維度,則從操作中計算出的名稱必須與現有名稱完全一致。否則,操作會報錯。
所有原地(in-place)方法都會修改輸入,使其名稱等於名稱推斷計算出的名稱。例如
>>> x = torch.randn(3, 3)
>>> y = torch.randn(3, 3, names=('N', 'C'))
>>> x.names
(None, None)
>>> x += y
>>> x.names
('N', 'C')