• 文件 >
  • 命名張量運算元覆蓋範圍
快捷方式

命名張量運算元覆蓋範圍

請先閱讀 命名張量 以獲取命名張量介紹。

本文件是 名稱推斷 的參考資料,這個過程定義了命名張量如何

  1. 使用名稱來提供額外的自動執行時正確性檢查

  2. 從輸入張量傳播名稱到輸出張量

下面列出了命名張量支援的所有操作及其相關的名稱推斷規則。

如果您在此處未找到某個操作,但它對您的用例有幫助,請搜尋是否已有相關問題被提交,如果沒有,提交一個問題

警告

命名張量 API 是實驗性的,未來可能會發生變化。

支援的操作

API

名稱推斷規則

Tensor.abs(), torch.abs()

保留輸入名稱

Tensor.abs_()

保留輸入名稱

Tensor.acos(), torch.acos()

保留輸入名稱

Tensor.acos_()

保留輸入名稱

Tensor.add(), torch.add()

合併來自輸入的名稱

Tensor.add_()

合併來自輸入的名稱

Tensor.addmm(), torch.addmm()

收縮維度

Tensor.addmm_()

收縮維度

Tensor.addmv(), torch.addmv()

收縮維度

Tensor.addmv_()

收縮維度

Tensor.align_as()

參見文件

Tensor.align_to()

參見文件

Tensor.all(), torch.all()

Tensor.any(), torch.any()

Tensor.asin(), torch.asin()

保留輸入名稱

Tensor.asin_()

保留輸入名稱

Tensor.atan(), torch.atan()

保留輸入名稱

Tensor.atan2(), torch.atan2()

合併來自輸入的名稱

Tensor.atan2_()

合併來自輸入的名稱

Tensor.atan_()

保留輸入名稱

Tensor.bernoulli(), torch.bernoulli()

保留輸入名稱

Tensor.bernoulli_()

Tensor.bfloat16()

保留輸入名稱

Tensor.bitwise_not(), torch.bitwise_not()

保留輸入名稱

Tensor.bitwise_not_()

Tensor.bmm(), torch.bmm()

收縮維度

Tensor.bool()

保留輸入名稱

Tensor.byte()

保留輸入名稱

torch.cat()

合併來自輸入的名稱

Tensor.cauchy_()

Tensor.ceil(), torch.ceil()

保留輸入名稱

Tensor.ceil_()

Tensor.char()

保留輸入名稱

Tensor.chunk(), torch.chunk()

保留輸入名稱

Tensor.clamp(), torch.clamp()

保留輸入名稱

Tensor.clamp_()

Tensor.copy_()

out 函式和原地 (in-place) 變體

Tensor.cos(), torch.cos()

保留輸入名稱

Tensor.cos_()

Tensor.cosh(), torch.cosh()

保留輸入名稱

Tensor.cosh_()

Tensor.acosh(), torch.acosh()

保留輸入名稱

Tensor.acosh_()

Tensor.cpu()

保留輸入名稱

Tensor.cuda()

保留輸入名稱

Tensor.cumprod(), torch.cumprod()

保留輸入名稱

Tensor.cumsum(), torch.cumsum()

保留輸入名稱

Tensor.data_ptr()

Tensor.deg2rad(), torch.deg2rad()

保留輸入名稱

Tensor.deg2rad_()

Tensor.detach(), torch.detach()

保留輸入名稱

Tensor.detach_()

Tensor.device, torch.device()

Tensor.digamma(), torch.digamma()

保留輸入名稱

Tensor.digamma_()

Tensor.dim()

Tensor.div(), torch.div()

合併來自輸入的名稱

Tensor.div_()

合併來自輸入的名稱

Tensor.dot(), torch.dot()

Tensor.double()

保留輸入名稱

Tensor.element_size()

torch.empty()

工廠函式

torch.empty_like()

工廠函式

Tensor.eq(), torch.eq()

合併來自輸入的名稱

Tensor.erf(), torch.erf()

保留輸入名稱

Tensor.erf_()

Tensor.erfc(), torch.erfc()

保留輸入名稱

Tensor.erfc_()

Tensor.erfinv(), torch.erfinv()

保留輸入名稱

Tensor.erfinv_()

Tensor.exp(), torch.exp()

保留輸入名稱

Tensor.exp_()

Tensor.expand()

保留輸入名稱

Tensor.expm1(), torch.expm1()

保留輸入名稱

Tensor.expm1_()

Tensor.exponential_()

Tensor.fill_()

Tensor.flatten(), torch.flatten()

參見文件

Tensor.float()

保留輸入名稱

Tensor.floor(), torch.floor()

