快捷方式

torch.nn

这些是图的基本构建块

Buffer

一种不应被视为模型参数的张量。

Parameter

一种应被视为模块参数的张量。

UninitializedParameter

未初始化的参数。

UninitializedBuffer

未初始化的缓冲区。

容器

Module

所有神经网络模块的基类。

Sequential

一个顺序容器。

ModuleList

在列表中保存子模块。

ModuleDict

在字典中保存子模块。

ParameterList

在列表中保存参数。

ParameterDict

在字典中保存参数。

模块的全局钩子

register_module_forward_pre_hook

注册所有模块共有的前向预钩子。

register_module_forward_hook

为所有模块注册一个全局前向钩子。

register_module_backward_hook

注册所有模块共有的反向钩子。

register_module_full_backward_pre_hook

注册所有模块共有的反向预钩子。

register_module_full_backward_hook

注册所有模块共有的反向钩子。

register_module_buffer_registration_hook

注册所有模块共有的缓冲区注册钩子。

register_module_module_registration_hook

注册所有模块共有的模块注册钩子。

register_module_parameter_registration_hook

注册所有模块共有的参数注册钩子。

卷积层

nn.Conv1d

对由多个输入平面组成的输入信号应用一维卷积。

nn.Conv2d

对由多个输入平面组成的输入信号应用二维卷积。

nn.Conv3d

对由多个输入平面组成的输入信号应用三维卷积。

nn.ConvTranspose1d

对由多个输入平面组成的输入图像应用一维转置卷积运算符。

nn.ConvTranspose2d

对由多个输入平面组成的输入图像应用二维转置卷积运算符。

nn.ConvTranspose3d

对由多个输入平面组成的输入图像应用三维转置卷积运算符。

nn.LazyConv1d

一个 torch.nn.Conv1d 模块,其 in_channels 参数进行延迟初始化。

nn.LazyConv2d

一个 torch.nn.Conv2d 模块,其 in_channels 参数进行延迟初始化。

nn.LazyConv3d

一个 torch.nn.Conv3d 模块,其 in_channels 参数进行延迟初始化。

nn.LazyConvTranspose1d

一个 torch.nn.ConvTranspose1d 模块,其 in_channels 参数进行延迟初始化。

nn.LazyConvTranspose2d

一个 torch.nn.ConvTranspose2d 模块,其 in_channels 参数进行延迟初始化。

nn.LazyConvTranspose3d

一个 torch.nn.ConvTranspose3d 模块,其 in_channels 参数进行延迟初始化。

nn.Unfold

从批处理输入张量中提取滑动局部块。

nn.Fold

将滑动局部块数组组合成一个大型包含张量。

池化层

nn.MaxPool1d

对由多个输入平面组成的输入信号应用一维最大池化。

nn.MaxPool2d

对由多个输入平面组成的输入信号应用二维最大池化。

nn.MaxPool3d

对由多个输入平面组成的输入信号应用三维最大池化。

nn.MaxUnpool1d

计算 MaxPool1d 的部分逆。

nn.MaxUnpool2d

计算 MaxPool2d 的部分逆。

nn.MaxUnpool3d

计算 MaxPool3d 的部分逆。

nn.AvgPool1d

对由多个输入平面组成的输入信号应用一维平均池化。

nn.AvgPool2d

对由多个输入平面组成的输入信号应用二维平均池化。

nn.AvgPool3d

对由多个输入平面组成的输入信号应用三维平均池化。

nn.FractionalMaxPool2d

对由多个输入平面组成的输入信号应用二维分数最大池化。

nn.FractionalMaxPool3d

对由多个输入平面组成的输入信号应用三维分数最大池化。

nn.LPPool1d

对由多个输入平面组成的输入信号应用一维幂平均池化。

nn.LPPool2d

对由多个输入平面组成的输入信号应用二维幂平均池化。

nn.LPPool3d

对由多个输入平面组成的输入信号应用三维幂平均池化。

nn.AdaptiveMaxPool1d

对由多个输入平面组成的输入信号应用一维自适应最大池化。

nn.AdaptiveMaxPool2d

对由多个输入平面组成的输入信号应用二维自适应最大池化。

nn.AdaptiveMaxPool3d

对由多个输入平面组成的输入信号应用三维自适应最大池化。

nn.AdaptiveAvgPool1d

