快捷方式

模块

PyTorch 使用模块来表示神经网络。模块是

  • 有状态计算的构建块。 PyTorch 提供了一个强大的模块库,并且可以轻松定义新的自定义模块,从而可以轻松构建精细的多层神经网络。

  • 与 PyTorch 的 自动微分 系统紧密集成。 模块使得为 PyTorch 的优化器指定可学习参数变得简单。

  • 易于使用和转换。 模块易于保存和恢复、在 CPU/GPU/TPU 设备之间传输、剪枝、量化等等。

本说明描述了模块,面向所有 PyTorch 用户。由于模块是 PyTorch 的基础,因此本说明中的许多主题将在其他说明或教程中详细阐述,并且此处也提供了许多此类文档的链接。

一个简单的自定义模块

首先,让我们看一个更简单的自定义版本的 PyTorch Linear 模块。此模块对其输入应用仿射变换。

import torch
from torch import nn

class MyLinear(nn.Module):
  def __init__(self, in_features, out_features):
    super().__init__()
    self.weight = nn.Parameter(torch.randn(in_features, out_features))
    self.bias = nn.Parameter(torch.randn(out_features))

  def forward(self, input):
    return (input @ self.weight) + self.bias

这个简单的模块具有以下模块的基本特征

  • 它继承自基类 Module。 所有模块都应继承自 Module,以便与其他模块组合。

  • 它定义了一些用于计算的“状态”。 在这里,状态由随机初始化的 weightbias 张量组成,它们定义了仿射变换。因为它们都被定义为 Parameter,所以它们被注册到模块中,并且将被自动跟踪并从对 parameters() 的调用中返回。参数可以被认为是模块计算的“可学习”方面(稍后会详细介绍)。请注意,模块不需要具有状态,也可以是无状态的。

  • 它定义了一个执行计算的 forward() 函数。 对于此仿射变换模块,输入与 weight 参数进行矩阵乘法运算(使用 @ 简写符号),然后添加到 bias 参数以生成输出。更一般地说,模块的 forward() 实现可以执行涉及任意数量输入和输出的任意计算。

这个简单的模块演示了模块如何将状态和计算打包在一起。可以构造并调用此模块的实例

m = MyLinear(4, 3)
sample_input = torch.randn(4)
m(sample_input)
: tensor([-0.3037, -1.0413, -4.2057], grad_fn=<AddBackward0>)

请注意,模块本身是可调用的,调用它会调用其 forward() 函数。此名称指的是“前向传递”和“反向传递”的概念,它们适用于每个模块。“前向传递”负责将模块表示的计算应用于给定的输入(如上面的代码段所示)。“反向传递”计算模块输出相对于其输入的梯度,这些梯度可用于通过梯度下降方法“训练”参数。PyTorch 的自动微分系统会自动处理此反向传递计算,因此不需要为每个模块手动实现 backward() 函数。通过连续的前向/反向传递来训练模块参数的过程将在 使用模块进行神经网络训练 中详细介绍。

可以通过调用 parameters()named_parameters() 来迭代模块注册的全部参数集,其中后者包括每个参数的名称

for parameter in m.named_parameters():
  print(parameter)
: ('weight', Parameter containing:
tensor([[ 1.0597,  1.1796,  0.8247],
        [-0.5080, -1.2635, -1.1045],
        [ 0.0593,  0.2469, -1.4299],
        [-0.4926, -0.5457,  0.4793]], requires_grad=True))
('bias', Parameter containing:
tensor([ 0.3634,  0.2015, -0.8525], requires_grad=True))

通常,模块注册的参数是模块计算中应该“学习”的方面。本说明的后面部分将展示如何使用 PyTorch 的优化器之一来更新这些参数。但是,在我们介绍之前,让我们先看看如何将模块相互组合。

模块作为构建块

模块可以包含其他模块,这使得它们成为开发更复杂功能的有用构建块。最简单的方法是使用 Sequential 模块。它允许我们将多个模块链接在一起

net = nn.Sequential(
  MyLinear(4, 3),
  nn.ReLU(),
  MyLinear(3, 1)
)

sample_input = torch.randn(4)
net(sample_input)
: tensor([-0.6749], grad_fn=<AddBackward0>)

