快捷方式

torch.nn

它們是構建圖的基本模組

Buffer

一種不應被視為模型引數的 Tensor。

Parameter

一種應被視為模組引數的 Tensor。

UninitializedParameter

一個未初始化的引數。

UninitializedBuffer

一個未初始化的 buffer。

容器

Module

所有神經網路模組的基類。

Sequential

一個序列容器。

ModuleList

以列表形式儲存子模組。

ModuleDict

以字典形式儲存子模組。

ParameterList

以列表形式儲存引數。

ParameterDict

以字典形式儲存引數。

模組的全域性鉤子

register_module_forward_pre_hook

註冊一個適用於所有模組的前向預鉤子。

register_module_forward_hook

註冊一個適用於所有模組的全域性前向鉤子。

register_module_backward_hook

註冊一個適用於所有模組的後向鉤子。

register_module_full_backward_pre_hook

註冊一個適用於所有模組的完整後向預鉤子。

register_module_full_backward_hook

註冊一個適用於所有模組的後向鉤子。

register_module_buffer_registration_hook

註冊一個適用於所有模組的 buffer 註冊鉤子。

register_module_module_registration_hook

註冊一個適用於所有模組的 module 註冊鉤子。

register_module_parameter_registration_hook

註冊一個適用於所有模組的 parameter 註冊鉤子。

卷積層

nn.Conv1d

對由若干輸入平面組成的輸入訊號應用 1D 卷積。

nn.Conv2d

對由若干輸入平面組成的輸入訊號應用 2D 卷積。

nn.Conv3d

對由若干輸入平面組成的輸入訊號應用 3D 卷積。

nn.ConvTranspose1d

對由若干輸入平面組成的輸入影像應用 1D 轉置卷積運算元。

nn.ConvTranspose2d

對由若干輸入平面組成的輸入影像應用 2D 轉置卷積運算元。

nn.ConvTranspose3d

對由若干輸入平面組成的輸入影像應用 3D 轉置卷積運算元。

nn.LazyConv1d

一個 torch.nn.Conv1d 模組,其 in_channels 引數支援延遲初始化。

nn.LazyConv2d

一個 torch.nn.Conv2d 模組,其 in_channels 引數支援延遲初始化。

nn.LazyConv3d

一個 torch.nn.Conv3d 模組,其 in_channels 引數支援延遲初始化。

nn.LazyConvTranspose1d

一個 torch.nn.ConvTranspose1d 模組,其 in_channels 引數支援延遲初始化。

nn.LazyConvTranspose2d

一個 torch.nn.ConvTranspose2d 模組,其 in_channels 引數支援延遲初始化。

nn.LazyConvTranspose3d

一個 torch.nn.ConvTranspose3d 模組,其 in_channels 引數支援延遲初始化。

nn.Unfold

從批次輸入張量中提取滑動的區域性塊。

nn.Fold

將滑動的區域性塊陣列組合成一個大的包含張量。

池化層

nn.MaxPool1d

對由若干輸入平面組成的輸入訊號應用 1D 最大池化。

nn.MaxPool2d

對由若干輸入平面組成的輸入訊號應用 2D 最大池化。

nn.MaxPool3d

對由若干輸入平面組成的輸入訊號應用 3D 最大池化。

nn.MaxUnpool1d

計算 MaxPool1d 的部分逆運算。

nn.MaxUnpool2d

計算 MaxPool2d 的部分逆運算。

nn.MaxUnpool3d

計算 MaxPool3d 的部分逆運算。

nn.AvgPool1d

對由若干輸入平面組成的輸入訊號應用 1D 平均池化。

nn.AvgPool2d

對由若干輸入平面組成的輸入訊號應用 2D 平均池化。

nn.AvgPool3d

對由若干輸入平面組成的輸入訊號應用 3D 平均池化。

nn.FractionalMaxPool2d