对由多个输入平面组成的输入信号应用一维自适应平均池化。

nn.AdaptiveAvgPool2d

对由多个输入平面组成的输入信号应用二维自适应平均池化。

nn.AdaptiveAvgPool3d

对由多个输入平面组成的输入信号应用三维自适应平均池化。

填充层

nn.ReflectionPad1d

使用输入边界反射来填充输入张量。

nn.ReflectionPad2d

使用输入边界反射来填充输入张量。

nn.ReflectionPad3d

使用输入边界反射来填充输入张量。

nn.ReplicationPad1d

使用输入边界复制来填充输入张量。

nn.ReplicationPad2d

使用输入边界复制来填充输入张量。

nn.ReplicationPad3d

使用输入边界复制来填充输入张量。

nn.ZeroPad1d

使用零填充输入张量边界。

nn.ZeroPad2d

使用零填充输入张量边界。

nn.ZeroPad3d

使用零填充输入张量边界。

nn.ConstantPad1d

使用常量值填充输入张量边界。

nn.ConstantPad2d

使用常量值填充输入张量边界。

nn.ConstantPad3d

使用常量值填充输入张量边界。

nn.CircularPad1d

使用输入边界循环填充来填充输入张量。

nn.CircularPad2d

使用输入边界循环填充来填充输入张量。

nn.CircularPad3d

使用输入边界循环填充来填充输入张量。

非线性激活(加权求和,非线性)

nn.ELU

按元素应用指数线性单元 (ELU) 函数。

nn.Hardshrink

按元素应用硬收缩 (Hardshrink) 函数。

nn.Hardsigmoid

按元素应用硬 sigmoid 函数。

nn.Hardtanh

按元素应用硬 tanh 函数。

nn.Hardswish

按元素应用 hardswish 函数。

nn.LeakyReLU

按元素应用 LeakyReLU 函数。

nn.LogSigmoid

按元素应用 logsigmoid 函数。

nn.MultiheadAttention

允许模型共同关注来自不同表示子空间的信息。

nn.PReLU

按元素应用 PReLU 函数。

nn.ReLU

按元素应用修正线性单元函数。

nn.ReLU6

按元素应用 ReLU6 函数。

nn.RReLU

按元素应用随机 leaky 修正线性单元函数。

nn.SELU

按元素应用 SELU 函数。

nn.CELU

按元素应用 CELU 函数。

nn.GELU

应用高斯误差线性单元函数。

nn.Sigmoid

按元素应用 sigmoid 函数。

nn.SiLU

按元素应用 sigmoid 线性单元 (SiLU) 函数。

nn.Mish

按元素应用 Mish 函数。

nn.Softplus

按元素应用 softplus 函数。

nn.Softshrink

按元素应用软收缩函数。

nn.Softsign

按元素应用 softsign 函数。

nn.Tanh

按元素应用双曲正切 (Tanh) 函数。

nn.Tanhshrink

按元素应用 tanhshrink 函数。

nn.Threshold

对输入张量的每个元素进行阈值处理。

nn.GLU

应用门控线性单元函数。

非线性激活(其他)

nn.Softmin

对 n 维输入张量应用 Softmin 函数。

nn.Softmax

对 n 维输入张量应用 Softmax 函数。

nn.Softmax2d

对每个空间位置的特征应用 SoftMax。

nn.LogSoftmax

对 n 维输入张量应用 log(Softmax(x))\log(\text{Softmax}(x)) 函数。

nn.AdaptiveLogSoftmaxWithLoss

高效的 Softmax 近似。

规范化层

nn.BatchNorm1d

对二维或三维输入应用批规范化。

nn.BatchNorm2d

对四维输入应用批规范化。

nn.BatchNorm3d

对五维输入应用批规范化。

nn.LazyBatchNorm1d

一个带有延迟初始化的 torch.nn.BatchNorm1d 模块。

nn.LazyBatchNorm2d

一个带有延迟初始化的 torch.nn.BatchNorm2d 模块。

nn.LazyBatchNorm3d

一个带有延迟初始化的 torch.nn.BatchNorm3d 模块。

nn.GroupNorm

对输入的小批量应用组规范化。

nn.SyncBatchNorm

对 N 维输入应用批规范化。

nn.InstanceNorm1d

应用实例规范化。

nn.InstanceNorm2d

应用实例规范化。

nn.InstanceNorm3d

