快捷方式

torch.combinations

torch.combinations(input: Tensor, r: int = 2, with_replacement: bool = False) seq

計算給定張量的長度為 rr 的組合。當 with_replacement 設定為 False 時,其行為類似於 Python 的 itertools.combinations;當 with_replacement 設定為 True 時,其行為類似於 itertools.combinations_with_replacement

引數
  • input (Tensor) – 1D 向量。

  • r (int, 可選) – 組合中元素的數量

  • with_replacement (bool, 可選) – 是否允許組合中的元素重複

返回

一個張量,等同於將所有輸入張量轉換為列表,對這些列表執行 itertools.combinationsitertools.combinations_with_replacement,最後將結果列表轉換回張量。

返回型別

張量

示例

>>> a = [1, 2, 3]
>>> list(itertools.combinations(a, r=2))
[(1, 2), (1, 3), (2, 3)]
>>> list(itertools.combinations(a, r=3))
[(1, 2, 3)]
>>> list(itertools.combinations_with_replacement(a, r=2))
[(1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)]
>>> tensor_a = torch.tensor(a)
>>> torch.combinations(tensor_a)
tensor([[1, 2],
        [1, 3],
        [2, 3]])
>>> torch.combinations(tensor_a, r=3)
tensor([[1, 2, 3]])
>>> torch.combinations(tensor_a, with_replacement=True)
tensor([[1, 1],
        [1, 2],
        [1, 3],
        [2, 2],
        [2, 3],
        [3, 3]])

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源