torch.diag¶
- torch.diag(input, diagonal=0, *, out=None) Tensor¶
如果
input是一個向量 (1-D 張量),則返回一個 2-D 方陣張量,其對角線元素由input的元素構成。如果
input是一個矩陣 (2-D 張量),則返回一個 1-D 張量,其元素為input的對角線元素。
引數
diagonal控制考慮哪條對角線示例
獲取以輸入向量為對角線的方陣
>>> a = torch.randn(3) >>> a tensor([ 0.5950,-0.0872, 2.3298]) >>> torch.diag(a) tensor([[ 0.5950, 0.0000, 0.0000], [ 0.0000,-0.0872, 0.0000], [ 0.0000, 0.0000, 2.3298]]) >>> torch.diag(a, 1) tensor([[ 0.0000, 0.5950, 0.0000, 0.0000], [ 0.0000, 0.0000,-0.0872, 0.0000], [ 0.0000, 0.0000, 0.0000, 2.3298], [ 0.0000, 0.0000, 0.0000, 0.0000]])
獲取給定矩陣的第 k 條對角線
>>> a = torch.randn(3, 3) >>> a tensor([[-0.4264, 0.0255,-0.1064], [ 0.8795,-0.2429, 0.1374], [ 0.1029,-0.6482,-1.6300]]) >>> torch.diag(a, 0) tensor([-0.4264,-0.2429,-1.6300]) >>> torch.diag(a, 1) tensor([ 0.0255, 0.1374])