捷徑

torch.masked

簡介

動機

警告

遮罩張量的 PyTorch API 處於原型階段,未來可能會或可能不會更改。

MaskedTensor 是 torch.Tensor 的擴充,它為使用者提供了以下功能:

  • 使用任何遮罩語義(例如,變長張量、nan* 運算子等)

  • 區分 0 和 NaN 梯度

  • 各種稀疏應用(請參閱以下教學課程)

「指定的」和「未指定的」在 PyTorch 中有著悠久的歷史,但沒有正式的語義,當然也沒有一致性;事實上,MaskedTensor 的誕生源於普通的 torch.Tensor 類別無法妥善解決的問題。因此,MaskedTensor 的主要目標是成為 PyTorch 中所述「指定的」和「未指定的」值的真實來源,在這些值中,它們是一等公民,而不是事後才想到的。反過來,這應該會進一步釋放 稀疏性 的潛力,實現更安全、更一致的運算子,並為使用者和開發者提供更順暢、更直觀的體驗。

什麼是 MaskedTensor?

MaskedTensor 是一個張量子類別,它由 1) 輸入(資料)和 2) 遮罩組成。遮罩告訴我們哪些輸入項目應該包含或忽略。

舉例來說,假設我們想要遮罩掉所有等於 0 的值(以灰色表示),並取最大值

_images/tensor_comparison.jpg

頂部是普通的張量範例,而底部是 MaskedTensor,其中所有 0 都被遮罩掉了。這顯然會根據我們是否有遮罩而產生不同的結果,但這種靈活的結構允許使用者在計算過程中系統地忽略他們想要的任何元素。

我們已經編寫了一些現有的教學課程來幫助使用者入門,例如:

支援的運算子

一元運算子

一元運算子是只包含單一輸入的運算子。將它們應用於 MaskedTensors 相對簡單:如果資料在給定索引處被遮罩,我們就應用運算子,否則我們將繼續遮罩資料。

可用的一元運算子如下:

abs

計算 input 中每個元素的絕對值。

absolute

torch.abs() 的別名

acos

計算 input 中每個元素的反餘弦。

arccos

torch.acos() 的別名。

acosh

傳回一個新張量,其中包含 input 元素的反雙曲餘弦。

arccosh

torch.acosh() 的別名。

angle

計算給定 input 張量的逐元素角度(以弧度為單位)。

asin

傳回一個新張量,其中包含 input 元素的反正弦。

arcsin

torch.asin() 的別名。

asinh

傳回一個新張量,其中包含 input 元素的反雙曲正弦。

arcsinh

torch.asinh() 的別名。

atan

傳回一個新的張量,其中包含 input 元素的反正切值。

arctan

torch.atan() 的別名。

atanh

傳回一個新的張量,其中包含 input 元素的反雙曲正切值。

arctanh

torch.atanh() 的別名。

bitwise_not

計算給定輸入張量的位元反轉。

ceil

傳回一個新的張量,其中包含 input 元素的ceil值,即大於或等於每個元素的最小整數。

clamp

input 中的所有元素限制在範圍 [ min, max ] 內。

clip

torch.clamp() 的別名。

conj_physical

計算給定 input 張量的元素共軛複數。

cos

傳回一個新的張量,其中包含 input 元素的餘弦值。

cosh

傳回一個新的張量,其中包含 input 元素的雙曲餘弦值。

deg2rad

傳回一個新的張量,其中包含將 input 的每個元素從角度轉換為弧度。

digamma

torch.special.digamma() 的別名。

erf

torch.special.erf() 的別名。

erfc

torch.special.erfc() 的別名。

erfinv

torch.special.erfinv() 的別名。

exp

傳回一個新的張量,其中包含輸入張量 input 元素的指數值。

exp2

torch.special.exp2() 的別名。

expm1

torch.special.expm1() 的別名。

fix

torch.trunc() 的別名

floor

傳回一個新的張量,其中包含 input 元素的floor值,即小於或等於每個元素的最大整數。

frac

計算 input 中每個元素的小數部分。

lgamma

計算 input 上伽瑪函數的絕對值的自然對數。

log

傳回一個新的張量,其中包含 input 元素的自然對數。

log10

傳回一個新的張量,其中包含 input 元素的以10為底的對數。

log1p

