捷徑

torch

torch 套件包含多維張量的資料結構,並定義了這些張量的數學運算。此外,它還提供了許多實用的工具,用於有效地序列化張量和任意類型,以及其他有用的工具。

它有一個 CUDA 對應項,使您能夠在運算能力 >= 3.0 的 NVIDIA GPU 上執行張量運算。

張量

is_tensor

如果 obj 是 PyTorch 張量,則傳回 True。

is_storage

如果 obj 是 PyTorch 儲存體物件,則傳回 True。

is_complex

如果 輸入 的資料類型是複數資料類型,即 torch.complex64torch.complex128 其中之一,則傳回 True。

is_conj

如果 輸入 是共軛張量,即其共軛位元設定為 True,則傳回 True。

is_floating_point

如果 輸入 的資料類型是浮點數資料類型,即 torch.float64torch.float32torch.float16torch.bfloat16 其中之一,則傳回 True。

is_nonzero

如果 輸入 是單一元素張量,並且在類型轉換後不等於零,則傳回 True。

set_default_dtype

將預設浮點數 dtype 設定為 d

get_default_dtype

取得目前的預設浮點數 torch.dtype

set_default_device

設定要在 裝置 上配置的預設 torch.Tensor

get_default_device

取得要在 裝置 上配置的預設 torch.Tensor

set_default_tensor_type

numel

傳回 輸入 張量中的元素總數。

set_printoptions

設定列印選項。

set_flush_denormal

停用 CPU 上的非正規浮點數。

建立運算

注意

隨機取樣建立運算列在 隨機取樣 下,包括: torch.rand() torch.rand_like() torch.randn() torch.randn_like() torch.randint() torch.randint_like() torch.randperm() 您也可以使用 torch.empty() 搭配 就地隨機取樣 方法,建立從更廣泛的分佈中取樣的 torch.Tensor

tensor

透過複製 資料 建構沒有 Autograd 歷程記錄的張量(也稱為「葉張量」,請參閱 Autograd 機制)。

sparse_coo_tensor

以給定的 索引 建構具有指定值的 COO(座標)格式的稀疏張量

sparse_csr_tensor

以給定的 crow_indicescol_indices 建構具有指定值的 CSR(壓縮稀疏列)格式的稀疏張量

sparse_csc_tensor

以給定的 ccol_indicesrow_indices 建構具有指定值的 CSC(壓縮稀疏欄)格式的稀疏張量

sparse_bsr_tensor

以給定的 crow_indicescol_indices 建構具有指定二維區塊的 BSR(區塊壓縮稀疏列)格式的稀疏張量

sparse_bsc_tensor

使用給定的 ccol_indicesrow_indices,構造一個具有指定二維塊的 BSC(塊壓縮稀疏列)格式的稀疏張量

asarray

obj 轉換為張量。

as_tensor

data 轉換為張量,盡可能共享數據並保留自動求導歷史記錄。

as_strided

使用指定的 sizestridestorage_offset,創建現有 torch.Tensor input 的視圖。

from_file

創建一個 CPU 張量,其存儲由內存映射文件支持。

from_numpy

numpy.ndarray 創建一個 Tensor

from_dlpack

將外部庫中的張量轉換為 torch.Tensor

frombuffer

從實現 Python 緩衝區協議的對象創建一個一維 Tensor

zeros

返回一個填充了標量值 0 的張量,其形狀由可變參數 size 定義。

zeros_like

返回一個填充了標量值 0 的張量,其大小與 input 相同。

ones

返回一個填充了標量值 1 的張量,其形狀由可變參數 size 定義。

ones_like

返回一個填充了標量值 1 的張量,其大小與 input 相同。

arange

返回一個大小為 endstartstep\left\lceil \frac{\text{end} - \text{start}}{\text{step}} \right\rceil 的一維張量,其值從區間 [start, end) 中以公差 stepstart 開始取值。

range

返回一個大小為 endstartstep+1\left\lfloor \frac{\text{end} - \text{start}}{\text{step}} \right\rfloor + 1 的一維張量,其值從 startend 以步長 step 取值。

linspace

創建一個大小為 steps 的一維張量,其值在 startend 之間均勻分佈,包括端點。

logspace

創建一個大小為 steps 的一維張量,其值在以 base 為底的對數刻度上,從 basestart{{\text{{base}}}}^{{\text{{start}}}}baseend{{\text{{base}}}}^{{\text{{end}}}} 之間均勻分佈,包括端點。

eye

返回一個二維張量,對角線上的元素為 1,其他元素為 0。

empty

返回一個填充了未初始化數據的張量。

empty_like

返回一個與 input 大小相同的未初始化張量。

empty_strided

創建一個具有指定 sizestride 的張量,並填充未定義的數據。

full

創建一個大小為 size 的張量,並填充 fill_value

full_like

返回一個與 input 大小相同的張量,並填充 fill_value

quantize_per_tensor

使用給定的比例因子和零點,將浮點數張量轉換為量化張量。

quantize_per_channel

使用給定的比例因子和零點,將浮點數張量轉換為逐通道量化張量。

dequantize

通過反量化量化張量,返回一個 fp32 張量。

complex

構造一個複數張量,其實部等於 real,虛部等於 imag

polar

構造一個複數張量,其元素是與極坐標相對應的笛卡爾坐標,其絕對值為 abs,角度為 angle

heaviside

計算 input 中每個元素的 Heaviside 階躍函數。

索引、切片、連接、變異操作

adjoint

返回張量的共軛視圖,並轉置最後兩個維度。

