• 文件 >
  • torch.nn.functional
快捷方式

torch.nn.functional

卷積函式

conv1d

在由多個輸入平面組成的輸入訊號上應用一維卷積。

conv2d

在由多個輸入平面組成的輸入影像上應用二維卷積。

conv3d

在由多個輸入平面組成的輸入影像上應用三維卷積。

conv_transpose1d

在由多個輸入平面組成的輸入訊號上應用一維轉置卷積運算元,有時也稱為“反捲積”。

conv_transpose2d

在由多個輸入平面組成的輸入影像上應用二維轉置卷積運算元,有時也稱為“反捲積”。

conv_transpose3d

在由多個輸入平面組成的輸入影像上應用三維轉置卷積運算元,有時也稱為“反捲積”

unfold

從批次輸入的張量中提取滑動區域性塊。

fold

將滑動區域性塊陣列組合成一個大的包含張量。

池化函式

avg_pool1d

在由多個輸入平面組成的輸入訊號上應用一維平均池化。

avg_pool2d

kH×kWkH \times kW 區域上應用二維平均池化操作,步長為 sH×sWsH \times sW

avg_pool3d

kT×kH×kWkT \times kH \times kW 區域上應用三維平均池化操作,步長為 sT×sH×sWsT \times sH \times sW

max_pool1d

在由多個輸入平面組成的輸入訊號上應用一維最大池化。

max_pool2d

在由多個輸入平面組成的輸入訊號上應用二維最大池化。

max_pool3d

在由多個輸入平面組成的輸入訊號上應用三維最大池化。

max_unpool1d

計算 MaxPool1d 的部分逆運算。

max_unpool2d

計算 MaxPool2d 的部分逆運算。

max_unpool3d

計算 MaxPool3d 的部分逆運算。

lp_pool1d

在由多個輸入平面組成的輸入訊號上應用一維冪平均池化。

lp_pool2d

在由多個輸入平面組成的輸入訊號上應用二維冪平均池化。

lp_pool3d

在由多個輸入平面組成的輸入訊號上應用三維冪平均池化。

adaptive_max_pool1d

在由多個輸入平面組成的輸入訊號上應用一維自適應最大池化。

adaptive_max_pool2d

在由多個輸入平面組成的輸入訊號上應用二維自適應最大池化。

adaptive_max_pool3d

在由多個輸入平面組成的輸入訊號上應用三維自適應最大池化。

adaptive_avg_pool1d

在由多個輸入平面組成的輸入訊號上應用一維自適應平均池化。

adaptive_avg_pool2d

在由多個輸入平面組成的輸入訊號上應用二維自適應平均池化。

adaptive_avg_pool3d

在由多個輸入平面組成的輸入訊號上應用三維自適應平均池化。

fractional_max_pool2d

在由多個輸入平面組成的輸入訊號上應用二維分數最大池化。

fractional_max_pool3d

在由多個輸入平面組成的輸入訊號上應用三維分數最大池化。

注意力機制

模組 torch.nn.attention.bias 包含設計用於 scaled_dot_product_attention 的注意力偏置項(attention_biases)。

scaled_dot_product_attention

scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,

非線性啟用函式

threshold

對輸入張量的每個元素應用閾值。

threshold_

threshold() 的原地(in-place)版本。

relu

逐元素應用修正線性單元函式。

relu_

relu() 的原地(in-place)版本。

hardtanh

逐元素應用 HardTanh 函式。

hardtanh_

hardtanh() 的原地(in-place)版本。

hardswish

逐元素應用 hardswish 函式。

relu6

逐元素應用函式 ReLU6(x)=min(,x),6)\text{ReLU6}(x) = \min(\max(0,x), 6)

elu

逐元素應用指數線性單元 (ELU) 函式。

elu_

elu() 的原地(in-place)版本。

selu

逐元素應用函式 SELU(x)=scale(max(0,x)+min(0,α(exp(x)1)))\text{SELU}(x) = scale * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1))),其中 α=1.6732632423543772848170429916717\alpha=1.6732632423543772848170429916717scale=1.0507009873554804934193349852946scale=1.0507009873554804934193349852946

celu

逐元素應用函式 CELU(x)=max(0,x)+min(0,α(exp(x/α)1))\text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))

leaky_relu

逐元素應用函式 LeakyReLU(x)=max(0,x)+negative_slopemin(0,x)\text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)

leaky_relu_

leaky_relu() 的原地(in-place)版本。

prelu

逐元素應用函式 PReLU(x)=max(0,x)+weightmin(0,x)\text{PReLU}(x) = \max(0,x) + \text{weight} * \min(0,x),其中 weight 是一個可學習的引數。

rrelu

隨機化 Leaky ReLU。

rrelu_

rrelu() 的原地(in-place)版本。

glu

門控線性單元。

gelu

當 approximate 引數為 'none' 時,逐元素應用函式 GELU(x)=xΦ(x)\text{GELU}(x) = x * \Phi(x)

logsigmoid

逐元素應用函式 LogSigmoid(xi)=log(11+exp(xi))\text{LogSigmoid}(x_i) = \log \left(\frac{1}{1 + \exp(-x_i)}\right)

hardshrink

逐元素應用硬收縮函式

tanhshrink

逐元素應用函式 Tanhshrink(x)=xTanh(x)\text{Tanhshrink}(x) = x - \text{Tanh}(x)

softsign

逐元素應用函式 SoftSign(x)=x1+x\text{SoftSign}(x) = \frac{x}{1 + |x|}

softplus

逐元素應用函式 Softplus(x)=1βlog(1+exp(βx))\text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))