请注意,Sequential 会自动将第一个 MyLinear 模块的输出作为输入馈送到 ReLU 中,并将该输出作为输入馈送到第二个 MyLinear 模块中。如上所示,它仅限于具有单个输入和输出的模块的顺序链接。

一般来说,建议为最简单用例之外的任何内容定义自定义模块,因为这可以完全灵活地决定如何将子模块用于模块的计算。

例如,下面是一个实现为自定义模块的简单神经网络

import torch.nn.functional as F

class Net(nn.Module):
  def __init__(self):
    super().__init__()
    self.l0 = MyLinear(4, 3)
    self.l1 = MyLinear(3, 1)
  def forward(self, x):
    x = self.l0(x)
    x = F.relu(x)
    x = self.l1(x)
    return x

此模块由两个“子级”或“子模块”(l0l1)组成,它们定义了神经网络的层并在模块的 forward() 方法中用于计算。可以通过调用 children()named_children() 迭代模块的直接子级。

net = Net()
for child in net.named_children():
  print(child)
: ('l0', MyLinear())
('l1', MyLinear())

要深入到直接子级之外,modules()named_modules() 会_递归地_迭代模块及其子模块。

class BigNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.l1 = MyLinear(5, 4)
    self.net = Net()
  def forward(self, x):
    return self.net(self.l1(x))

big_net = BigNet()
for module in big_net.named_modules():
  print(module)
: ('', BigNet(
  (l1): MyLinear()
  (net): Net(
    (l0): MyLinear()
    (l1): MyLinear()
  )
))
('l1', MyLinear())
('net', Net(
  (l0): MyLinear()
  (l1): MyLinear()
))
('net.l0', MyLinear())
('net.l1', MyLinear())

有时,模块需要动态定义子模块。此时,ModuleListModuleDict 模块就很有用;它们从列表或字典中注册子模块。

class DynamicNet(nn.Module):
  def __init__(self, num_layers):
    super().__init__()
    self.linears = nn.ModuleList(
      [MyLinear(4, 4) for _ in range(num_layers)])
    self.activations = nn.ModuleDict({
      'relu': nn.ReLU(),
      'lrelu': nn.LeakyReLU()
    })
    self.final = MyLinear(4, 1)
  def forward(self, x, act):
    for linear in self.linears:
      x = linear(x)
    x = self.activations[act](x)
    x = self.final(x)
    return x

dynamic_net = DynamicNet(3)
sample_input = torch.randn(4)
output = dynamic_net(sample_input, 'relu')

对于任何给定的模块,其参数包括其直接参数以及所有子模块的参数。这意味着调用 parameters()named_parameters() 将递归地包含子级参数,从而可以方便地优化网络中的所有参数。

for parameter in dynamic_net.named_parameters():
  print(parameter)
: ('linears.0.weight', Parameter containing:
tensor([[-1.2051,  0.7601,  1.1065,  0.1963],
        [ 3.0592,  0.4354,  1.6598,  0.9828],
        [-0.4446,  0.4628,  0.8774,  1.6848],
        [-0.1222,  1.5458,  1.1729,  1.4647]], requires_grad=True))
('linears.0.bias', Parameter containing:
tensor([ 1.5310,  1.0609, -2.0940,  1.1266], requires_grad=True))
('linears.1.weight', Parameter containing:
tensor([[ 2.1113, -0.0623, -1.0806,  0.3508],
        [-0.0550,  1.5317,  1.1064, -0.5562],
        [-0.4028, -0.6942,  1.5793, -1.0140],
        [-0.0329,  0.1160, -1.7183, -1.0434]], requires_grad=True))
('linears.1.bias', Parameter containing:
tensor([ 0.0361, -0.9768, -0.3889,  1.1613], requires_grad=True))
('linears.2.weight', Parameter containing:
tensor([[-2.6340, -0.3887, -0.9979,  0.0767],
        [-0.3526,  0.8756, -1.5847, -0.6016],
        [-0.3269, -0.1608,  0.2897, -2.0829],
        [ 2.6338,  0.9239,  0.6943, -1.5034]], requires_grad=True))
('linears.2.bias', Parameter containing:
tensor([ 1.0268,  0.4489, -0.9403,  0.1571], requires_grad=True))
('final.weight', Parameter containing:
tensor([[ 0.2509], [-0.5052], [ 0.3088], [-1.4951]], requires_grad=True))
('final.bias', Parameter containing:
tensor([0.3381], requires_grad=True))