對由若干輸入平面組成的輸入訊號應用 2D 分數最大池化。

nn.FractionalMaxPool3d

對由若干輸入平面組成的輸入訊號應用 3D 分數最大池化。

nn.LPPool1d

對由若干輸入平面組成的輸入訊號應用 1D 冪平均池化。

nn.LPPool2d

對由若干輸入平面組成的輸入訊號應用 2D 冪平均池化。

nn.LPPool3d

對由若干輸入平面組成的輸入訊號應用 3D 冪平均池化。

nn.AdaptiveMaxPool1d

對由若干輸入平面組成的輸入訊號應用 1D 自適應最大池化。

nn.AdaptiveMaxPool2d

對由若干輸入平面組成的輸入訊號應用 2D 自適應最大池化。

nn.AdaptiveMaxPool3d

對由若干輸入平面組成的輸入訊號應用 3D 自適應最大池化。

nn.AdaptiveAvgPool1d

對由若干輸入平面組成的輸入訊號應用 1D 自適應平均池化。

nn.AdaptiveAvgPool2d

對由若干輸入平面組成的輸入訊號應用 2D 自適應平均池化。

nn.AdaptiveAvgPool3d

對由若干輸入平面組成的輸入訊號應用 3D 自適應平均池化。

填充層

nn.ReflectionPad1d

使用輸入邊界的反射來填充輸入張量。

nn.ReflectionPad2d

使用輸入邊界的反射來填充輸入張量。

nn.ReflectionPad3d

使用輸入邊界的反射來填充輸入張量。

nn.ReplicationPad1d

使用輸入邊界的複製來填充輸入張量。

nn.ReplicationPad2d

使用輸入邊界的複製來填充輸入張量。

nn.ReplicationPad3d

使用輸入邊界的複製來填充輸入張量。

nn.ZeroPad1d

使用零填充輸入張量的邊界。

nn.ZeroPad2d

使用零填充輸入張量的邊界。

nn.ZeroPad3d

使用零填充輸入張量的邊界。

nn.ConstantPad1d

使用常量值填充輸入張量的邊界。

nn.ConstantPad2d

使用常量值填充輸入張量的邊界。

nn.ConstantPad3d

使用常量值填充輸入張量的邊界。

nn.CircularPad1d

使用輸入邊界的迴圈填充來填充輸入張量。

nn.CircularPad2d

使用輸入邊界的迴圈填充來填充輸入張量。

nn.CircularPad3d

使用輸入邊界的迴圈填充來填充輸入張量。

非線性啟用函式 (加權和相關)

nn.ELU

逐元素應用 Exponential Linear Unit (ELU) 函式。

nn.Hardshrink

逐元素應用 Hard Shrinkage (Hardshrink) 函式。

nn.Hardsigmoid

逐元素應用 Hardsigmoid 函式。

nn.Hardtanh

逐元素應用 HardTanh 函式。

nn.Hardswish

逐元素應用 Hardswish 函式。

nn.LeakyReLU

逐元素應用 LeakyReLU 函式。

nn.LogSigmoid

逐元素應用 Logsigmoid 函式。

nn.MultiheadAttention

允許模型共同關注來自不同表示子空間的資訊。

nn.PReLU

逐元素應用 PReLU 函式。

nn.ReLU

逐元素應用整流線性單元函式。

nn.ReLU6

逐元素應用 ReLU6 函式。

nn.RReLU

逐元素應用隨機 leaky 整流線性單元函式。

nn.SELU

逐元素應用 SELU 函式。

nn.CELU

逐元素應用 CELU 函式。

nn.GELU

應用高斯誤差線性單元函式。

nn.Sigmoid

逐元素應用 Sigmoid 函式。

nn.SiLU

逐元素應用 Sigmoid Linear Unit (SiLU) 函式。

nn.Mish

逐元素應用 Mish 函式。

nn.Softplus

逐元素應用 Softplus 函式。