应用实例规范化。

nn.LazyInstanceNorm1d

一个带有 num_features 参数延迟初始化的 torch.nn.InstanceNorm1d 模块。

nn.LazyInstanceNorm2d

一个带有 num_features 参数延迟初始化的 torch.nn.InstanceNorm2d 模块。

nn.LazyInstanceNorm3d

一个带有 num_features 参数延迟初始化的 torch.nn.InstanceNorm3d 模块。

nn.LayerNorm

对输入的小批量应用层规范化。

nn.LocalResponseNorm

对输入信号应用局部响应规范化。

nn.RMSNorm

对输入的小批量应用均方根层规范化。

循环层

nn.RNNBase

RNN 模块(RNN、LSTM、GRU)的基类。

nn.RNN

对输入序列应用具有 tanh\tanhReLU\text{ReLU} 非线性的多层 Elman RNN。

nn.LSTM

对输入序列应用多层长短期记忆 (LSTM) RNN。

nn.GRU

对输入序列应用多层门控循环单元 (GRU) RNN。

nn.RNNCell

具有 tanh 或 ReLU 非线性的 Elman RNN 单元。

nn.LSTMCell

长短期记忆 (LSTM) 单元。

nn.GRUCell

门控循环单元 (GRU) 单元。

Transformer 层

nn.Transformer

一个 Transformer 模型。

nn.TransformerEncoder

TransformerEncoder 是 N 个编码器层的堆栈。

nn.TransformerDecoder

TransformerDecoder 是 N 个解码器层的堆栈。

nn.TransformerEncoderLayer

TransformerEncoderLayer 由自注意力和前馈网络组成。

nn.TransformerDecoderLayer

TransformerDecoderLayer 由自注意力、多头注意力和前馈网络组成。

线性层

nn.Identity

一个占位符标识运算符,它对参数不敏感。

nn.Linear

对传入的数据应用仿射线性变换:y=xAT+by = xA^T + b.

nn.Bilinear

对输入数据应用双线性变换:y=x1TAx2+by = x_1^T A x_2 + b.

nn.LazyLinear

一个torch.nn.Linear 模块,其中in_features 是推断出来的。

Dropout 层

nn.Dropout

在训练期间,以概率p随机将输入张量中的某些元素置零。

nn.Dropout1d

随机清零整个通道。

nn.Dropout2d

随机清零整个通道。

nn.Dropout3d

随机清零整个通道。

nn.AlphaDropout

对输入应用 Alpha Dropout。

nn.FeatureAlphaDropout

随机屏蔽整个通道。

稀疏层

nn.Embedding

一个简单的查找表,存储固定字典和大小的嵌入。

nn.EmbeddingBag

计算“包”嵌入的总和或平均值,而无需实例化中间嵌入。

距离函数

nn.CosineSimilarity

返回x1x_1x2x_2 之间的余弦相似度,沿着dim计算。

nn.PairwiseDistance

计算输入向量之间的成对距离,或输入矩阵的列之间的成对距离。

损失函数

nn.L1Loss

创建一个准则,测量输入 xx 和目标 yy 中每个元素之间的平均绝对误差 (MAE)。

nn.MSELoss

创建一个准则,测量输入 xx 和目标 yy 中每个元素之间的均方误差(平方 L2 范数)。

nn.CrossEntropyLoss

此准则计算输入 logits 和目标之间的交叉熵损失。

nn.CTCLoss

连接主义时间分类损失。

nn.NLLLoss

负对数似然损失。

nn.PoissonNLLLoss

目标泊松分布的负对数似然损失。

nn.GaussianNLLLoss

高斯负对数似然损失。

nn.KLDivLoss

Kullback-Leibler 散度损失。

nn.BCELoss

创建一个准则,测量目标和输入概率之间的二元交叉熵。

nn.BCEWithLogitsLoss

此损失将Sigmoid 层和BCELoss 结合在一个类中。

nn.MarginRankingLoss

创建一个准则,测量给定输入 x1x1x2x2,两个 1D 小批量或 0D Tensor,以及标签 1D 小批量或 0D Tensor yy(包含 1 或 -1)时的损失。

nn.HingeEmbeddingLoss

测量给定输入张量 xx 和标签张量 yy(包含 1 或 -1)时的损失。

nn.MultiLabelMarginLoss

