快捷方式

torch.Tensor.masked_scatter_

Tensor.masked_scatter_(mask, source)

mask 為 True 的位置,將 source 中的元素複製到 self 張量中。從 source 的位置 0 開始,對於 mask 為 True 的每一個位置,按順序逐一將 source 中的元素複製到 self 中。mask 的形狀必須與底層張量的形狀可廣播 (broadcastable)source 中的元素數量應至少與 mask 中值為 1 的數量相等。

引數
  • mask (BoolTensor) – 布林掩碼

  • source (Tensor) – 要從中複製元素的張量

注意

mask 是作用在 self 張量上,而不是給定的 source 張量上。

示例

>>> self = torch.tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]])
>>> mask = torch.tensor([[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]], dtype=torch.bool)
>>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
>>> self.masked_scatter_(mask, source)
tensor([[0, 0, 0, 0, 1],
        [2, 3, 0, 4, 5]])

文件

查閱 PyTorch 全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源