nn.Softshrink

逐元素應用 soft shrinkage 函式。

nn.Softsign

逐元素應用 Softsign 函式。

nn.Tanh

逐元素應用 Hyperbolic Tangent (Tanh) 函式。

nn.Tanhshrink

逐元素應用 Tanhshrink 函式。

nn.Threshold

對輸入 Tensor 的每個元素進行閾值處理。

nn.GLU

應用門控線性單元函式。

非線性啟用函式 (其他)

nn.Softmin

對 n 維輸入 Tensor 應用 Softmin 函式。

nn.Softmax

對 n 維輸入 Tensor 應用 Softmax 函式。

nn.Softmax2d

對每個空間位置的特徵應用 SoftMax。

nn.LogSoftmax

對 n 維輸入 Tensor 應用 log(Softmax(x))\log(\text{Softmax}(x)) 函式。

nn.AdaptiveLogSoftmaxWithLoss

高效的 softmax 近似。

歸一化層

nn.BatchNorm1d

對 2D 或 3D 輸入應用批歸一化。

nn.BatchNorm2d

對 4D 輸入應用批歸一化。

nn.BatchNorm3d

對 5D 輸入應用批歸一化。

nn.LazyBatchNorm1d

一個支援延遲初始化的 torch.nn.BatchNorm1d 模組。

nn.LazyBatchNorm2d

一個支援延遲初始化的 torch.nn.BatchNorm2d 模組。

nn.LazyBatchNorm3d

一個支援延遲初始化的 torch.nn.BatchNorm3d 模組。

nn.GroupNorm

對 mini-batch 輸入應用組歸一化。

nn.SyncBatchNorm

對 N 維輸入應用批歸一化。

nn.InstanceNorm1d

應用例項歸一化。

nn.InstanceNorm2d

應用例項歸一化。

nn.InstanceNorm3d

應用例項歸一化。

nn.LazyInstanceNorm1d

一個 torch.nn.InstanceNorm1d 模組,其 num_features 引數支援延遲初始化。

nn.LazyInstanceNorm2d

一個 torch.nn.InstanceNorm2d 模組,其 num_features 引數支援延遲初始化。

nn.LazyInstanceNorm3d

一個 torch.nn.InstanceNorm3d 模組,其 num_features 引數支援延遲初始化。

nn.LayerNorm

對 mini-batch 輸入應用層歸一化。

nn.LocalResponseNorm

對輸入訊號應用區域性響應歸一化。

nn.RMSNorm

對 mini-batch 輸入應用均方根層歸一化。

迴圈層

nn.RNNBase

RNN 模組 (RNN, LSTM, GRU) 的基類。

nn.RNN

對輸入序列應用帶有 tanh\tanhReLU\text{ReLU} 非線性的多層 Elman RNN。

nn.LSTM

對輸入序列應用多層長短期記憶 (LSTM) RNN。

nn.GRU

對輸入序列應用多層門控迴圈單元 (GRU) RNN。

nn.RNNCell

一個帶有 tanh 或 ReLU 非線性的 Elman RNN cell。

nn.LSTMCell

一個長短期記憶 (LSTM) cell。

nn.GRUCell

一個門控迴圈單元 (GRU) cell。

Transformer 層

nn.Transformer

一個 transformer 模型。

nn.TransformerEncoder

TransformerEncoder 是 N 個編碼器層的堆疊。

nn.TransformerDecoder

TransformerDecoder 是 N 個解碼器層的堆疊。

nn.TransformerEncoderLayer

TransformerEncoderLayer 由自注意力和前饋網路組成。

nn.TransformerDecoderLayer

TransformerDecoderLayer 由自注意力、多頭注意力和前饋網路組成。

線性層

nn.Identity

一個佔位符恆等運算元,對引數不敏感。

nn.Linear

對輸入資料應用仿射線性變換:y=xAT+by = xA^T + b

nn.Bilinear