使用 to() 可以轻松地将所有参数移动到不同的设备或更改其精度。

# Move all parameters to a CUDA device
dynamic_net.to(device='cuda')

# Change precision of all parameters
dynamic_net.to(dtype=torch.float64)

dynamic_net(torch.randn(5, device='cuda', dtype=torch.float64))
: tensor([6.5166], device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>)

更一般地说,可以通过使用 apply() 函数将任意函数递归地应用于模块及其子模块。例如,要将自定义初始化应用于模块及其子模块的参数,可以这样做:

# Define a function to initialize Linear weights.
# Note that no_grad() is used here to avoid tracking this computation in the autograd graph.
@torch.no_grad()
def init_weights(m):
  if isinstance(m, nn.Linear):
    nn.init.xavier_normal_(m.weight)
    m.bias.fill_(0.0)

# Apply the function recursively on the module and its submodules.
dynamic_net.apply(init_weights)

这些示例展示了如何通过模块组合形成复杂的神经网络并方便地对其进行操作。为了能够以最少的样板代码快速轻松地构建神经网络,PyTorch 在 torch.nn 命名空间中提供了一个大型的高性能模块库,这些模块执行常见的的操作,如池化、卷积、损失函数等。

在下一节中,我们将给出一个训练神经网络的完整示例。

有关更多信息,请查看:

使用模块进行神经网络训练

一旦构建了网络,就必须对其进行训练,并且可以使用 PyTorch 的 torch.optim 中的一个优化器轻松地优化其参数。

# Create the network (from previous section) and optimizer
net = Net()
optimizer = torch.optim.SGD(net.parameters(), lr=1e-4, weight_decay=1e-2, momentum=0.9)

# Run a sample training loop that "teaches" the network
# to output the constant zero function
for _ in range(10000):
  input = torch.randn(4)
  output = net(input)
  loss = torch.abs(output)
  net.zero_grad()
  loss.backward()
  optimizer.step()

# After training, switch the module to eval mode to do inference, compute performance metrics, etc.
# (see discussion below for a description of training and evaluation modes)
...
net.eval()
...

在这个简化的示例中,网络学习简单地输出零,因为任何非零输出都会根据其绝对值(通过使用 torch.abs() 作为损失函数)而受到“惩罚”。虽然这不是一个非常有趣的任务,但训练的关键部分都已存在。

  • 创建网络。

  • 创建一个优化器(在本例中为随机梯度下降优化器),并将网络的参数与其关联。

  • 一个训练循环...
    • 获取输入,

    • 运行网络,

    • 计算损失,

    • 将网络参数的梯度归零,

    • 调用 loss.backward() 来更新参数的梯度,

    • 调用 optimizer.step() 将梯度应用于参数。

运行上述代码段后,请注意网络的参数已更改。特别是,检查 l1weight 参数的值会发现其值现在更接近于 0(正如预期的那样)。

print(net.l1.weight)
: Parameter containing:
tensor([[-0.0013],
        [ 0.0030],
        [-0.0008]], requires_grad=True)

请注意,上述过程完全是在网络模块处于“训练模式”时完成的。默认情况下,模块处于训练模式,可以使用 train()eval() 在训练和评估模式之间切换。它们的行为可能会有所不同,具体取决于它们所处的模式。例如,BatchNorm 模块在训练期间维护一个运行均值和方差,这些值在模块处于评估模式时不会更新。通常,模块在训练期间应处于训练模式,并且仅在进行推理或评估时才切换到评估模式。下面是一个自定义模块的示例,该模块在两种模式下的行为不同:

class ModalModule(nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, x):
    if self.training:
      # Add a constant only in training mode.
      return x + 1.
    else:
      return x


m = ModalModule()
x = torch.randn(4)

print('training mode output: {}'.format(m(x)))
: tensor([1.6614, 1.2669, 1.0617, 1.6213, 0.5481])

m.eval()
print('evaluation mode output: {}'.format(m(x)))
: tensor([ 0.6614,  0.2669,  0.0617,  0.6213, -0.4519])

训练神经网络通常很棘手。有关更多信息,请查看:

模块状态

在上一节中,我们演示了如何训练模块的“参数”或计算的可学习方面。现在,如果我们想将训练好的模型保存到磁盘,我们可以通过保存其 state_dict(即“状态字典”)来实现。

