快捷方式

torch.masked

介紹

動機

警告

masked tensors 的 PyTorch API 處於原型階段,未來可能會發生變化。

MaskedTensor 是 torch.Tensor 的擴充套件,它提供了以下能力:

  • 使用任何 masked 語義(例如,變長張量、nan* 運算元等)

  • 區分 0 梯度和 NaN 梯度

  • 各種稀疏應用(見下方教程)

“指定”(Specified)和“未指定”(unspecified)在 PyTorch 中有著悠久的歷史,但缺乏正式的語義和一致性;事實上,MaskedTensor 的誕生是為了解決普通 torch.Tensor 類無法妥善處理的一系列問題。因此,MaskedTensor 的主要目標是成為 PyTorch 中這些“指定”和“未指定”值的唯一真理來源,讓它們成為一等公民而非事後補丁。這反過來應該進一步釋放稀疏性的潛力,實現更安全、更一致的運算元,併為使用者和開發者提供更流暢、更直觀的體驗。

什麼是 MaskedTensor?

MaskedTensor 是一種張量子類,由 1) 輸入(資料)和 2) 掩碼(mask)組成。掩碼告訴我們應包含或忽略輸入中的哪些條目。

舉個例子,假設我們想遮蔽掉所有等於 0 的值(用灰色表示)並取最大值

_images/tensor_comparison.jpg

上方是普通張量的例子,下方是 MaskedTensor 的例子,其中所有的 0 都被遮蔽掉了。這顯然會產生不同的結果,取決於我們是否有掩碼,但這種靈活的結構允許使用者在計算過程中系統地忽略他們希望忽略的任何元素。

我們已經編寫了一些現有教程來幫助使用者入門,例如

支援的運算元

一元運算元

一元運算元是僅包含一個輸入的運算元。將其應用於 MaskedTensors 相對簡單:如果在給定索引處資料被遮蔽,我們會應用該運算元;否則,我們將繼續遮蔽資料。

可用的一元運算元有

abs

計算 input 中每個元素的絕對值。

absolute

torch.abs() 的別名

acos

計算 input 中每個元素的反餘弦。

arccos

torch.acos() 的別名。

acosh

返回一個新張量,其中包含 input 中元素的反雙曲餘弦。

arccosh

torch.acosh() 的別名。

angle

計算給定 input 張量的逐元素角度(以弧度為單位)。

asin

返回一個新張量,其中包含 input 中元素的反正弦。

arcsin

torch.asin() 的別名。

asinh

返回一個新張量,其中包含 input 中元素的反雙曲正弦。

arcsinh

torch.asinh() 的別名。

atan

返回一個新張量,其中包含 input 中元素的反正切。

arctan

torch.atan() 的別名。

atanh

返回一個新張量,其中包含 input 中元素的反雙曲正切。

arctanh

torch.atanh() 的別名。

bitwise_not

計算給定輸入張量的按位非。

ceil

返回一個新張量,其中包含 input 中元素的向上取整結果,即大於或等於每個元素的最小整數。

clamp

input 中的所有元素限制在 [ min, max ] 範圍內。

clip

torch.clamp() 的別名。

conj_physical

計算給定 input 張量的逐元素共軛。

cos

返回一個新張量,其中包含 input 中元素的餘弦。

cosh

返回一個新張量,其中包含 input 中元素的雙曲餘弦。

deg2rad

返回一個新張量,其中包含 input 中每個元素從角度(度)轉換為弧度的結果。

digamma

torch.special.digamma() 的別名。

erf

torch.special.erf() 的別名。

erfc

torch.special.erfc() 的別名。

erfinv

torch.special.erfinv() 的別名。

exp

返回一個新張量,其中包含輸入張量 input 中元素的指數。

exp2

torch.special.exp2() 的別名。

expm1

torch.special.expm1() 的別名。

fix

torch.trunc() 的別名

floor

返回一個新張量,其中包含 input 中元素的向下取整結果,即小於或等於每個元素的最大整數。

frac

計算 input 中每個元素的小數部分。

lgamma

計算 input 上伽馬函式絕對值的自然對數。

log

返回一個新張量,其中包含 input 中元素的自然對數。

log10

返回一個新張量,其中包含 input 中元素以 10 為底的對數。

log1p

返回一個新張量,其中包含 (1 + input) 的自然對數。

log2