argwhere

返回一個張量,其中包含 input 中所有非零元素的索引。

cat

在給定維度上連接給定的 seq 張量序列。

concat

torch.cat() 的別名。

concatenate

torch.cat() 的別名。

conj

返回 input 的視圖,其中共軛位已翻轉。

chunk

嘗試將張量拆分為指定數量的塊。

dsplit

根據 indices_or_sections,將具有三個或更多維度的張量 input 在深度方向上拆分為多個張量。

column_stack

通過水平堆疊 tensors 中的張量來創建一個新的張量。

dstack

按深度方向(沿第三個軸)依次堆疊張量。

gather

沿 dim 指定的軸收集值。

hsplit

根據 indices_or_sections,水平地將具有一個或多個維度的張量 input 拆分為多個張量。

hstack

依序水平堆疊張量(以列為單位)。

index_add

函數說明請參閱 index_add_()

index_copy

函數說明請參閱 index_add_()

index_reduce

函數說明請參閱 index_reduce_()

index_select

傳回一個新的張量,它會使用 LongTensor 類型的 index 中的項目,沿著維度 diminput 張量進行索引。

masked_select

傳回一個新的 1 維張量,它會根據 BoolTensor 類型的布林遮罩 mask,對 input 張量進行索引。

movedim

input 中位於 source 位置的維度移動到 destination 中的位置。

moveaxis

torch.movedim() 的別名。

narrow

傳回一個新的張量,它是 input 張量的縮小版本。

narrow_copy

Tensor.narrow() 相同,但會傳回一個副本而不是共用儲存空間。

nonzero

permute

傳回原始張量 input 的視圖,其維度會被置換。

reshape

傳回一個與 input 具有相同資料和元素數量的張量,但具有指定的形狀。

row_stack

torch.vstack() 的別名。

select

沿著選定的維度,在給定的索引處切割 input 張量。

scatter

torch.Tensor.scatter_() 的非就地版本

diagonal_scatter

src 張量的值嵌入到 input 中沿著 input 的對角線元素,相對於 dim1dim2

select_scatter

src 張量的值嵌入到 input 中的給定索引處。

slice_scatter

src 張量的值嵌入到 input 中的給定維度處。

scatter_add

torch.Tensor.scatter_add_() 的非就地版本

scatter_reduce

torch.Tensor.scatter_reduce_() 的非就地版本

split

將張量拆分為多個區塊。

squeeze

傳回一個張量,其中所有大小為 1 的指定維度 input 都會被移除。

stack

沿著新的維度串聯多個張量。

swapaxes

torch.transpose() 的別名。

swapdims

torch.transpose() 的別名。

t

預期 input 是 <= 2 維張量,並轉置維度 0 和 1。

take

傳回一個新的張量,其中包含位於給定索引處的 input 元素。

take_along_dim

沿著給定的 dim,從 indices 的一維索引處選取 input 中的值。

tensor_split

沿著維度 dim,根據 indices_or_sections 指定的索引或區塊數,將張量拆分為多個子張量,所有子張量都是 input 的視圖。

tile

藉由重複 input 的元素來構造張量。

transpose

傳回一個張量,它是 input 的轉置版本。

unbind

移除張量維度。

unravel_index

將扁平索引的張量轉換為坐標張量的元組,這些坐標張量會索引到指定形狀的任意張量。

unsqueeze

傳回一個新的張量,其中在指定位置插入一個大小為一的維度。

vsplit

根據 indices_or_sections,垂直地將具有兩個或多個維度的張量 input 拆分為多個張量。

vstack

依序垂直堆疊張量(以行為單位)。

where

根據 condition,傳回從 inputother 中選擇的元素張量。

產生器

產生器

建立並傳回一個產生器物件,用於管理產生偽亂數的演算法狀態。

隨機取樣

seed

將所有裝置上用於產生亂數的種子設定為非確定性亂數。

manual_seed

設定所有裝置上用於產生亂數的種子。

initial_seed

以 Python long 型別傳回用於產生亂數的初始種子。

get_rng_state

torch.ByteTensor 型別傳回亂數產生器狀態。

set_rng_state

設定亂數產生器狀態。

torch.default_generator 傳回 預設的 CPU torch.Generator

bernoulli

從伯努利分佈中繪製二元亂數(0 或 1)。

multinomial

傳回一個張量,其中每一列包含從多項式機率分佈(更嚴格的定義是多變量,有關詳細資訊,請參閱 torch.distributions.multinomial.Multinomial)中取樣的 num_samples 個索引,該機率分佈位於張量 input 的對應列中。

normal

傳回一個由亂數組成的張量,這些亂數是從獨立的常態分佈中繪製的,其均值和標準差已給定。

poisson

傳回一個與 input 大小相同的張量,其中每個元素都是從卜瓦松分佈中取樣的,其速率參數由 input 中的對應元素給出,即

rand

傳回一個張量,其中填滿了從區間 [0,1)[0, 1) 上的均勻分佈中生成的亂數。

rand_like

傳回一個與 input 大小相同的張量,其中填滿了從區間 [0,1)[0, 1) 上的均勻分佈中生成的亂數。

randint

傳回一個張量,其中填滿了在 low(含)和 high(不含)之間均勻生成的亂數整數。

randint_like

傳回一個與張量 input 形狀相同的張量,其中填滿了在 low(含)和 high(不含)之間均勻生成的亂數整數。

randn

傳回一個張量,其中填滿了從均值為 0、變異數為 1 的常態分佈(也稱為標準常態分佈)中生成的亂數。

randn_like

