快捷方式

torch.repeat_interleave

torch.repeat_interleave(input, repeats, dim=None, *, output_size=None) Tensor

重複張量的元素。

警告

這與 torch.Tensor.repeat() 不同,但與 numpy.repeat 類似。

引數
  • input (Tensor) – 輸入張量。

  • repeats (Tensorint) – 每個元素的重複次數。 repeats 被廣播以適應給定軸的形狀。

  • dim (int, 可選) – 沿其重複值的維度。 預設情況下,使用展平的輸入陣列,並返回一個展平的輸出陣列。

關鍵字引數

output_size (int, 可選) – 給定軸的總輸出大小(例如,repeats 的總和)。如果給定,將避免計算張量輸出形狀所需的流同步。

返回

重複後的張量,其形狀與輸入相同,但沿給定軸除外。

返回型別

Tensor

示例

>>> x = torch.tensor([1, 2, 3])
>>> x.repeat_interleave(2)
tensor([1, 1, 2, 2, 3, 3])
>>> y = torch.tensor([[1, 2], [3, 4]])
>>> torch.repeat_interleave(y, 2)
tensor([1, 1, 2, 2, 3, 3, 4, 4])
>>> torch.repeat_interleave(y, 3, dim=1)
tensor([[1, 1, 1, 2, 2, 2],
        [3, 3, 3, 4, 4, 4]])
>>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0)
tensor([[1, 2],
        [3, 4],
        [3, 4]])
>>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0, output_size=3)
tensor([[1, 2],
        [3, 4],
        [3, 4]])

如果 repeatstensor([n1, n2, n3, …]),則輸出將是 tensor([0, 0, …, 1, 1, …, 2, 2, …, …]),其中 0 出現 n1 次,1 出現 n2 次,2 出現 n3 次,依此類推。

torch.repeat_interleave(repeats, *) Tensor

將 0 重複 repeats[0] 次,1 重複 repeats[1] 次,2 重複 repeats[2] 次,依此類推。

引數

repeats (Tensor) – 每個元素的重複次數。

返回

重複後的張量,其大小為 sum(repeats)

返回型別

Tensor

示例

>>> torch.repeat_interleave(torch.tensor([1, 2, 3]))
tensor([0, 1, 1, 2, 2, 2])

文件

訪問 PyTorch 的完整開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源