创建一个准则,优化输入 xx(一个 2D 小批量 Tensor)和输出 yy(这是一个 2D Tensor 的目标类索引)之间的多类多分类铰链损失(基于边际的损失)。

nn.HuberLoss

创建一个准则,如果绝对逐元素误差低于 delta,则使用平方项;否则使用 delta 缩放的 L1 项。

nn.SmoothL1Loss

创建一个准则,如果绝对逐元素误差低于 beta,则使用平方项;否则使用 L1 项。

nn.SoftMarginLoss

创建一个准则,优化输入张量 xx 和目标张量 yy(包含 1 或 -1)之间的两类分类逻辑损失。

nn.MultiLabelSoftMarginLoss

创建一个准则,优化基于最大熵的多标签一对多损失,在输入 xx 和大小为 (N,C)(N, C) 的目标 yy 之间。

nn.CosineEmbeddingLoss

创建了一个标准,用于衡量输入张量 x1x_1, x2x_2Tensor 标签 yy 的损失,其值为 1 或 -1。

nn.MultiMarginLoss

创建了一个标准,用于优化输入 xx(一个 2D 小型批次 Tensor)和输出 yy(它是一个目标类别索引的 1D 张量,0yx.size(1)10 \leq y \leq \text{x.size}(1)-1)

nn.TripletMarginLoss

创建一个标准,用于衡量给定输入张量 x1x1, x2x2, x3x3 和大于 00 的值的 margin 的三元组损失。

nn.TripletMarginWithDistanceLoss

创建一个标准,用于衡量给定输入张量 aa, ppnn(分别代表锚点、正例和负例),以及一个非负的实值函数(“距离函数”),用于计算锚点和正例之间的关系(“正例距离”)和锚点和负例之间的关系(“负例距离”)的三元组损失。

视觉层

nn.PixelShuffle

根据上采样因子重新排列张量中的元素。

nn.PixelUnshuffle

反转 PixelShuffle 操作。

nn.Upsample

上采样给定的多通道 1D(时间)、2D(空间)或 3D(体积)数据。

nn.UpsamplingNearest2d

对由多个输入通道组成的输入信号应用 2D 最近邻上采样。

nn.UpsamplingBilinear2d

对由多个输入通道组成的输入信号应用 2D 双线性上采样。

混洗层

nn.ChannelShuffle

划分并重新排列张量中的通道。

数据并行层(多 GPU、分布式)

nn.DataParallel

在模块级别实现数据并行。

nn.parallel.DistributedDataParallel

基于 torch.distributed 在模块级别实现分布式数据并行。

实用程序

来自 torch.nn.utils 模块

用于剪切参数梯度的实用程序函数。

clip_grad_norm_

剪切参数可迭代对象的参数梯度范数。

clip_grad_norm

剪切参数可迭代对象的参数梯度范数。

clip_grad_value_

在指定值处剪切参数可迭代对象的参数梯度。

用于将模块参数展平并取消展平成单个向量的实用程序函数。

parameters_to_vector

将参数可迭代对象展平成单个向量。

vector_to_parameters

将向量的切片复制到参数可迭代对象中。

用于将模块与 BatchNorm 模块融合的实用程序函数。

fuse_conv_bn_eval

将卷积模块和 BatchNorm 模块融合成一个新的卷积模块。

fuse_conv_bn_weights

将卷积模块参数和 BatchNorm 模块参数融合成新的卷积模块参数。

fuse_linear_bn_eval

将线性模块和 BatchNorm 模块融合成一个新的线性模块。

fuse_linear_bn_weights

将线性模块参数和 BatchNorm 模块参数融合成新的线性模块参数。

用于转换模块参数内存格式的实用程序函数。

convert_conv2d_weight_memory_format

nn.Conv2d.weightmemory_format 转换为 memory_format

convert_conv3d_weight_memory_format

nn.Conv3d.weightmemory_format 转换为 memory_format 该转换递归地应用于嵌套的 nn.Module,包括 module

用于对模块参数应用和移除权重归一化的实用程序函数。

weight_norm

对给定模块中的参数应用权重归一化。

remove_weight_norm

从模块中移除权重归一化重新参数化。

spectral_norm

对给定模块中的参数应用谱归一化。

remove_spectral_norm

从模块中移除谱归一化重新参数化。

用于初始化模块参数的实用程序函数。

skip_init