softmin

應用 softmin 函式。

softmax

應用 softmax 函式。

softshrink

逐元素應用 soft shrinkage 函式

gumbel_softmax

從 Gumbel-Softmax 分佈中取樣(連結 1 連結 2),並可選擇進行離散化。

log_softmax

應用 softmax 後接對數函式。

tanh

逐元素應用,Tanh(x)=tanh(x)=exp(x)exp(x)exp(x)+exp(x)\text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)}{\exp(x) + \exp(-x)}

sigmoid

逐元素應用函式 Sigmoid(x)=11+exp(x)\text{Sigmoid}(x) = \frac{1}{1 + \exp(-x)}

hardsigmoid

逐元素應用 Hardsigmoid 函式。

silu

逐元素應用 Sigmoid Linear Unit (SiLU) 函式。

mish

逐元素應用 Mish 函式。

batch_norm

對一批資料中的每個通道應用批次歸一化。

group_norm

對最後若干個維度應用組歸一化。

instance_norm

對一批資料中每個資料樣本的每個通道獨立應用例項歸一化。

layer_norm

對最後若干個維度應用層歸一化。

local_response_norm

對輸入訊號應用區域性響應歸一化。

rms_norm

應用均方根層歸一化。

normalize

對輸入沿指定維度執行 LpL_p 歸一化。

線性函式

linear

對傳入資料應用線性變換:y=xAT+by = xA^T + b

bilinear

對傳入資料應用雙線性變換:y=x1TAx2+by = x_1^T A x_2 + b

Dropout 函式

dropout

在訓練期間,以機率 p 隨機將輸入張量的某些元素置零。

alpha_dropout

對輸入應用 alpha dropout。

feature_alpha_dropout

隨機遮蔽整個通道(一個通道是一個特徵圖)。

dropout1d

隨機將整個通道置零(一個通道是一個一維特徵圖)。

dropout2d

隨機將整個通道置零(一個通道是一個二維特徵圖)。

dropout3d

隨機將整個通道置零(一個通道是一個三維特徵圖)。

稀疏函式

embedding

生成一個簡單的查詢表,用於在固定字典和大小中查詢嵌入。

embedding_bag

計算嵌入“袋”的總和、均值或最大值。

one_hot

接受形狀為 (*) 的 LongTensor 索引值,並返回形狀為 (*, num_classes) 的張量,該張量除最後一維索引與其對應的輸入張量值匹配的位置為 1 外,其餘位置均為零。

距離函式

pairwise_distance

詳細資訊請參閱 torch.nn.PairwiseDistance

cosine_similarity

返回 x1x2 之間的餘弦相似度,沿 dim 計算。

pdist

計算輸入中每對行向量之間的 p-範數距離。

損失函式

binary_cross_entropy

計算目標與輸入機率之間的二元交叉熵。

binary_cross_entropy_with_logits

計算目標與輸入 logits 之間的二元交叉熵。

poisson_nll_loss

泊松負對數似然損失。

cosine_embedding_loss

詳細資訊請參閱 CosineEmbeddingLoss

cross_entropy

計算輸入 logits 與目標之間的交叉熵損失。

ctc_loss

應用聯結主義時間分類 (CTC) 損失。

gaussian_nll_loss

高斯負對數似然損失。

hinge_embedding_loss

詳細資訊請參閱 HingeEmbeddingLoss

kl_div

計算 KL 散度損失。

l1_loss

計算逐元素絕對值差的平均值的函式。

mse_loss

計算逐元素均方誤差,支援可選加權。

margin_ranking_loss

詳細資訊請參閱 MarginRankingLoss

multilabel_margin_loss

詳細資訊請參閱 MultiLabelMarginLoss

multilabel_soft_margin_loss

詳細資訊請參閱 MultiLabelSoftMarginLoss

multi_margin_loss

詳細資訊請參閱 MultiMarginLoss

nll_loss

計算負對數似然損失。

huber_loss

計算 Huber 損失,支援可選加權。

smooth_l1_loss

計算平滑 L1 損失。

soft_margin_loss

詳細資訊請參閱 SoftMarginLoss

triplet_margin_loss

計算給定輸入張量與大於 0 的 margin 之間的 triplet 損失。

triplet_margin_with_distance_loss

使用自定義距離函式計算輸入張量的 triplet margin 損失。

視覺函式

pixel_shuffle

將形狀為 (,C×r2,H,W)(*, C \times r^2, H, W) 的張量中的元素重新排列成形狀為 (,C,H×r,W×r)(*, C, H \times r, W \times r) 的張量,其中 r 是 upscale_factor

pixel_unshuffle

透過將形狀為 (,C,H×r,W×r)(*, C, H \times r, W \times r) 的張量中的元素重新排列成形狀為 (,C×r2,H,W)(*, C \times r^2, H, W) 的張量,來反轉 PixelShuffle 操作,其中 r 是 downscale_factor

pad

填充張量。

interpolate

對輸入進行下/上取樣。

upsample

對輸入進行上取樣。

upsample_nearest

使用最近鄰畫素值對輸入進行上取樣。

upsample_bilinear

使用雙線性上取樣對輸入進行上取樣。

grid_sample

計算網格取樣。

affine_grid

給定一批仿射矩陣 theta,生成二維或三維流場(取樣網格)。

資料並行函式(多 GPU,分散式)

data_parallel

torch.nn.parallel.data_parallel

在 device_ids 中指定的多個 GPU 上並行評估模組(輸入)。

文件

訪問 PyTorch 全面的開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源