對輸入資料應用雙線性變換:y=x1TAx2+by = x_1^T A x_2 + b

nn.LazyLinear

一個 torch.nn.Linear 模組,其 in_features 引數是推斷的。

Dropout 層

nn.Dropout

在訓練期間,以機率 p 隨機將輸入張量的一些元素歸零。

nn.Dropout1d

隨機將整個通道歸零。

nn.Dropout2d

隨機將整個通道歸零。

nn.Dropout3d

隨機將整個通道歸零。

nn.AlphaDropout

對輸入應用 Alpha Dropout。

nn.FeatureAlphaDropout

隨機遮蔽整個通道。

稀疏層

nn.Embedding

一個簡單的查詢表,用於儲存固定詞典和大小的嵌入。

nn.EmbeddingBag

計算嵌入“包”的總和或平均值,無需例項化中間嵌入。

距離函式

nn.CosineSimilarity

返回 x1x_1x2x_2 之間的餘弦相似度,沿 dim 計算。

nn.PairwiseDistance

計算輸入向量之間或輸入矩陣列之間的成對距離。

損失函式

nn.L1Loss

建立一個準則,用於衡量輸入 xx 和目標 yy 中每個元素之間的平均絕對誤差 (MAE)。

nn.MSELoss

建立一個準則,用於衡量輸入 xx 和目標 yy 中每個元素之間的均方誤差(平方 L2 範數)。

nn.CrossEntropyLoss

此準則計算輸入 logit 和目標之間的交叉熵損失。

nn.CTCLoss

連線主義時間分類損失。

nn.NLLLoss

負對數似然損失。

nn.PoissonNLLLoss

目標服從泊松分佈的負對數似然損失。

nn.GaussianNLLLoss

高斯負對數似然損失。

nn.KLDivLoss

Kullback-Leibler 散度損失。

nn.BCELoss

建立一個準則,用於衡量目標和輸入機率之間的二元交叉熵。

nn.BCEWithLogitsLoss

此損失函式在一個類中結合了 Sigmoid 層和 BCELoss

nn.MarginRankingLoss

建立一個準則,用於衡量給定輸入 x1x1、$x2$(兩個 1D 小批次或 0D 張量)和標籤 yy(一個 1D 小批次或 0D 張量,包含 1 或 -1)時的損失。

nn.HingeEmbeddingLoss

衡量給定輸入張量 xx 和標籤張量 yy(包含 1 或 -1)時的損失。

nn.MultiLabelMarginLoss

建立一個準則,用於最佳化輸入 xx(一個 2D 小批次 張量)和輸出 yy(一個 2D 目標類別索引 張量)之間的多類別多分類鉸鏈損失(基於間隔的損失)。

nn.HuberLoss

建立一個準則,當逐元素的絕對誤差小於 delta 時使用平方項,否則使用 delta 放大的 L1 項。

nn.SmoothL1Loss

建立一個準則,當逐元素的絕對誤差小於 beta 時使用平方項,否則使用 L1 項。

nn.SoftMarginLoss

建立一個準則,用於最佳化輸入張量 xx 和目標張量 yy(包含 1 或 -1)之間的兩類別分類邏輯損失。

nn.MultiLabelSoftMarginLoss

建立一個準則,用於最佳化基於最大熵的多標籤一對多損失,該損失衡量輸入 xx 和大小為 (N,C)(N, C) 的目標 yy 之間的差異。

nn.CosineEmbeddingLoss

建立一個準則,用於衡量給定輸入張量 x1x_1、$x_2$ 和值為 1 或 -1 的 張量 標籤 yy 時的損失。

nn.MultiMarginLoss

建立一個準則,用於最佳化輸入 xx(一個 2D 小批次 張量)和輸出 yy(一個 1D 目標類別索引張量,$0 \leq y \leq \text{x.size}(1)-1$)之間的多類別分類鉸鏈損失(基於間隔的損失)。