给定模块类对象和参数 / 关键字参数,实例化模块而不初始化参数 / 缓冲区。

用于修剪模块参数的实用程序类和函数。

prune.BasePruningMethod

用于创建新的修剪技术的抽象基类。

prune.PruningContainer

包含一系列修剪方法的容器,用于迭代修剪。

prune.Identity

实用修剪方法,不修剪任何单元,但使用全为一的掩码生成修剪参数化。

prune.RandomUnstructured

随机修剪张量中(当前未修剪的)单元。

prune.L1Unstructured

通过将 L1 范数最低的单元置零来修剪张量中(当前未修剪的)单元。

prune.RandomStructured

随机修剪张量中所有(当前未修剪的)通道。

prune.LnStructured

根据其 Ln 范数修剪张量中所有(当前未修剪的)通道。

prune.CustomFromMask

prune.identity

应用修剪重新参数化而不修剪任何单元。

prune.random_unstructured

通过移除随机(当前未修剪的)单元来修剪张量。

prune.l1_unstructured

通过移除 L1 范数最低的单元来修剪张量。

prune.random_structured

通过移除指定维度上的随机通道来修剪张量。

prune.ln_structured

通过移除指定维度上 Ln 范数最低的通道来修剪张量。

prune.global_unstructured

通过应用指定的 pruning_method,全局修剪对应于 parameters 中所有参数的张量。

prune.custom_from_mask

通过应用 mask 中的预先计算的掩码,修剪对应于 module 中名为 name 的参数的张量。

prune.remove

从模块中移除修剪重新参数化,并从正向钩子中移除修剪方法。

prune.is_pruned

通过查找修剪预钩子来检查模块是否已修剪。

使用 torch.nn.utils.parameterize.register_parametrization() 中的新参数化功能实现的参数化。

parametrizations.orthogonal

对矩阵或矩阵批次应用正交或酉参数化。

parametrizations.weight_norm

对给定模块中的参数应用权重归一化。

parametrizations.spectral_norm

对给定模块中的参数应用谱归一化。

用于对现有模块上的张量进行参数化的实用函数。请注意,这些函数可以用于对给定的参数或缓冲区进行参数化,前提是提供一个将输入空间映射到参数化空间的特定函数。它们不是将对象转换为参数的参数化。有关如何实现自己的参数化的更多信息,请参阅参数化教程

parametrize.register_parametrization

将参数化注册到模块中的张量。

parametrize.remove_parametrizations

删除模块中张量上的参数化。

parametrize.cached

上下文管理器,用于启用使用 register_parametrization() 注册的参数化中的缓存系统。

parametrize.is_parametrized

确定模块是否具有参数化。

parametrize.ParametrizationList

一个顺序容器,用于保存和管理参数化 torch.nn.Module 的原始参数或缓冲区。

用于以无状态方式调用给定模块的实用函数。

stateless.functional_call

通过用提供的参数和缓冲区替换模块参数和缓冲区来对模块执行函数调用。

其他模块中的实用函数

nn.utils.rnn.PackedSequence

保存打包序列的数据和 batch_sizes 列表。

nn.utils.rnn.pack_padded_sequence

打包包含可变长度填充序列的张量。

nn.utils.rnn.pad_packed_sequence

填充可变长度序列的打包批次。

nn.utils.rnn.pad_sequence

使用 padding_value 填充可变长度张量的列表。

nn.utils.rnn.pack_sequence

打包可变长度张量的列表。

nn.utils.rnn.unpack_sequence

将 PackedSequence 解包到可变长度张量的列表中。

nn.utils.rnn.unpad_sequence

将填充的张量解包到可变长度张量的列表中。

nn.Flatten

将一系列连续的维数展平为一个张量。

nn.Unflatten

展开张量维数,将其扩展到所需的形状。

量化函数

量化是指使用比浮点精度更低的位宽执行计算和存储张量的技术。PyTorch 支持每张量和每通道非对称线性量化。有关如何在 PyTorch 中使用量化函数的更多信息,请参阅量化 文档。

延迟模块初始化

nn.modules.lazy.LazyModuleMixin

用于延迟初始化参数的模块的 mixin,也称为“延迟模块”。

别名

以下别名与其在 torch.nn 中的对应项相同

nn.modules.normalization.RMSNorm

对输入的小批量应用均方根层规范化。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源