返回一個新張量,其中包含 input 中元素以 2 為底的對數。

logit

torch.special.logit() 的別名。

i0

torch.special.i0() 的別名。

isnan

返回一個新張量,其中包含布林元素,表示 input 的每個元素是否為 NaN。

nan_to_num

NaN、正無窮和負無窮值在 input 中分別替換為由 nanposinfneginf 指定的值。

neg

返回一個新張量,其中包含 input 中元素的負值。

negative

torch.neg() 的別名

positive

返回 input

pow

計算 input 中每個元素的 exponent 次冪,並返回包含結果的張量。

rad2deg

返回一個新張量,其中包含 input 中每個元素從角度(弧度)轉換為度的結果。

reciprocal

返回一個新張量,其中包含 input 中元素的倒數

round

input 中的元素四捨五入到最接近的整數。

rsqrt

返回一個新張量,其中包含 input 中每個元素平方根的倒數。

sigmoid

torch.special.expit() 的別名。

sign

返回一個新張量,其中包含 input 中元素的符號。

sgn

此函式是 torch.sign() 對於複數張量的擴充套件。

signbit

測試 input 的每個元素是否設定了符號位。

sin

返回一個新張量,其中包含 input 中元素的正弦。

sinc

torch.special.sinc() 的別名。

sinh

返回一個新張量,其中包含 input 中元素的雙曲正弦。

sqrt

返回一個新張量,其中包含 input 中元素的平方根。

square

返回一個新張量,其中包含 input 中元素的平方。

tan

返回一個新張量,其中包含 input 中元素的正切。

tanh

返回一個新張量,其中包含 input 中元素的雙曲正切。

trunc

返回一個新張量,其中包含 input 中元素的截斷整數值。

可用的就地(inplace)一元運算元包括上述所有運算元,**除了**

angle

計算給定 input 張量的逐元素角度(以弧度為單位)。

positive

返回 input

signbit

測試 input 的每個元素是否設定了符號位。

isnan

返回一個新張量,其中包含布林元素,表示 input 的每個元素是否為 NaN。

二元運算元

如您在教程中可能看到的,MaskedTensor 也實現了二元操作,但需要注意的是,兩個 MaskedTensors 中的掩碼必須匹配,否則會引發錯誤。正如錯誤資訊中指出的,如果您需要支援某個特定的運算元,或者對它們應該如何表現有提議的語義,請在 GitHub 上開啟一個 issue。目前,我們決定採用最保守的實現方式,以確保使用者清楚地瞭解正在發生的事情,並慎重地對待 masked 語義相關的決策。

可用的二元運算元有

add

將按 alpha 縮放的 other 新增到 input

atan2

逐元素計算 inputi/otheri\text{input}_{i} / \text{other}_{i} 的反正切,並考慮象限。

arctan2

torch.atan2() 的別名。

bitwise_and

計算 inputother 的按位與。

bitwise_or

計算 inputother 的按位或。

bitwise_xor

計算 inputother 的按位異或。

bitwise_left_shift

計算 inputother 位進行的左算術移位。

bitwise_right_shift

計算 inputother 位進行的右算術移位。

div

將輸入 input 的每個元素除以 other 的對應元素。

divide

torch.div() 的別名。

floor_divide

fmod

逐元素應用 C++ 的 std::fmod

logaddexp

輸入指數之和的對數。

logaddexp2

輸入指數之和以 2 為底的對數。

mul

input 乘以 other

multiply

torch.mul() 的別名。

nextafter

逐元素返回 input 朝向 other 方向的下一個浮點值。

remainder

逐元素計算 Python 的模運算

sub

input 中減去按 alpha 縮放的 other

subtract

torch.sub() 的別名。

true_divide

torch.div() 的別名,其中 rounding_mode=None

eq

計算逐元素相等性

ne

逐元素計算 inputother\text{input} \neq \text{other}

le

逐元素計算 inputother\text{input} \leq \text{other}

ge

逐元素計算 inputother\text{input} \geq \text{other}

greater

torch.gt() 的別名。

大於等於

torch.ge() 的別名。

gt

按元素計算 input>other\text{input} > \text{other}

小於等於

torch.le() 的別名。

lt

按元素計算 input<other\text{input} < \text{other}

小於

torch.lt() 的別名。

最大值

按元素計算 inputother 的最大值。

最小值