保留輸入名稱

Tensor.floor_()

Tensor.frac(), torch.frac()

保留輸入名稱

Tensor.frac_()

Tensor.ge(), torch.ge()

合併來自輸入的名稱

Tensor.get_device(), torch.get_device()

Tensor.grad

Tensor.gt(), torch.gt()

合併來自輸入的名稱

Tensor.half()

保留輸入名稱

Tensor.has_names()

參見文件

Tensor.index_fill(), torch.index_fill()

保留輸入名稱

Tensor.index_fill_()

Tensor.int()

保留輸入名稱

Tensor.is_contiguous()

Tensor.is_cuda

Tensor.is_floating_point(), torch.is_floating_point()

Tensor.is_leaf

Tensor.is_pinned()

Tensor.is_shared()

Tensor.is_signed(), torch.is_signed()

Tensor.is_sparse

Tensor.is_sparse_csr

torch.is_tensor()

Tensor.item()

Tensor.itemsize

Tensor.kthvalue(), torch.kthvalue()

移除維度

Tensor.le(), torch.le()

合併來自輸入的名稱

Tensor.log(), torch.log()

保留輸入名稱

Tensor.log10(), torch.log10()

保留輸入名稱

Tensor.log10_()

Tensor.log1p(), torch.log1p()

保留輸入名稱

Tensor.log1p_()

Tensor.log2(), torch.log2()

保留輸入名稱

Tensor.log2_()

Tensor.log_()

Tensor.log_normal_()

Tensor.logical_not(), torch.logical_not()

保留輸入名稱

Tensor.logical_not_()

Tensor.logsumexp(), torch.logsumexp()

移除維度

Tensor.long()

保留輸入名稱

Tensor.lt(), torch.lt()

合併來自輸入的名稱

torch.manual_seed()

Tensor.masked_fill(), torch.masked_fill()

保留輸入名稱

Tensor.masked_fill_()

Tensor.masked_select(), torch.masked_select()

將掩碼與輸入對齊,然後合併來自輸入張量的名稱

Tensor.matmul(), torch.matmul()

收縮維度

Tensor.mean(), torch.mean()

移除維度

Tensor.median(), torch.median()

移除維度

Tensor.nanmedian(), torch.nanmedian()

移除維度

Tensor.mm(), torch.mm()

收縮維度

Tensor.mode(), torch.mode()

移除維度

Tensor.mul(), torch.mul()

合併來自輸入的名稱

Tensor.mul_()

合併來自輸入的名稱

Tensor.mv(), torch.mv()

收縮維度

Tensor.names

參見文件

Tensor.narrow(), torch.narrow()

保留輸入名稱

Tensor.nbytes

Tensor.ndim

Tensor.ndimension()

Tensor.ne(), torch.ne()

合併來自輸入的名稱

Tensor.neg(), torch.neg()

保留輸入名稱

Tensor.neg_()

torch.normal()

保留輸入名稱

Tensor.normal_()

Tensor.numel(), torch.numel()

torch.ones()

工廠函式

Tensor.pow(), torch.pow()

合併來自輸入的名稱

Tensor.pow_()

Tensor.prod(), torch.prod()

移除維度

Tensor.rad2deg(), torch.rad2deg()

保留輸入名稱

Tensor.rad2deg_()

torch.rand()

工廠函式

torch.rand()

工廠函式

torch.randn()

工廠函式

torch.randn()

工廠函式

Tensor.random_()

Tensor.reciprocal(), torch.reciprocal()

保留輸入名稱

Tensor.reciprocal_()

Tensor.refine_names()

參見文件

Tensor.register_hook()

Tensor.register_post_accumulate_grad_hook()

Tensor.rename()

參見文件

Tensor.rename_()

參見文件

Tensor.requires_grad

Tensor.requires_grad_()

Tensor.resize_()

只允許不改變形狀的 resize 操作

Tensor.resize_as_()

只允許不改變形狀的 resize 操作

Tensor.round(), torch.round()

保留輸入名稱

Tensor.round_()

Tensor.rsqrt(), torch.rsqrt()

保留輸入名稱

Tensor.rsqrt_()

Tensor.select(), torch.select()

移除維度

Tensor.short()

保留輸入名稱

Tensor.sigmoid(), torch.sigmoid()

保留輸入名稱

Tensor.sigmoid_()

Tensor.sign(), torch.sign()

保留輸入名稱

Tensor.sign_()

Tensor.sgn(), torch.sgn()

保留輸入名稱

Tensor.sgn_()

Tensor.sin(), torch.sin()

保留輸入名稱

Tensor.sin_()

Tensor.sinh(), torch.sinh()