傳回一個新的張量,其中包含 (1 + input) 的自然對數。

log2

傳回一個新的張量,其中包含 input 元素的以2為底的對數。

logit

torch.special.logit() 的別名。

i0

torch.special.i0() 的別名。

isnan

傳回一個新的張量,其中包含布林元素,表示 input 的每個元素是否為 NaN。

nan_to_num

input 中的 NaN、正無窮大和負無窮大值分別替換為 nanposinfneginf 指定的值。

neg

傳回一個新的張量,其中包含 input 元素的相反數。

negative

torch.neg() 的別名

positive

傳回 input

pow

使用 exponent 計算 input 中每個元素的冪,並傳回一個包含結果的張量。

rad2deg

傳回一個新的張量,其中包含將 input 的每個元素從弧度轉換為角度。

reciprocal

傳回一個新的張量,其中包含 input 元素的倒數

round

input 的元素四捨五入到最接近的整數。

rsqrt

傳回一個新的張量,其中包含 input 中每個元素的平方根的倒數。

sigmoid

torch.special.expit() 的別名。

sign

傳回一個新的張量,其中包含 input 元素的正負號。

sgn

此函數是 torch.sign() 對複數張量的擴展。

signbit

測試 input 的每個元素是否設置了符號位。

sin

傳回一個新的張量,其中包含 input 元素的正弦值。

sinc

torch.special.sinc() 的別名。

sinh

傳回一個新的張量,其中包含 input 元素的雙曲正弦值。

sqrt

傳回一個新的張量,其中包含 input 元素的平方根。

square

傳回一個新的張量,其中包含 input 元素的平方。

tan

傳回一個新的張量,其中包含 input 元素的正切值。

tanh

傳回一個新的張量,其中包含 input 元素的雙曲正切值。

trunc

傳回一個新的張量,其中包含 input 元素的截斷整數值。

可用的原地單元運算符是以上所有運算符,除了

angle

計算給定 input 張量的逐元素角度(以弧度為單位)。

positive

傳回 input

signbit

測試 input 的每個元素是否設置了符號位。

isnan

傳回一個新的張量,其中包含布林元素,表示 input 的每個元素是否為 NaN。

二元運算符

您可能在教程中已經看到,MaskedTensor 也實現了二元運算,但需要注意的是,兩個 MaskedTensors 中的遮罩必須匹配,否則會引發錯誤。如錯誤訊息中所述,如果您需要支援特定運算符或建議它們應該如何運作的語義,請在 GitHub 上提交議題。目前,我們決定採用最保守的實現方式,以確保使用者確切了解發生了什麼事,並對遮罩語義做出有意識的決定。

可用的二元運算符有

add

other 乘以 alpha 後加到 input

atan2

考慮象限的元素反正切 inputi/otheri\text{input}_{i} / \text{other}_{i}

arctan2

torch.atan2() 的別名。

bitwise_and

計算 inputother 的位元 AND。

bitwise_or

計算 inputother 的位元 OR。

bitwise_xor

計算 inputother 的位元 XOR。

bitwise_left_shift

計算 input 向左位元移位 other 位元。

bitwise_right_shift

計算 input 向右位元移位 other 位元。

div

將輸入 input 的每個元素除以 other 的對應元素。

divide

torch.div() 的別名。

floor_divide

fmod

逐項應用 C++ 的 std::fmod

logaddexp

輸入值的指數總和之對數。

logaddexp2

以 2 為底數時,輸入值的指數總和之對數。

mul

input 乘以 other

multiply

torch.mul() 的別名。

nextafter

逐項傳回在 input 之後,朝向 other 的下一個浮點數值。

remainder

逐項計算 Python 的模數運算

sub

input 中減去縮放 alpha 倍的 other

subtract

torch.sub() 的別名。

true_divide

torch.div() 搭配 rounding_mode=None 的別名。

eq

逐項計算是否相等

ne

逐項計算 inputother\text{input} \neq \text{other}

le

逐項計算 inputother\text{input} \leq \text{other}

ge

逐項計算 inputother\text{input} \geq \text{other}

greater

torch.gt() 的別名。

greater_equal

torch.ge() 的別名。

gt

逐項計算 input>other\text{input} > \text{other}

less_equal

torch.le() 的別名。

lt

逐項計算 input<other\text{input} < \text{other}

less