按元素計算 inputother 的最小值。

fmax

按元素計算 inputother 的最大值。

fmin

按元素計算 inputother 的最小值。

不等於

torch.ne() 的別名。

可用的原地二元運算子包含以上所有,**除了**

logaddexp

輸入指數之和的對數。

logaddexp2

輸入指數之和以 2 為底的對數。

等於

如果兩個張量具有相同的大小和元素,則為 True,否則為 False

fmin

按元素計算 inputother 的最小值。

最小值

按元素計算 inputother 的最小值。

fmax

按元素計算 inputother 的最大值。

規約

以下規約可用(支援 autograd)。更多資訊請參閱 概述 教程,其中詳細介紹了一些規約示例;而 高階語義 教程則對某些規約語義的決定方式進行了深入探討。

求和

返回 input 張量中所有元素的和。

均值

最小值

返回 input 張量在給定維度 dim 中每個切片的最小值。

最大值

返回 input 張量在給定維度 dim 中每個切片的最大值。

最小值索引

返回展平張量或沿某個維度的最小值索引

最大值索引

返回 input 張量中所有元素的最大值索引。

乘積

返回 input 張量中所有元素的乘積。

全部為真

測試 input 中的所有元素是否都評估為 True

範數

返回給定張量的矩陣範數或向量範數。

方差

計算由 dim 指定的維度上的方差。

標準差

計算由 dim 指定的維度上的標準差。

檢視和選擇函式

我們還包含了一些檢視和選擇函式;直觀上,這些運算子將同時應用於資料和掩碼,然後將結果包裝在 MaskedTensor 中。舉個簡單的例子,考慮 select()

>>> data = torch.arange(12, dtype=torch.float).reshape(3, 4)
>>> data
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]])
>>> mask = torch.tensor([[True, False, False, True], [False, True, False, False], [True, True, True, True]])
>>> mt = masked_tensor(data, mask)
>>> data.select(0, 1)
tensor([4., 5., 6., 7.])
>>> mask.select(0, 1)
tensor([False,  True, False, False])
>>> mt.select(0, 1)
MaskedTensor(
  [      --,   5.0000,       --,       --]
)

當前支援以下操作:

至少一維

返回每個零維輸入張量的一維檢視。

廣播張量

根據 廣播語義 廣播給定張量。

廣播到

input 廣播到形狀 shape

連線

在給定維度連線 tensors 中的給定張量序列。

分塊

嘗試將張量分割成指定數量的塊。

列堆疊

透過水平堆疊 tensors 中的張量建立一個新張量。

深度分割

根據 indices_or_sectionsinput (一個三維或更多維度的張量) 深度分割成多個張量。

展平

透過將其重塑為一維張量來展平 input

水平分割

根據 indices_or_sectionsinput (一個一維或更多維度的張量) 水平分割成多個張量。

水平堆疊

按順序水平堆疊張量(按列)。

Kronecker 積

計算 inputother 的 Kronecker 積,記為 \otimes

網格化

根據 attr:tensors 中的一維輸入建立座標網格。

窄化

返回一個新張量,它是 input 張量的窄化版本。

nn.functional.unfold

從批次輸入張量中提取滑動區域性塊。

展平

返回一個連續的展平張量。

選擇

在給定索引處沿選定維度對 input 張量進行切片。

分割

將張量分割成塊。

堆疊

沿新維度連線張量序列。

轉置

要求 input 是 <= 2維張量,並轉置維度 0 和 1。

轉置

返回 input 張量的轉置版本。

垂直分割

根據 indices_or_sectionsinput (一個二維或更多維度的張量) 垂直分割成多個張量。

垂直堆疊

按順序垂直堆疊張量(按行)。

Tensor.expand

返回 self 張量的一個新檢視,其中單例維度被擴充套件到更大尺寸。

Tensor.expand_as

將此張量擴充套件到與 other 相同的大小。

Tensor.reshape

返回一個與 self 具有相同資料和元素數量但具有指定形狀的張量。

Tensor.reshape_as

返回此張量,使其形狀與 other 相同。

Tensor.unfold

返回原始張量的一個檢視,該檢視包含 self 張量中維度 dimension 上所有大小為 size 的切片。

Tensor.view

返回一個新張量,其資料與 self 張量相同但具有不同的 shape

文件

訪問 PyTorch 的完整開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源