torch.cat¶
- torch.cat(tensors, dim=0, *, out=None) Tensor¶
在給定維度上拼接輸入序列
tensors中的張量。所有張量除了拼接維度外必須具有相同的形狀,或者是一個大小為(0,)的一維空張量。torch.cat()可以被視為torch.split()和torch.chunk()的逆操作。torch.cat()透過示例最容易理解。另請參閱
torch.stack()沿著一個新維度拼接輸入序列。- 引數
tensors (Tensor 序列) – 提供的非空張量除了拼接維度外必須具有相同的形狀。
dim (int, 可選) – 拼接張量的維度
- 關鍵字引數
out (Tensor, 可選) – 輸出張量。
示例
>>> x = torch.randn(2, 3) >>> x tensor([[ 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497]]) >>> torch.cat((x, x, x), 0) tensor([[ 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497], [ 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497], [ 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497]]) >>> torch.cat((x, x, x), 1) tensor([[ 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497]])