保留輸入名稱

Tensor.sinh_()

Tensor.asinh(), torch.asinh()

保留輸入名稱

Tensor.asinh_()

Tensor.size()

Tensor.softmax(), torch.softmax()

保留輸入名稱

Tensor.split(), torch.split()

保留輸入名稱

Tensor.sqrt(), torch.sqrt()

保留輸入名稱

Tensor.sqrt_()

Tensor.squeeze(), torch.squeeze()

移除維度

Tensor.std(), torch.std()

移除維度

torch.std_mean()

移除維度

Tensor.stride()

Tensor.sub(), torch.sub()

合併來自輸入的名稱

Tensor.sub_()

合併來自輸入的名稱

Tensor.sum(), torch.sum()

移除維度

Tensor.tan(), torch.tan()

保留輸入名稱

Tensor.tan_()

Tensor.tanh(), torch.tanh()

保留輸入名稱

Tensor.tanh_()

Tensor.atanh(), torch.atanh()

保留輸入名稱

Tensor.atanh_()

torch.tensor()

工廠函式

Tensor.to()

保留輸入名稱

Tensor.topk(), torch.topk()

移除維度

Tensor.transpose(), torch.transpose()

置換維度

Tensor.trunc(), torch.trunc()

保留輸入名稱

Tensor.trunc_()

Tensor.type()

Tensor.type_as()

保留輸入名稱

Tensor.unbind(), torch.unbind()

移除維度

Tensor.unflatten()

參見文件

Tensor.uniform_()

Tensor.var(), torch.var()

移除維度

torch.var_mean()

移除維度

Tensor.zero_()

torch.zeros()

工廠函式

保留輸入名稱

所有逐點一元函式以及其他一些一元函式都遵循此規則。

  • 檢查名稱:無

  • 傳播名稱:將輸入張量的名稱傳播到輸出張量。

>>> x = torch.randn(3, 3, names=('N', 'C'))
>>> x.abs().names
('N', 'C')

移除維度

所有歸約操作(如 sum())會透過在指定維度上進行歸約來移除維度。其他操作(如 select()squeeze())也會移除維度。

在任何可以向運算子傳遞整數維度索引的地方,也可以傳遞維度名稱。接受維度索引列表的函式也可以接受維度名稱列表。

  • 檢查名稱:如果將 dimdims 作為名稱列表傳入,檢查這些名稱是否存在於 self 中。

  • 傳播名稱:如果輸入張量中由 dimdims 指定的維度不存在於輸出張量中,則這些維度的對應名稱不會出現在 output.names 中。

>>> x = torch.randn(1, 3, 3, 3, names=('N', 'C', 'H', 'W'))
>>> x.squeeze('N').names
('C', 'H', 'W')

>>> x = torch.randn(3, 3, 3, 3, names=('N', 'C', 'H', 'W'))
>>> x.sum(['N', 'C']).names
('H', 'W')

# Reduction ops with keepdim=True don't actually remove dimensions.
>>> x = torch.randn(3, 3, 3, 3, names=('N', 'C', 'H', 'W'))
>>> x.sum(['N', 'C'], keepdim=True).names
('N', 'C', 'H', 'W')

統一輸入名稱

所有二元算術操作都遵循此規則。廣播操作仍然從右側按位置廣播,以保留與未命名張量的相容性。若要按名稱執行顯式廣播,請使用 Tensor.align_as()

  • 檢查名稱:所有名稱必須從右側按位置匹配。例如,在 tensor + other 中,對於 (-min(tensor.dim(), other.dim()) + 1, -1] 中的所有 imatch(tensor.names[i], other.names[i]) 必須為真。

  • 檢查名稱:此外,所有已命名維度必須從右側對齊。在匹配過程中,如果我們將命名維度 A 與未命名維度 None 匹配,則 A 不得出現在帶有未命名維度的張量中。

  • 傳播名稱:從右側統一兩個張量中的名稱對,以生成輸出名稱。

例如,

# tensor: Tensor[   N, None]
# other:  Tensor[None,    C]
>>> tensor = torch.randn(3, 3, names=('N', None))
>>> other = torch.randn(3, 3, names=(None, 'C'))
>>> (tensor + other).names
('N', 'C')

檢查名稱

  • match(tensor.names[-1], other.names[-1]) is True

  • match(tensor.names[-2], tensor.names[-2]) is True

  • 因為我們將 tensor 中的 None'C' 匹配,請檢查確保 'C' 不存在於 tensor 中(它確實不存在)。

  • 檢查確保 'N' 不存在於 other 中(它確實不存在)。