傳回一個與 input 大小相同的張量,其中填滿了從均值為 0、變異數為 1 的常態分佈中生成的亂數。

randperm

傳回從 0n - 1 的整數的隨機排列。

就地隨機取樣

在張量上也定義了一些其他的就地隨機取樣函數。點選以參閱其說明文件。

準隨機取樣

quasirandom.SobolEngine

torch.quasirandom.SobolEngine 是一個用於產生(加擾)Sobol 序列的引擎。

序列化

save

將物件儲存到磁碟檔案。

load

從使用 torch.save() 儲存的檔案載入物件。

平行化

get_num_threads

傳回用於 CPU 運算平行化的執行緒數量

set_num_threads

設定用於 CPU 上運算內部平行化的執行緒數量。

get_num_interop_threads

傳回用於 CPU 上運算間平行化的執行緒數量(例如

set_num_interop_threads

設定用於運算間平行化的執行緒數量(例如

局部停用梯度計算

上下文管理器 torch.no_grad()torch.enable_grad()torch.set_grad_enabled() 有助於局部停用和啟用梯度計算。如需有關其用法的更多詳細資訊,請參閱局部停用梯度計算。這些上下文管理器是執行緒本機的,因此如果您使用 threading 模組等將工作發送到另一個執行緒,它們將無法運作。

範例

>>> x = torch.zeros(1, requires_grad=True)
>>> with torch.no_grad():
...     y = x * 2
>>> y.requires_grad
False

>>> is_train = False
>>> with torch.set_grad_enabled(is_train):
...     y = x * 2
>>> y.requires_grad
False

>>> torch.set_grad_enabled(True)  # this can also be used as a function
>>> y = x * 2
>>> y.requires_grad
True

>>> torch.set_grad_enabled(False)
>>> y = x * 2
>>> y.requires_grad
False

no_grad

停用梯度計算的上下文管理器。

enable_grad

啟用梯度計算的上下文管理器。

autograd.grad_mode.set_grad_enabled

開啟或關閉梯度計算的上下文管理器。

is_grad_enabled

如果目前已啟用梯度模式,則傳回 True。

autograd.grad_mode.inference_mode

啟用或停用推斷模式的上下文管理器。

is_inference_mode_enabled

如果目前已啟用推斷模式,則傳回 True。

數學運算

逐點運算

abs

計算 input 中每個元素的絕對值。

absolute

torch.abs() 的別名

acos

計算 input 中每個元素的反餘弦值。

arccos

torch.acos() 的別名。

acosh

傳回一個新的張量,其中包含 input 元素的反雙曲餘弦值。

arccosh

torch.acosh() 的別名。

add

other 乘以 alpha 後加到 input

addcdiv

執行 tensor1tensor2 的逐元素除法,將結果乘以純量 value,然後將其加到 input

addcmul

執行 tensor1tensor2 的逐元素乘法,將結果乘以純量 value,然後將其加到 input

angle

計算給定 input 張量的逐元素角度(以弧度為單位)。

asin

傳回一個新的張量,其中包含 input 元素的反正弦值。

arcsin

torch.asin() 的別名。

asinh

傳回一個新的張量,其中包含 input 元素的反雙曲正弦值。

arcsinh

torch.asinh() 的別名。

atan

傳回一個新的張量,其中包含 input 元素的反正切值。

arctan

torch.atan() 的別名。

atanh

傳回一個新的張量,其中包含 input 元素的反雙曲正切值。

arctanh

torch.atanh() 的別名。

atan2

inputi/otheri\text{input}_{i} / \text{other}_{i} 的逐元素反正切,並考慮象限。

arctan2

torch.atan2() 的別名。

bitwise_not

計算給定輸入張量的位元反轉。

bitwise_and

計算 inputother 的位元 AND 運算。

bitwise_or

計算 inputother 的位元 OR 運算。

bitwise_xor

計算 inputother 的位元 XOR 運算。

bitwise_left_shift

計算 input 的左算術移位 other 位元。

bitwise_right_shift

計算 input 的右算術移位 other 位元。

ceil

傳回一個新的張量,其中包含 input 元素的天花板值,即大於或等於每個元素的最小整數。

clamp

input 中的所有元素限制在範圍 [ min, max ] 內。

clip

torch.clamp() 的別名。

conj_physical

計算給定 input 張量的逐元素共軛複數。

copysign

建立一個新的浮點數張量,其大小與 input 相同,並帶有 other 的符號,逐元素計算。

cos

傳回一個新的張量,其中包含 input 元素的餘弦值。

cosh

傳回一個新的張量,其中包含 input 元素的雙曲餘弦值。

deg2rad

傳回一個新的張量,其中包含 input 的每個元素,從角度單位轉換為弧度單位。

div

將輸入 input 的每個元素除以 other 中的對應元素。

divide

torch.div() 的別名。

digamma

torch.special.digamma() 的別名。

erf

torch.special.erf() 的別名。

erfc

torch.special.erfc() 的別名。

erfinv

torch.special.erfinv() 的別名。

exp

傳回一個新的張量,其中包含輸入張量 input 元素的指數值。

exp2

torch.special.exp2() 的別名。

expm1

torch.special.expm1() 的別名。

fake_quantize_per_channel_affine

傳回一個新的張量,其中 input 中的數據使用 scalezero_pointquant_minquant_max 沿著 axis 指定的通道進行模擬量化。

fake_quantize_per_tensor_affine

傳回一個新的張量,其中 input 中的數據使用 scalezero_pointquant_minquant_max 進行模擬量化。

fix

torch.trunc() 的別名

float_power

input 提高到 exponent 的冪,逐元素計算,使用雙精度浮點數。

floor

傳回一個新的張量,其中包含 input 元素的向下取整值,即小於或等於每個元素的最大整數。

floor_divide

fmod

逐元素應用 C++ 的 std::fmod

frac

計算 input 中每個元素的小數部分。

frexp

input 分解為尾數和指數張量,使得 input=mantissa×2exponent\text{input} = \text{mantissa} \times 2^{\text{exponent}}.

gradient

使用 二階精確中央差分法 估計函數 g:RnRg : \mathbb{R}^n \rightarrow \mathbb{R} 在一個或多個維度上的梯度,並在邊界處使用一階或二階估計。

imag

傳回一個新的張量,其中包含 self 張量的虛部值。

ldexp

input 乘以 2 的 other 次方。

lerp

根據標量或張量 weight 對兩個張量 start(由 input 指定)和 end 進行線性插值,並傳回結果張量 out

lgamma

計算 input 上伽瑪函數絕對值的自然對數。

log

傳回一個新的張量,其中包含 input 元素的自然對數。

log10

傳回一個新的張量,其中包含 input 元素以 10 為底的對數。

log1p

傳回一個新的張量,其中包含 (1 + input) 的自然對數。

log2

傳回一個新的張量,其中包含 input 元素以 2 為底的對數。

logaddexp

輸入值的指數總和的對數。

logaddexp2

輸入值以 2 為底的指數總和的對數。

logical_and

計算給定輸入張量的逐元素邏輯 AND 運算。

logical_not

計算給定輸入張量的逐元素邏輯 NOT 運算。

logical_or

計算給定輸入張量的逐元素邏輯 OR 運算。

logical_xor

計算給定輸入張量的逐元素邏輯 XOR 運算。

logit

torch.special.logit() 的別名。

hypot

給定直角三角形的兩股,傳回其斜邊長。

i0

torch.special.i0() 的別名。

igamma

torch.special.gammainc() 的別名。

igammac

torch.special.gammaincc() 的別名。

mul

input 乘以 other

multiply

torch.mul() 的別名。

mvlgamma

torch.special.multigammaln() 的別名。

nan_to_num

input 中的 NaN、正無窮大和負無窮大值分別替換為 nanposinfneginf 指定的值。

neg

傳回一個新的張量,其中包含 input 元素的相反數。

negative

torch.neg() 的別名

nextafter

傳回 input 朝向 other 的下一個浮點數值,逐元素計算。

polygamma

torch.special.polygamma() 的別名。

positive

傳回 input

pow

使用 exponent 計算 input 中每個元素的冪,並傳回包含結果的張量。

quantized_batch_norm

對 4D (NCHW) 量化張量應用批次正規化。

quantized_max_pool1d

對由多個輸入平面組成的輸入量化張量應用 1D 最大池化。

quantized_max_pool2d

對由多個輸入平面組成的輸入量化張量應用 2D 最大池化。

rad2deg

傳回一個新的張量,其中包含 input 的每個元素,從弧度單位轉換為角度單位。

real

傳回一個新的張量,其中包含 self 張量的實數值。

reciprocal(倒數)

傳回一個新的張量,其中包含 input 各元素的倒數。

remainder(餘數)

逐元素計算 Python 的模數運算

round(四捨五入)

input 的元素四捨五入到最接近的整數。

rsqrt(平方根倒數)

傳回一個新的張量,其中包含 input 各元素平方根的倒數。

sigmoid( sigmoid 函數)

torch.special.expit() 的別名。

sign(正負號函數)

傳回一個新的張量,其中包含 input 各元素的正負號。

sgn(正負號函數,支援複數)

此函數是 torch.sign() 對複數張量的擴展。

signbit(符號位元函數)

測試 input 的每個元素是否設置了符號位元。

sin(正弦函數)

傳回一個新的張量,其中包含 input 各元素的正弦值。

sinc(辛格函數)

torch.special.sinc() 的別名。

sinh(雙曲正弦函數)

傳回一個新的張量,其中包含 input 各元素的雙曲正弦值。

softmax( softmax 函數)

torch.nn.functional.softmax() 的別名。

sqrt(平方根函數)

傳回一個新的張量,其中包含 input 各元素的平方根。

square(平方函數)

傳回一個新的張量,其中包含 input 各元素的平方。

sub(減法)

input 中減去 other 乘以 alpha 的結果。

subtract(減法)

torch.sub() 的別名。

tan(正切函數)

傳回一個新的張量,其中包含 input 各元素的正切值。

tanh(雙曲正切函數)

傳回一個新的張量,其中包含 input 各元素的雙曲正切值。

true_divide(真除法)

torch.div() 的別名,其中 rounding_mode=None

trunc(截斷函數)

傳回一個新的張量,其中包含 input 各元素截斷後的整數值。

xlogy(xlogy 函數)

torch.special.xlogy() 的別名。

歸約運算

argmax(最大值索引)

傳回 input 張量中所有元素最大值的索引。

argmin(最小值索引)

傳回扁平化張量或沿某個維度的最小值的索引。

amax(最大值)

傳回給定維度 diminput 張量每個切片的最大值。

amin(最小值)

傳回給定維度 diminput 張量每個切片的最小值。

aminmax(最小值和最大值)

計算 input 張量的最小值和最大值。

all(所有元素皆為真)

測試 input 中的所有元素是否都評估為 True

any(任一元素為真)

測試 input 中的任何元素是否評估為 True

max(最大值)

傳回 input 張量中所有元素的最大值。

min(最小值)

傳回 input 張量中所有元素的最小值。

dist(距離)

傳回 (input - other) 的 p 范數。

logsumexp(對數指數和)

傳回給定維度 diminput 張量每一行的指數和的對數。

mean(平均值)

傳回 input 張量中所有元素的平均值。

nanmean(忽略 NaN 的平均值)

計算沿指定維度所有 非 NaN 元素的平均值。

median(中位數)

傳回 input 中值的中位數。

nanmedian(忽略 NaN 的中位數)

傳回 input 中值的中位數,忽略 NaN 值。

mode(眾數)

傳回一個命名元組 (values, indices),其中 values 是給定維度 diminput 張量每一行的眾數值,即該行中出現次數最多的值,而 indices 是找到的每個眾數值的索引位置。

norm(範數)

傳回給定張量的矩陣範數或向量範數。

nansum(忽略 NaN 的總和)

傳回所有元素的總和,將非數值 (NaN) 視為零。

prod(乘積)

傳回 input 張量中所有元素的乘積。

quantile(分位數)

計算 input 張量沿維度 dim 的每一行的第 q 個分位數。

nanquantile(忽略 NaN 的分位數)

這是 torch.quantile() 的變體,它“忽略” NaN 值,計算分位數 q,就好像 input 中不存在 NaN 值一樣。

std(標準差)

計算 dim 指定的維度上的標準差。

std_mean(標準差和平均值)

計算 dim 指定的維度上的標準差和平均值。

sum(總和)

傳回 input 張量中所有元素的總和。

unique(唯一元素)

傳回輸入張量的唯一元素。

unique_consecutive(連續唯一元素)

從每個等效元素的連續組中刪除除第一個元素之外的所有元素。

var(變異數)

計算 dim 指定的維度上的變異數。

var_mean(變異數和平均值)

計算 dim 指定的維度上的變異數和平均值。

count_nonzero(非零元素計數)

計算張量 input 中沿給定 dim 的非零值的數量。

比較運算

allclose(近似相等)

此函數檢查 inputother 是否滿足條件

argsort(排序索引)

傳回按值升序對張量沿給定維度排序的索引。

eq(等於)

逐元素計算相等性

equal(相等)

如果兩個張量的大小和元素相同,則為 True,否則為 False

ge(大於或等於)

逐元素計算 inputother\text{input} \geq \text{other}

greater_equal(大於或等於)

torch.ge() 的別名。

gt(大於)

逐元素計算 input>other\text{input} > \text{other}

greater(大於)

torch.gt() 的別名。

isclose(近似相等)

傳回一個新的張量,其中包含布林元素,表示 input 的每個元素是否“接近”於 other 中的相應元素。

isfinite(有限數)

傳回一個新的張量,其中包含布林元素,表示每個元素是否為 有限數

isin(是否在集合中)

測試 elements 的每個元素是否在 test_elements 中。

isinf(是否為無限大)

測試 input 的每個元素是否為無限大(正無限大或負無限大)。

isposinf(是否為正無限大)

測試 input 的每個元素是否為正無限大。

isneginf(是否為負無限大)

測試 input 的每個元素是否為負無限大。

isnan(是否為 NaN)

傳回一個新的張量,其中包含布林元素,表示 input 的每個元素是否為 NaN。

isreal(是否為實數)

傳回一個新的張量,其中包含布林元素,表示 input 的每個元素是否為實數值。

kthvalue

傳回一個 namedtuple (values, indices),其中 values 是在給定維度 dim 中,input 張量每一列的第 k 個最小元素。

le

以元素方式計算 inputother\text{input} \leq \text{other}

less_equal

torch.le() 的別名。

lt

以元素方式計算 input<other\text{input} < \text{other}

less

torch.lt() 的別名。

maximum

計算 inputother 的元素最大值。

minimum

計算 inputother 的元素最小值。

fmax

計算 inputother 的元素最大值。

fmin

計算 inputother 的元素最小值。

ne

以元素方式計算 inputother\text{input} \neq \text{other}

not_equal

torch.ne() 的別名。

sort

沿著給定維度,按值遞增排序 input 張量的元素。

topk

沿著給定維度,傳回給定 input 張量的 k 個最大元素。

msort

沿著第一個維度,按值遞增排序 input 張量的元素。

頻譜運算

stft

短時傅立葉變換 (STFT)。

istft

反向短時傅立葉變換。

bartlett_window

Bartlett 窗函數。

blackman_window

Blackman 窗函數。

hamming_window

Hamming 窗函數。

hann_window

Hann 窗函數。

kaiser_window

計算窗口長度為 window_length 且形狀參數為 beta 的 Kaiser 窗。

其他運算

atleast_1d

傳回每個輸入張量的 1 維視圖,維度為零。

atleast_2d

傳回每個輸入張量的 2 維視圖,維度為零。

atleast_3d

傳回每個輸入張量的 3 維視圖,維度為零。

bincount

計算非負整數陣列中每個值的出現頻率。

block_diag

從提供的張量建立區塊對角矩陣。

broadcast_tensors

根據 廣播語義 廣播給定的張量。

broadcast_to

input 廣播到形狀 shape

broadcast_shapes

類似於 broadcast_tensors(),但用於形狀。

bucketize

傳回 input 中每個值所屬區間的索引,其中區間的邊界由 boundaries 設定。

cartesian_prod

對給定張量序列進行笛卡爾積。

cdist

計算兩個行向量集合中每對向量之間的批量 p-範數距離。

clone

傳回 input 的副本。

combinations

計算給定張量的長度為 rr 的組合。

corrcoef

估計由 input 矩陣給出的變數的 Pearson 積差相關係數矩陣,其中列是變數,行是觀測值。

cov

估計由 input 矩陣給出的變數的共變異數矩陣,其中列是變數,行是觀測值。

cross

傳回 inputother 在維度 dim 中的向量的外積。

cummax

傳回一個 namedtuple (values, indices),其中 valuesinput 在維度 dim 中的元素的累積最大值。

cummin

傳回一個 namedtuple (values, indices),其中 valuesinput 在維度 dim 中的元素的累積最小值。

cumprod

傳回 input 在維度 dim 中的元素的累積乘積。

cumsum

傳回 input 在維度 dim 中的元素的累積總和。

diag

  • 如果 input 是一個向量(1 維張量),則傳回一個 2 維方形張量

diag_embed

建立一個張量,其某些 2D 平面(由 dim1dim2 指定)的對角線由 input 填充。

diagflat

  • 如果 input 是一個向量(1 維張量),則傳回一個 2 維方形張量

diagonal

傳回 input 的部分視圖,其對角線元素相對於 dim1dim2 附加為形狀末尾的維度。

diff

沿著給定維度計算第 n 個前向差。

einsum

沿著使用基於愛因斯坦求和約定的符號指定的維度,對輸入 operands 的元素的乘積求和。

flatten

透過將 input 重塑為一維張量來將其展平。

flip

沿著 dims 中的給定軸反轉 n 維張量的順序。

fliplr

在左/右方向翻轉張量,傳回一個新的張量。

flipud

在上/下方向翻轉張量,傳回一個新的張量。

kron

計算 inputother 的克羅內克積,表示為 \otimes

rot90

在由 dims 軸指定的平面上將 n 維張量旋轉 90 度。

gcd

計算 inputother 的元素最大公因數 (GCD)。

histc

計算張量的直方圖。

histogram

計算張量中值的直方圖。

histogramdd

計算張量中值的多維直方圖。

meshgrid

建立由 attr:tensors 中的 1D 輸入指定的座標網格。

lcm

計算 inputother 的元素最小公倍數 (LCM)。

logcumsumexp

傳回 input 在維度 dim 中的元素的指數的累積總和的對數。

ravel

傳回一個連續的展平張量。

renorm

傳回一個張量,其中沿著維度 diminput 的每個子張量都被正規化,使得子張量的 p-範數低於值 maxnorm

repeat_interleave

重複張量的元素。

反轉 (roll)

沿給定維度反轉張量 input

搜尋排序 (searchsorted)

sorted_sequence 的*最內層*維度找出索引,以便在排序時,如果將 values 中的對應值插入索引之前,則會保留 sorted_sequence 中對應*最內層*維度的順序。

張量點積 (tensordot)

傳回 a 和 b 在多個維度上的縮併。

跡 (trace)

傳回輸入二維矩陣對角線元素的總和。

下三角矩陣 (tril)

傳回矩陣(二維張量)或矩陣批次 input 的下三角部分,結果張量 out 的其他元素設為 0。

下三角矩陣索引 (tril_indices)

以 2×N 張量的形式傳回 row×col 矩陣下三角部分的索引,其中第一列包含所有索引的列坐標,第二列包含欄坐標。

上三角矩陣 (triu)

傳回矩陣(二維張量)或矩陣批次 input 的上三角部分,結果張量 out 的其他元素設為 0。

上三角矩陣索引 (triu_indices)

以 2×N 張量的形式傳回 row×col 矩陣上三角部分的索引,其中第一列包含所有索引的列坐標,第二列包含欄坐標。

展開 (unflatten)

展開輸入張量在多個維度上的一個維度。

萬德蒙矩陣 (vander)

產生萬德蒙矩陣。

視為實數 (view_as_real)

input 的視圖傳回為實數張量。

視為複數 (view_as_complex)

input 的視圖傳回為複數張量。

解析共軛 (resolve_conj)

如果 input 的共軛位元設為 True,則傳回具體化共軛的新張量,否則傳回 input

解析負數 (resolve_neg)

如果 input 的負數位元設為 True,則傳回具體化負數的新張量,否則傳回 input

BLAS 和 LAPACK 運算

批次矩陣加法乘法 (addbmm)

對儲存在 batch1batch2 中的矩陣執行批次矩陣乘法,並減少加法步驟(所有矩陣乘法沿第一個維度累加)。

矩陣加法乘法 (addmm)

對矩陣 mat1mat2 執行矩陣乘法。

矩陣向量加法乘積 (addmv)

對矩陣 mat 和向量 vec 執行矩陣向量乘積。

向量外積加法 (addr)

執行向量 vec1vec2 的外積,並將其加到矩陣 input

批次矩陣乘法 (baddbmm)

batch1batch2 中的矩陣執行批次矩陣乘法。

批次矩陣乘法 (bmm)

對儲存在 inputmat2 中的矩陣執行批次矩陣乘法。

鏈式矩陣乘法 (chain_matmul)

傳回 NN 個二維張量的矩陣乘積。

Cholesky 分解 (cholesky)

計算對稱正定矩陣 AA 或對稱正定矩陣批次的 Cholesky 分解。

Cholesky 逆矩陣 (cholesky_inverse)

計算複數 Hermitian 或實數對稱正定矩陣在其 Cholesky 分解下的逆矩陣。

Cholesky 求解 (cholesky_solve)

計算具有複數 Hermitian 或實數對稱正定 lhs 的線性方程組的解,給定其 Cholesky 分解。

點積 (dot)

計算兩個一維張量的點積。

QR 分解 (geqrf)

這是直接呼叫 LAPACK 的 geqrf 的低階函數。

ger

torch.outer() 的別名。

內積 (inner)

計算一維張量的點積。

逆矩陣 (inverse)

torch.linalg.inv() 的別名

行列式 (det)

torch.linalg.det() 的別名

對數行列式 (logdet)

計算方陣或方陣批次的對數行列式。

符號與對數行列式 (slogdet)

torch.linalg.slogdet() 的別名

LU 分解 (lu)

計算矩陣或矩陣批次 A 的 LU 分解。

LU 求解 (lu_solve)

使用 lu_factor() 中 A 的部分樞軸 LU 分解,傳回線性系統 Ax=bAx = b 的 LU 解。

LU 解壓縮 (lu_unpack)

lu_factor() 傳回的 LU 分解解壓縮為 P、L、U 矩陣。

矩陣乘法 (matmul)

兩個張量的矩陣乘積。

矩陣冪 (matrix_power)

torch.linalg.matrix_power() 的別名

矩陣指數 (matrix_exp)

torch.linalg.matrix_exp() 的別名。

矩陣乘法 (mm)

對矩陣 inputmat2 執行矩陣乘法。

矩陣向量乘積 (mv)

對矩陣 input 和向量 vec 執行矩陣向量乘積。

orgqr

torch.linalg.householder_product() 的別名。

ormqr

計算 Householder 矩陣乘積與一般矩陣的矩陣乘法。

外積 (outer)

inputvec2 的外積。

偽逆矩陣 (pinverse)

torch.linalg.pinv() 的別名

QR 分解 (qr)

計算矩陣或矩陣批次 input 的 QR 分解,並傳回張量的命名元組 (Q, R),使得 input=QR\text{input} = Q R,其中 QQ 是正交矩陣或正交矩陣批次,RR 是上三角矩陣或上三角矩陣批次。

奇異值分解 (svd)

計算矩陣或矩陣批次 input 的奇異值分解。

低秩奇異值分解 (svd_lowrank)

傳回矩陣、矩陣批次或稀疏矩陣 AA 的奇異值分解 (U, S, V),使得 AUdiag(S)VHA \approx U \operatorname{diag}(S) V^{\text{H}}

pca_lowrank

對低秩矩陣、此類矩陣的批次或稀疏矩陣執行線性主成分分析 (PCA)。

lobpcg

使用無矩陣 LOBPCG 方法,找出對稱正定廣義特徵值問題的 k 個最大(或最小)特徵值和對應的特徵向量。

trapz

torch.trapezoid() 的別名。

trapezoid

沿著 dim 計算 梯形法則

cumulative_trapezoid

沿著 dim 累積計算 梯形法則

triangular_solve

使用方形上三角或下三角可逆矩陣 AA 和多個右側向量 bb 解決方程式系統。

vdot

沿著維度計算兩個一維向量的點積。

Foreach 操作

警告

此 API 處於測試階段,未來可能會有所變更。不支援正向模式 AD。

_foreach_abs

torch.abs() 套用至輸入清單的每個張量。

_foreach_abs_

torch.abs() 套用至輸入清單的每個張量。

_foreach_acos

torch.acos() 套用至輸入清單的每個張量。

_foreach_acos_

torch.acos() 套用至輸入清單的每個張量。

_foreach_asin

torch.asin() 套用至輸入清單的每個張量。

_foreach_asin_

torch.asin() 套用至輸入清單的每個張量。

_foreach_atan

torch.atan() 套用至輸入清單的每個張量。

_foreach_atan_

torch.atan() 套用至輸入清單的每個張量。

_foreach_ceil

torch.ceil() 套用至輸入清單的每個張量。

_foreach_ceil_

torch.ceil() 套用至輸入清單的每個張量。

_foreach_cos

torch.cos() 套用至輸入清單的每個張量。

_foreach_cos_

torch.cos() 套用至輸入清單的每個張量。

_foreach_cosh

torch.cosh() 套用至輸入清單的每個張量。

_foreach_cosh_

torch.cosh() 套用至輸入清單的每個張量。

_foreach_erf

torch.erf() 套用至輸入清單的每個張量。

_foreach_erf_

torch.erf() 套用至輸入清單的每個張量。

_foreach_erfc

torch.erfc() 套用至輸入清單的每個張量。

_foreach_erfc_

torch.erfc() 套用至輸入清單的每個張量。

_foreach_exp

torch.exp() 套用至輸入清單的每個張量。

_foreach_exp_

torch.exp() 套用至輸入清單的每個張量。

_foreach_expm1

torch.expm1() 套用至輸入清單的每個張量。

_foreach_expm1_

torch.expm1() 套用至輸入清單的每個張量。

_foreach_floor

torch.floor() 套用至輸入清單的每個張量。

_foreach_floor_

torch.floor() 套用至輸入清單的每個張量。

_foreach_log

torch.log() 套用至輸入清單的每個張量。

_foreach_log_

torch.log() 套用至輸入清單的每個張量。

_foreach_log10

torch.log10() 套用至輸入清單的每個張量。

_foreach_log10_

torch.log10() 套用至輸入清單的每個張量。

_foreach_log1p

torch.log1p() 套用至輸入清單的每個張量。

_foreach_log1p_

torch.log1p() 套用至輸入清單的每個張量。

_foreach_log2

torch.log2() 套用至輸入清單的每個張量。

_foreach_log2_

torch.log2() 套用至輸入清單的每個張量。

_foreach_neg

torch.neg() 套用至輸入清單的每個張量。

_foreach_neg_

torch.neg() 套用至輸入清單的每個張量。

_foreach_tan

torch.tan() 套用至輸入清單的每個張量。

_foreach_tan_

torch.tan() 套用至輸入清單的每個張量。

_foreach_sin

torch.sin() 套用至輸入清單的每個張量。

_foreach_sin_

torch.sin() 套用至輸入清單的每個張量。

_foreach_sinh

torch.sinh() 套用至輸入清單的每個張量。

_foreach_sinh_

torch.sinh() 套用至輸入清單的每個張量。

_foreach_round

torch.round() 套用至輸入清單的每個張量。

_foreach_round_

torch.round() 套用至輸入清單的每個張量。

_foreach_sqrt

torch.sqrt() 套用至輸入清單的每個張量。

_foreach_sqrt_

torch.sqrt() 套用至輸入清單的每個張量。

_foreach_lgamma

torch.lgamma() 套用至輸入清單的每個張量。

_foreach_lgamma_

torch.lgamma() 套用至輸入清單的每個張量。

_foreach_frac

torch.frac() 套用至輸入清單的每個張量。

_foreach_frac_

torch.frac() 套用至輸入清單的每個張量。

_foreach_reciprocal

torch.reciprocal() 套用至輸入清單的每個張量。

_foreach_reciprocal_

torch.reciprocal() 套用至輸入清單的每個張量。

_foreach_sigmoid

torch.sigmoid() 套用至輸入清單的每個張量。

_foreach_sigmoid_

torch.sigmoid() 套用至輸入清單的每個張量。

_foreach_trunc

torch.trunc() 套用至輸入清單的每個張量。

_foreach_trunc_

torch.trunc() 套用至輸入清單的每個張量。

_foreach_zero_

torch.zero() 套用至輸入清單的每個張量。

公用程式

compiled_with_cxx11_abi

傳回 PyTorch 是否使用 _GLIBCXX_USE_CXX11_ABI=1 建置

result_type

傳回對提供的輸入張量執行算術運算後會產生的 torch.dtype

can_cast

判斷在類型提升 文件 中說明的 PyTorch 轉換規則下,是否允許類型轉換。

promote_types

傳回大小和純量種類最小,且不小於或種類低於 type1type2torch.dtype

use_deterministic_algorithms

設定 PyTorch 操作是否必須使用「確定性」演算法。

are_deterministic_algorithms_enabled

如果已開啟全域確定性旗標,則傳回 True。

is_deterministic_algorithms_warn_only_enabled

如果全域確定性旗標設定為僅警告,則傳回 True。

set_deterministic_debug_mode

設定確定性操作的除錯模式。

get_deterministic_debug_mode

傳回確定性操作除錯模式的目前值。

set_float32_matmul_precision

設定 float32 矩陣乘法的內部精度。

get_float32_matmul_precision

傳回 float32 矩陣乘法精度的目前值。

set_warn_always

如果此旗標為 False(預設值),則某些 PyTorch 警告可能每個處理程序只會出現一次。

get_device_module

返回與給定裝置相關聯的模組(例如,torch.device('cuda')、"mtia:0"、"xpu" 等)。

is_warn_always_enabled

如果全域 warn_always 旗標已開啟,則返回 True。

vmap

vmap 是向量化映射;vmap(func) 返回一個新函數,該函數將 func 映射到輸入的某個維度。

_assert

Python assert 的包裝器,可進行符號追蹤。

符號數

類別 torch.SymInt(node)[原始碼]

類似於 int(包括魔術方法),但會將所有操作重定向到包裝的節點。這特別用於在符號形狀工作流程中符號記錄操作。

類別 torch.SymFloat(node)[原始碼]

類似於 float(包括魔術方法),但會將所有操作重定向到包裝的節點。這特別用於在符號形狀工作流程中符號記錄操作。

is_integer()[原始碼]

如果浮點數是整數,則返回 True。

類別 torch.SymBool(node)[原始碼]

類似於 bool(包括魔術方法),但會將所有操作重定向到包裝的節點。這特別用於在符號形狀工作流程中符號記錄操作。

與常規布林值不同,常規布林運算符會強制執行額外的防護,而不是符號求值。請改用位元運算符來處理此問題。

sym_float

SymInt 感知的浮點數轉換工具。

sym_int

SymInt 感知的整數轉換工具。

sym_max

SymInt 感知的最大值工具,可避免在 a < b 時分支。

sym_min

SymInt 感知的最小值工具。

sym_not

SymInt 感知的邏輯非工具。

sym_ite

匯出路徑

警告

此功能為原型,未來可能會發生不相容的變更。

匯出 generated/exportdb/index

控制流程

警告

此功能為原型,未來可能會發生不相容的變更。

cond

有條件地套用 true_fnfalse_fn

優化

compile

使用 TorchDynamo 和指定的後端優化給定的模型/函數。

torch.compile 文件

運算符標籤

類別 torch.Tag

成員

core

data_dependent_output

dynamic_output_shape

generated

inplace_view

needs_fixed_stride_order

nondeterministic_bitwise

nondeterministic_seeded

pointwise

pt2_compliant_tag

view_copy

屬性 name

文件

取得 PyTorch 的完整開發人員文件

查看文件

教學課程

取得適合初學者和進階開發人員的深入教學課程

查看教學課程

資源

尋找開發資源並取得問題解答

查看資源