# Save the module
torch.save(net.state_dict(), 'net.pt')

...

# Load the module later on
new_net = Net()
new_net.load_state_dict(torch.load('net.pt'))
: <All keys matched successfully>

模块的 state_dict 包含影响其计算的状态。这包括但不限于模块的参数。对于某些模块,拥有超出参数的状态可能很有用,这些状态会影响模块计算但不可学习。对于此类情况,PyTorch 提供了“缓冲区”的概念,包括“持久性”和“非持久性”缓冲区。以下是模块可以具有的各种状态类型的概述:

  • **参数**:计算的可学习方面;包含在 state_dict 中。

  • **缓冲区**:计算的不可学习方面。

    • **持久性**缓冲区:包含在 state_dict 中(即在保存和加载时序列化)。

    • **非持久性**缓冲区:未包含在 state_dict 中(即在序列化中被忽略)。

作为使用缓冲区的示例,请考虑一个维护运行均值的简单模块。我们希望将运行均值的当前值视为模块 state_dict 的一部分,以便在加载模块的序列化形式时恢复它,但我们不希望它是可学习的。此代码段显示了如何使用 register_buffer() 来实现此目的:

class RunningMean(nn.Module):
  def __init__(self, num_features, momentum=0.9):
    super().__init__()
    self.momentum = momentum
    self.register_buffer('mean', torch.zeros(num_features))
  def forward(self, x):
    self.mean = self.momentum * self.mean + (1.0 - self.momentum) * x
    return self.mean

现在,运行均值的当前值被视为模块 state_dict 的一部分,并且在从磁盘加载模块时将正确恢复。

m = RunningMean(4)
for _ in range(10):
  input = torch.randn(4)
  m(input)

print(m.state_dict())
: OrderedDict([('mean', tensor([ 0.1041, -0.1113, -0.0647,  0.1515]))]))

# Serialized form will contain the 'mean' tensor
torch.save(m.state_dict(), 'mean.pt')

m_loaded = RunningMean(4)
m_loaded.load_state_dict(torch.load('mean.pt'))
assert(torch.all(m.mean == m_loaded.mean))

如前所述,可以通过将缓冲区标记为非持久性来将它们排除在模块的 state_dict 之外。

self.register_buffer('unserialized_thing', torch.randn(5), persistent=False)

持久性和非持久性缓冲区都会受到使用 to() 应用的模型范围设备/dtype 更改的影响。

# Moves all module parameters and buffers to the specified device / dtype
m.to(device='cuda', dtype=torch.float64)

可以使用 buffers()named_buffers() 迭代模块的缓冲区。

for buffer in m.named_buffers():
  print(buffer)

以下类演示了在模块中注册参数和缓冲区的各种方法:

class StatefulModule(nn.Module):
  def __init__(self):
    super().__init__()
    # Setting a nn.Parameter as an attribute of the module automatically registers the tensor
    # as a parameter of the module.
    self.param1 = nn.Parameter(torch.randn(2))

    # Alternative string-based way to register a parameter.
    self.register_parameter('param2', nn.Parameter(torch.randn(3)))

    # Reserves the "param3" attribute as a parameter, preventing it from being set to anything
    # except a parameter. "None" entries like this will not be present in the module's state_dict.
    self.register_parameter('param3', None)

    # Registers a list of parameters.
    self.param_list = nn.ParameterList([nn.Parameter(torch.randn(2)) for i in range(3)])

    # Registers a dictionary of parameters.
    self.param_dict = nn.ParameterDict({
      'foo': nn.Parameter(torch.randn(3)),
      'bar': nn.Parameter(torch.randn(4))
    })

    # Registers a persistent buffer (one that appears in the module's state_dict).
    self.register_buffer('buffer1', torch.randn(4), persistent=True)

    # Registers a non-persistent buffer (one that does not appear in the module's state_dict).
    self.register_buffer('buffer2', torch.randn(5), persistent=False)

    # Reserves the "buffer3" attribute as a buffer, preventing it from being set to anything
    # except a buffer. "None" entries like this will not be present in the module's state_dict.
    self.register_buffer('buffer3', None)

    # Adding a submodule registers its parameters as parameters of the module.
    self.linear = nn.Linear(2, 3)

m = StatefulModule()

# Save and load state_dict.
torch.save(m.state_dict(), 'state.pt')
m_loaded = StatefulModule()
m_loaded.load_state_dict(torch.load('state.pt'))

