torch.Tensor.masked_scatter_¶
- Tensor.masked_scatter_(mask, source)¶
在
mask為 True 的位置,將source中的元素複製到self張量中。從source的位置 0 開始,對於mask為 True 的每一個位置,按順序逐一將source中的元素複製到self中。mask的形狀必須與底層張量的形狀可廣播 (broadcastable)。source中的元素數量應至少與mask中值為 1 的數量相等。- 引數
mask (BoolTensor) – 布林掩碼
source (Tensor) – 要從中複製元素的張量
注意
mask是作用在self張量上,而不是給定的source張量上。示例
>>> self = torch.tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]) >>> mask = torch.tensor([[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]], dtype=torch.bool) >>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) >>> self.masked_scatter_(mask, source) tensor([[0, 0, 0, 0, 1], [2, 3, 0, 4, 5]])