nn.TripletMarginLoss

建立一個準則,用於衡量給定輸入張量 x1x1、$x2$、$x3$ 和一個大於 $0$ 的間隔(margin)時的三元組損失。

nn.TripletMarginWithDistanceLoss

建立一個準則,用於衡量給定輸入張量 aa、$p$ 和 nn(分別代表錨點、正例和負例)以及一個非負實值函式(“距離函式”)時的三元組損失,該函式用於計算錨點與正例之間的關係(“正距離”)以及錨點與負例之間的關係(“負距離”)。

視覺層

nn.PixelShuffle

根據上取樣因子重新排列張量中的元素。

nn.PixelUnshuffle

反轉 PixelShuffle 操作。

nn.Upsample

對給定的多通道 1D(時間)、2D(空間)或 3D(體積)資料進行上取樣。

nn.UpsamplingNearest2d

對由多個輸入通道組成的輸入訊號應用 2D 最近鄰上取樣。

nn.UpsamplingBilinear2d

對由多個輸入通道組成的輸入訊號應用 2D 雙線性上取樣。

混洗層

nn.ChannelShuffle

劃分並重新排列張量中的通道。

資料並行層(多 GPU,分散式)

nn.DataParallel

在模組級別實現資料並行。

nn.parallel.DistributedDataParallel

在模組級別基於 torch.distributed 實現分散式資料並行。

工具函式

來自 torch.nn.utils 模組

用於裁剪引數梯度的工具函式。

clip_grad_norm_

裁剪可迭代引數的梯度範數。

clip_grad_norm

裁剪可迭代引數的梯度範數。

clip_grad_value_

按指定值裁剪可迭代引數的梯度。

get_total_norm

計算可迭代張量的範數。

clip_grads_with_norm_

根據預計算的總範數和期望的最大範數,縮放可迭代引數的梯度。

用於將模組引數展平為單個向量以及從單個向量還原的工具函式。

parameters_to_vector

將可迭代引數展平為單個向量。

vector_to_parameters

將向量的切片複製到可迭代引數中。

用於融合 Module 和 BatchNorm 模組的工具函式。

fuse_conv_bn_eval

將卷積模組和 BatchNorm 模組融合成一個新的卷積模組。

fuse_conv_bn_weights

將卷積模組引數和 BatchNorm 模組引數融合成新的卷積模組引數。

fuse_linear_bn_eval

將線性模組和 BatchNorm 模組融合成一個新的線性模組。

fuse_linear_bn_weights

將線性模組引數和 BatchNorm 模組引數融合成新的線性模組引數。

用於轉換模組引數記憶體格式的工具函式。

convert_conv2d_weight_memory_format

nn.Conv2d.weightmemory_format 轉換為指定的 memory_format

convert_conv3d_weight_memory_format

nn.Conv3d.weightmemory_format 轉換為指定的 memory_format。此轉換遞迴應用於巢狀的 nn.Module,包括 module

用於對模組引數應用和移除權重歸一化的工具函式。

weight_norm

對給定模組中的引數應用權重歸一化。

remove_weight_norm

從模組中移除權重歸一化重引數化。

spectral_norm

對給定模組中的引數應用譜歸一化。

remove_spectral_norm

從模組中移除譜歸一化重引數化。

用於初始化模組引數的工具函式。

skip_init

給定模組類物件和 args / kwargs,例項化模組而不初始化引數 / 緩衝區。

用於修剪模組引數的工具類和函式。

prune.BasePruningMethod

建立新修剪技術的抽象基類。

prune.PruningContainer

包含一系列修剪方法用於迭代修剪的容器。

prune.Identity

不修剪任何單元但生成帶全 1 掩碼的修剪重引數化的實用修剪方法。

prune.RandomUnstructured

隨機修剪張量中(當前未修剪的)單元。

prune.L1Unstructured