# Note that non-persistent buffer "buffer2" and reserved attributes "param3" and "buffer3" do
# not appear in the state_dict.
print(m_loaded.state_dict())
: OrderedDict([('param1', tensor([-0.0322,  0.9066])),
               ('param2', tensor([-0.4472,  0.1409,  0.4852])),
               ('buffer1', tensor([ 0.6949, -0.1944,  1.2911, -2.1044])),
               ('param_list.0', tensor([ 0.4202, -0.1953])),
               ('param_list.1', tensor([ 1.5299, -0.8747])),
               ('param_list.2', tensor([-1.6289,  1.4898])),
               ('param_dict.bar', tensor([-0.6434,  1.5187,  0.0346, -0.4077])),
               ('param_dict.foo', tensor([-0.0845, -1.4324,  0.7022])),
               ('linear.weight', tensor([[-0.3915, -0.6176],
                                         [ 0.6062, -0.5992],
                                         [ 0.4452, -0.2843]])),
               ('linear.bias', tensor([-0.3710, -0.0795, -0.3947]))])

有关更多信息,请查看:

模块初始化

默认情况下,torch.nn 提供的模块的参数和浮点缓冲区在模块实例化期间初始化为 CPU 上的 32 位浮点值,使用的初始化方案经历史证明对该模块类型表现良好。对于某些用例,可能需要使用不同的 dtype、设备(例如 GPU)或初始化技术进行初始化。

示例:

# Initialize module directly onto GPU.
m = nn.Linear(5, 3, device='cuda')

# Initialize module with 16-bit floating point parameters.
m = nn.Linear(5, 3, dtype=torch.half)

# Skip default parameter initialization and perform custom (e.g. orthogonal) initialization.
m = torch.nn.utils.skip_init(nn.Linear, 5, 3)
nn.init.orthogonal_(m.weight)

请注意,上面演示的设备和 dtype 选项也适用于为模块注册的任何浮点缓冲区。

m = nn.BatchNorm2d(3, dtype=torch.half)
print(m.running_mean)
: tensor([0., 0., 0.], dtype=torch.float16)

虽然模块编写者可以使用任何设备或 dtype 来初始化其自定义模块中的参数,但最佳实践是默认也使用 dtype=torch.floatdevice='cpu'。或者,您可以通过遵循上面演示的惯例(所有 torch.nn 模块都遵循的惯例)为自定义模块在这些方面提供完全的灵活性。

  • 提供一个 device 构造函数关键字参数,该参数适用于模块注册的任何参数/缓冲区。

  • 提供一个适用于模块注册的任何参数/浮点数缓冲区的 dtype 构造函数关键字参数。

  • 仅在模块构造函数中对参数和缓冲区使用初始化函数(即来自 torch.nn.init 的函数)。请注意,这仅是使用 skip_init() 所必需的;有关解释,请参阅此页面

有关更多信息,请查看:

模块钩子

使用模块进行神经网络训练中,我们演示了模块的训练过程,该过程迭代地执行前向和反向传递,并在每次迭代中更新模块参数。为了更好地控制此过程,PyTorch 提供了“钩子”,可以在前向或反向传递期间执行任意计算,甚至可以根据需要修改传递的方式。此功能的一些有用示例包括调试、可视化激活、深入检查梯度等。可以将钩子添加到您自己未编写的模块中,这意味着此功能可以应用于第三方或 PyTorch 提供的模块。

PyTorch 为模块提供了两种类型的钩子

所有钩子都允许用户返回一个更新后的值,该值将在剩余的计算中使用。因此,这些钩子可用于沿着常规模块前向/反向传递执行任意代码,或修改某些输入/输出,而无需更改模块的 forward() 函数。

以下是一个演示前向和反向钩子用法的示例

torch.manual_seed(1)

def forward_pre_hook(m, inputs):
  # Allows for examination and modification of the input before the forward pass.
  # Note that inputs are always wrapped in a tuple.
  input = inputs[0]
  return input + 1.

def forward_hook(m, inputs, output):
  # Allows for examination of inputs / outputs and modification of the outputs
  # after the forward pass. Note that inputs are always wrapped in a tuple while outputs
  # are passed as-is.

  # Residual computation a la ResNet.
  return output + inputs[0]