最後,輸出名稱透過 [unify('N', None), unify(None, 'C')] = ['N', 'C'] 計算得出。

更多示例

# Dimensions don't match from the right:
# tensor: Tensor[N, C]
# other:  Tensor[   N]
>>> tensor = torch.randn(3, 3, names=('N', 'C'))
>>> other = torch.randn(3, names=('N',))
>>> (tensor + other).names
RuntimeError: Error when attempting to broadcast dims ['N', 'C'] and dims
['N']: dim 'C' and dim 'N' are at the same position from the right but do
not match.

# Dimensions aren't aligned when matching tensor.names[-1] and other.names[-1]:
# tensor: Tensor[N, None]
# other:  Tensor[      N]
>>> tensor = torch.randn(3, 3, names=('N', None))
>>> other = torch.randn(3, names=('N',))
>>> (tensor + other).names
RuntimeError: Misaligned dims when attempting to broadcast dims ['N'] and
dims ['N', None]: dim 'N' appears in a different position from the right
across both lists.

注意

在最後兩個示例中,都可以按名稱對齊張量,然後執行加法。使用 Tensor.align_as() 按名稱對齊張量,或使用 Tensor.align_to() 將張量對齊到自定義的維度順序。

置換維度

一些操作,如 Tensor.t(),會置換維度順序。維度名稱附加到各個維度上,因此它們也會被置換。

如果運算子接受位置索引 dim,它也可以接受維度名稱作為 dim

  • 檢查名稱:如果將 dim 作為名稱傳遞,檢查它是否存在於張量中。

  • 傳播名稱:以與被置換維度相同的方式置換維度名稱。

>>> x = torch.randn(3, 3, names=('N', 'C'))
>>> x.transpose('N', 'C').names
('C', 'N')

收縮維度

矩陣乘法函式遵循此規則的某種變體。我們首先來看 torch.mm(),然後概括批次矩陣乘法的規則。

對於 torch.mm(tensor, other)

  • 檢查名稱:無

  • 傳播名稱:結果名稱為 (tensor.names[-2], other.names[-1])

>>> x = torch.randn(3, 3, names=('N', 'D'))
>>> y = torch.randn(3, 3, names=('in', 'out'))
>>> x.mm(y).names
('N', 'out')

本質上,矩陣乘法在兩個維度上執行點積,並將它們摺疊。當兩個張量進行矩陣乘法時,被收縮的維度會消失,並且不會出現在輸出張量中。

torch.mv()torch.dot() 的工作方式類似:名稱推斷不檢查輸入名稱,並移除參與點積的維度。

>>> x = torch.randn(3, 3, names=('N', 'D'))
>>> y = torch.randn(3, names=('something',))
>>> x.mv(y).names
('N',)

現在,我們來看看 torch.matmul(tensor, other)。假設 tensor.dim() >= 2other.dim() >= 2

  • 檢查名稱:檢查輸入的批次維度是否對齊且可廣播。關於輸入對齊的含義,請參閱統一輸入名稱

  • 傳播名稱:結果名稱透過統一批次維度和移除被收縮維度獲得:unify(tensor.names[:-2], other.names[:-2]) + (tensor.names[-2], other.names[-1])

示例

# Batch matrix multiply of matrices Tensor['C', 'D'] and Tensor['E', 'F'].
# 'A', 'B' are batch dimensions.
>>> x = torch.randn(3, 3, 3, 3, names=('A', 'B', 'C', 'D'))
>>> y = torch.randn(3, 3, 3, names=('B', 'E', 'F'))
>>> torch.matmul(x, y).names
('A', 'B', 'C', 'F')

最後,許多矩陣乘法函式都有融合的 add 版本。例如,addmm()addmv()。這些函式被視為組合了例如 mm() 的名稱推斷和 add() 的名稱推斷。

工廠函式

工廠函式現在接受一個新的 names 引數,該引數為每個維度關聯一個名稱。

>>> torch.zeros(2, 3, names=('N', 'C'))
tensor([[0., 0., 0.],
        [0., 0., 0.]], names=('N', 'C'))

out 函式和原地(in-place)變體

指定為 out= 張量的行為如下:

  • 如果它沒有命名維度,則從操作中計算出的名稱會傳播給它。

  • 如果它有任何命名維度,則從操作中計算出的名稱必須與現有名稱完全一致。否則,操作會報錯。

所有原地(in-place)方法都會修改輸入,使其名稱等於名稱推斷計算出的名稱。例如

>>> x = torch.randn(3, 3)
>>> y = torch.randn(3, 3, names=('N', 'C'))
>>> x.names
(None, None)

>>> x += y
>>> x.names
('N', 'C')

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源