torch.quantile¶
- torch.quantile(input, q, dim=None, keepdim=False, *, interpolation='linear', out=None) Tensor¶
計算
input張量沿維度dim的每一行的 q 分位數。為了計算分位數,我們將 [0, 1] 範圍內的 q 對映到 [0, n] 的索引範圍,以找到分位數在排序輸入中的位置。如果分位數位於排序順序中索引為
i和j的兩個資料點a < b之間,則結果根據給定的interpolation方法計算如下:linear:a + (b - a) * fraction,其中fraction是計算出的分位數索引的小數部分。lower:a。higher:b。nearest:a或b,以索引更接近計算出的分位數索引者為準(對於 .5 的小數部分向下取整)。midpoint:(a + b) / 2。
如果
q是一個 1D 張量,則輸出的第一個維度代表分位數,其大小等於q的大小,其餘維度是規約後剩餘的維度。注意
預設情況下,
dim為None,這將導致input張量在計算前被展平。- 引數
- 關鍵字引數
示例
>>> a = torch.randn(2, 3) >>> a tensor([[ 0.0795, -1.2117, 0.9765], [ 1.1707, 0.6706, 0.4884]]) >>> q = torch.tensor([0.25, 0.5, 0.75]) >>> torch.quantile(a, q, dim=1, keepdim=True) tensor([[[-0.5661], [ 0.5795]], [[ 0.0795], [ 0.6706]], [[ 0.5280], [ 0.9206]]]) >>> torch.quantile(a, q, dim=1, keepdim=True).shape torch.Size([3, 2, 1]) >>> a = torch.arange(4.) >>> a tensor([0., 1., 2., 3.]) >>> torch.quantile(a, 0.6, interpolation='linear') tensor(1.8000) >>> torch.quantile(a, 0.6, interpolation='lower') tensor(1.) >>> torch.quantile(a, 0.6, interpolation='higher') tensor(2.) >>> torch.quantile(a, 0.6, interpolation='midpoint') tensor(1.5000) >>> torch.quantile(a, 0.6, interpolation='nearest') tensor(2.) >>> torch.quantile(a, 0.4, interpolation='nearest') tensor(1.)