def backward_hook(m, grad_inputs, grad_outputs):
  # Allows for examination of grad_inputs / grad_outputs and modification of
  # grad_inputs used in the rest of the backwards pass. Note that grad_inputs and
  # grad_outputs are always wrapped in tuples.
  new_grad_inputs = [torch.ones_like(gi) * 42. for gi in grad_inputs]
  return new_grad_inputs

# Create sample module & input.
m = nn.Linear(3, 3)
x = torch.randn(2, 3, requires_grad=True)

# ==== Demonstrate forward hooks. ====
# Run input through module before and after adding hooks.
print('output with no forward hooks: {}'.format(m(x)))
: output with no forward hooks: tensor([[-0.5059, -0.8158,  0.2390],
                                        [-0.0043,  0.4724, -0.1714]], grad_fn=<AddmmBackward>)

# Note that the modified input results in a different output.
forward_pre_hook_handle = m.register_forward_pre_hook(forward_pre_hook)
print('output with forward pre hook: {}'.format(m(x)))
: output with forward pre hook: tensor([[-0.5752, -0.7421,  0.4942],
                                        [-0.0736,  0.5461,  0.0838]], grad_fn=<AddmmBackward>)

# Note the modified output.
forward_hook_handle = m.register_forward_hook(forward_hook)
print('output with both forward hooks: {}'.format(m(x)))
: output with both forward hooks: tensor([[-1.0980,  0.6396,  0.4666],
                                          [ 0.3634,  0.6538,  1.0256]], grad_fn=<AddBackward0>)

# Remove hooks; note that the output here matches the output before adding hooks.
forward_pre_hook_handle.remove()
forward_hook_handle.remove()
print('output after removing forward hooks: {}'.format(m(x)))
: output after removing forward hooks: tensor([[-0.5059, -0.8158,  0.2390],
                                               [-0.0043,  0.4724, -0.1714]], grad_fn=<AddmmBackward>)

# ==== Demonstrate backward hooks. ====
m(x).sum().backward()
print('x.grad with no backwards hook: {}'.format(x.grad))
: x.grad with no backwards hook: tensor([[ 0.4497, -0.5046,  0.3146],
                                         [ 0.4497, -0.5046,  0.3146]])

# Clear gradients before running backward pass again.
m.zero_grad()
x.grad.zero_()

m.register_full_backward_hook(backward_hook)
m(x).sum().backward()
print('x.grad with backwards hook: {}'.format(x.grad))
: x.grad with backwards hook: tensor([[42., 42., 42.],
                                      [42., 42., 42.]])

高级功能

PyTorch 还提供了一些更高级的功能,旨在与模块配合使用。所有这些功能都可用于自定义编写的模块,但需要注意的是,某些功能可能需要模块符合特定的约束才能得到支持。有关这些功能和相应要求的深入讨论,请参阅以下链接。

分布式训练

PyTorch 中存在多种分布式训练方法,既可以使用多个 GPU 扩展训练,也可以跨多台机器进行训练。有关如何利用这些方法的详细信息,请查看分布式训练概述页面

分析性能

PyTorch Profiler 可用于识别模型中的性能瓶颈。它可以测量并输出内存使用量和时间消耗方面的性能特征。

通过量化提高性能

对模块应用量化技术可以通过使用比浮点精度更低的位宽来提高性能和内存使用率。请此处查看 PyTorch 提供的各种量化机制。

通过剪枝提高内存使用率

大型深度学习模型通常参数过多,导致内存使用率很高。为了解决这个问题,PyTorch 提供了模型剪枝机制,可以在保持任务准确性的同时帮助减少内存使用。 剪枝教程介绍了如何利用 PyTorch 提供的剪枝技术或根据需要定义自定义剪枝技术。

参数化

对于某些应用,在模型训练期间约束参数空间可能是有益的。例如,强制学习到的参数的正交性可以提高 RNN 的收敛性。PyTorch 提供了一种应用参数化(例如 this)的机制,并进一步允许定义自定义约束。

使用 FX 转换模块

PyTorch 的FX 组件提供了一种通过直接操作模块计算图来灵活地转换模块的方法。这可用于以编程方式生成或操作模块,以满足广泛的用例。要探索 FX,请查看以下使用 FX 进行卷积 + 批量归一化融合CPU 性能分析的示例。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得您的问题的答案

查看资源