透過將 L1 範數最低的單元歸零來修剪張量中(當前未修剪的)單元。

prune.RandomStructured

隨機修剪張量中整個(當前未修剪的)通道。

prune.LnStructured

根據 Ln 範數修剪張量中整個(當前未修剪的)通道。

prune.CustomFromMask

prune.identity

應用修剪重引數化而不修剪任何單元。

prune.random_unstructured

透過移除隨機的(當前未修剪的)單元來修剪張量。

prune.l1_unstructured

透過移除 L1 範數最低的單元來修剪張量。

prune.random_structured

沿指定維度移除隨機通道來修剪張量。

prune.ln_structured

沿指定維度移除 Ln 範數最低的通道來修剪張量。

prune.global_unstructured

透過應用指定的 pruning_method,全域性修剪 parameters 中對應於所有引數的張量。

prune.custom_from_mask

透過應用 mask 中預計算的掩碼,修剪 module 中名為 name 的引數對應的張量。

prune.remove

從模組中移除修剪重引數化,並從前向鉤子中移除修剪方法。

prune.is_pruned

透過查詢修剪前置鉤子檢查模組是否已修剪。

使用 torch.nn.utils.parameterize.register_parametrization() 中的新引數化功能實現的引數化(Parametrizations)。

parametrizations.orthogonal

對矩陣或一批矩陣應用正交或酉引數化。

parametrizations.weight_norm

對給定模組中的引數應用權重歸一化。

parametrizations.spectral_norm

對給定模組中的引數應用譜歸一化。

用於對現有模組中的張量進行引數化的工具函式。請注意,這些函式可用於根據一個特定的函式對給定引數或緩衝區進行引數化,該函式將輸入空間對映到引數化空間。它們並不是將物件轉換為引數的引數化。有關如何實現您自己的引數化的更多資訊,請參閱引數化教程

parametrize.register_parametrization

在模組中向張量註冊引數化。

parametrize.remove_parametrizations

移除模組中張量上的引數化。

parametrize.cached

上下文管理器,可在使用 register_parametrization() 註冊的引數化中啟用快取系統。

parametrize.is_parametrized

確定模組是否具有引數化。

parametrize.ParametrizationList

一個序列容器,用於儲存和管理引數化 torch.nn.Module 的原始引數或緩衝區。

用於以無狀態方式呼叫給定模組的工具函式。

stateless.functional_call

透過將模組引數和緩衝區替換為提供的引數和緩衝區,對模組執行函式式呼叫。

其他模組中的工具函式

nn.utils.rnn.PackedSequence

儲存打包序列的資料和 batch_sizes 列表。

nn.utils.rnn.pack_padded_sequence

打包包含可變長度填充序列的張量。

nn.utils.rnn.pad_packed_sequence

填充打包的可變長度序列批次。

nn.utils.rnn.pad_sequence

使用 padding_value 填充可變長度張量列表。

nn.utils.rnn.pack_sequence

打包可變長度張量列表。

nn.utils.rnn.unpack_sequence

將 PackedSequence 解包為可變長度張量列表。

nn.utils.rnn.unpad_sequence

將填充張量去除填充為可變長度張量列表。

nn.Flatten

將連續的維度範圍展平為一個張量。

nn.Unflatten

還原張量維度,將其擴充套件到所需的形狀。

量化函式

量化是指以低於浮點精度的位寬執行計算和儲存張量的技術。PyTorch 支援逐張量和逐通道的非對稱線性量化。要了解如何在 PyTorch 中使用量化函式,請參閱量化文件。

惰性模組初始化

nn.modules.lazy.LazyModuleMixin

用於惰性初始化引數的模組的混合類,也稱為“惰性模組”。

別名

以下是 torch.nn 中對應部分的別名。

nn.modules.normalization.RMSNorm

對 mini-batch 輸入應用均方根層歸一化。

文件

訪問 PyTorch 的完整開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源