torch.lt() 的別名。

maximum

逐項計算 inputother 的最大值。

minimum

逐項計算 inputother 的最小值。

fmax

逐項計算 inputother 的最大值。

fmin

逐項計算 inputother 的最小值。

not_equal

torch.ne() 的別名。

可用的就地二元運算子為以上所有運算子,除了

logaddexp

輸入值的指數總和之對數。

logaddexp2

以 2 為底數時,輸入值的指數總和之對數。

equal

如果兩個張量具有相同的大小和元素,則為 True,否則為 False

fmin

逐項計算 inputother 的最小值。

minimum

逐項計算 inputother 的最小值。

fmax

逐項計算 inputother 的最大值。

縮減

以下縮減操作可用(支援自動梯度)。如需更多資訊,概覽 教學課程詳細說明了一些縮減範例,而 進階語義 教學課程則深入探討了我們如何決定某些縮減語義。

sum

傳回 input 張量中所有元素的總和。

mean

傳回 input 張量中所有元素的平均值。

amin

傳回給定維度 dim 中,input 張量每個切片的最小值。

amax

傳回給定維度 dim 中,input 張量每個切片的最大值。

argmin

傳回扁平化張量或沿著維度的最小值索引

argmax

傳回 input 張量中所有元素最大值的索引。

prod

傳回 input 張量中所有元素的乘積。

all

測試 input 中的所有元素是否評估為 True

norm

傳回給定張量的矩陣範數或向量範數。

var

計算 dim 指定維度上的變異數。

std

計算 dim 指定維度上的標準差。

檢視和選擇函數

我們也包含了一些檢視和選擇函數;直觀地說,這些運算子將應用於資料和遮罩,然後將結果包裝在 MaskedTensor 中。如需快速範例,請考慮 select()

>>> data = torch.arange(12, dtype=torch.float).reshape(3, 4)
>>> data
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]])
>>> mask = torch.tensor([[True, False, False, True], [False, True, False, False], [True, True, True, True]])
>>> mt = masked_tensor(data, mask)
>>> data.select(0, 1)
tensor([4., 5., 6., 7.])
>>> mask.select(0, 1)
tensor([False,  True, False, False])
>>> mt.select(0, 1)
MaskedTensor(
  [      --,   5.0000,       --,       --]
)

目前支援以下操作

atleast_1d

傳回每個輸入張量的 1 維檢視,維度為零。

broadcast_tensors

根據 廣播語義 廣播給定的張量。

broadcast_to

input 廣播到形狀 shape

cat

在給定維度中串聯給定的 seq 張量序列。

chunk

嘗試將張量分割成指定數量的區塊。

column_stack

藉由水平堆疊 tensors 中的張量來建立新的張量。

dsplit

根據 indices_or_sections,沿深度方向將具有三個或更多維度的張量 input 分割成多個張量。

flatten

藉由將 input 重塑為一維張量來將其扁平化。

hsplit

根據 indices_or_sections,將具有一個或多個維度的張量 input 水平分割成多個張量。

hstack

水平依序堆疊張量(以欄為單位)。

kron

計算 inputother 的克羅內克積,表示為 \otimes

meshgrid

建立由 attr:tensors 中的一維輸入指定的座標格線。

narrow

傳回一個新的張量,它是 input 張量的縮小版本。

ravel

傳回連續的扁平化張量。

select

沿著所選維度,在給定索引處切片 input 張量。

split

將張量分割成多個區塊。

t

預期 input 是 <= 2 維張量,並轉置維度 0 和 1。

transpose

傳回 input 的轉置版本張量。

vsplit

根據 indices_or_sections,將具有兩個或多個維度的張量 input 垂直分割成多個張量。

vstack

按順序垂直堆疊張量(逐行)。

Tensor.expand

傳回 self 張量的新視圖,其中單例維度擴展到更大的大小。

Tensor.expand_as

將此張量擴展到與 other 相同的大小。

Tensor.reshape

傳回一個張量,其數據和元素個數與 self 相同,但具有指定的形狀。

Tensor.reshape_as

傳回與 other 形狀相同的張量。

Tensor.view

傳回一個新張量,其數據與 self 張量相同,但具有不同的 shape

文件

存取 PyTorch 的完整開發人員文件

檢視文件

教學

取得適用於初學者和進階開發人員的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源