快捷方式

模块

PyTorch 使用模块来表示神经网络。模块是

  • 有状态计算的构建块。 PyTorch 提供了一个强大的模块库,并使定义新的自定义模块变得简单,从而可以轻松构建复杂的、多层神经网络。

  • 与 PyTorch 的 autograd 系统紧密集成。 模块使为 PyTorch 的优化器指定可学习参数变得简单。

  • 易于使用和转换。 模块易于保存和恢复,在 CPU/GPU/TPU 设备之间传输,修剪,量化等等。

本说明介绍了模块,适用于所有 PyTorch 用户。由于模块对于 PyTorch 来说是如此基础,因此本说明中的许多主题将在其他说明或教程中进行详细说明,并且此处提供了指向许多这些文档的链接。

一个简单的自定义模块

首先,让我们看一下 PyTorch 的 Linear 模块的更简单、自定义版本。此模块对输入应用仿射变换。

import torch
from torch import nn

class MyLinear(nn.Module):
  def __init__(self, in_features, out_features):
    super().__init__()
    self.weight = nn.Parameter(torch.randn(in_features, out_features))
    self.bias = nn.Parameter(torch.randn(out_features))

  def forward(self, input):
    return (input @ self.weight) + self.bias

此简单模块具有模块以下基本特征

  • 它继承自基模块类。 所有模块都应该子类化 Module 以便与其他模块组合。

  • 它定义了一些用于计算的“状态”。 在这里,状态由随机初始化的 weightbias 张量组成,这些张量定义了仿射变换。由于每个张量都定义为 Parameter,因此它们会注册到模块中,并且会自动跟踪并从对 parameters() 的调用中返回。参数可以被认为是模块计算的“可学习”方面(稍后会详细介绍)。请注意,模块不需要具有状态,也可以是无状态的。

  • 它定义了一个执行计算的 forward() 函数。 对于此仿射变换模块,输入与 weight 参数进行矩阵相乘(使用 @ 简写符号)并加到 bias 参数以生成输出。更一般地说,模块的 forward() 实现可以执行涉及任意数量输入和输出的任意计算。

此简单模块演示了模块如何将状态和计算打包在一起。可以构造此模块的实例并调用它们

m = MyLinear(4, 3)
sample_input = torch.randn(4)
m(sample_input)
: tensor([-0.3037, -1.0413, -4.2057], grad_fn=<AddBackward0>)

请注意,模块本身是可调用的,调用它会调用其 forward() 函数。这个名称指的是“前向传递”和“反向传递”的概念,这些概念适用于每个模块。“前向传递”负责将模块表示的计算应用于给定的输入(如上面的代码片段所示)。“反向传递”计算模块输出相对于其输入的梯度,这可用于通过梯度下降方法“训练”参数。PyTorch 的 autograd 系统会自动处理此反向传递计算,因此无需为每个模块手动实现 backward() 函数。通过连续的前向/反向传递训练模块参数的过程将在 使用模块进行神经网络训练 中详细介绍。

可以通过调用 parameters()named_parameters() 来迭代模块注册的完整参数集,后者包含每个参数的名称

for parameter in m.named_parameters():
  print(parameter)
: ('weight', Parameter containing:
tensor([[ 1.0597,  1.1796,  0.8247],
        [-0.5080, -1.2635, -1.1045],
        [ 0.0593,  0.2469, -1.4299],
        [-0.4926, -0.5457,  0.4793]], requires_grad=True))
('bias', Parameter containing:
tensor([ 0.3634,  0.2015, -0.8525], requires_grad=True))

一般来说,模块注册的参数是模块计算中应该“学习”的方面。本说明的后面部分将展示如何使用 PyTorch 的优化器之一来更新这些参数。但是,在我们了解这些内容之前,让我们首先检查一下如何将模块相互组合。

模块作为构建块

模块可以包含其他模块,使它们成为开发更复杂功能的有用构建块。最简单的方法是使用 Sequential 模块。它允许我们将多个模块串联在一起

net = nn.Sequential(
  MyLinear(4, 3),
  nn.ReLU(),
  MyLinear(3, 1)
)

sample_input = torch.randn(4)
net(sample_input)
: tensor([-0.6749], grad_fn=<AddBackward0>)

注意,Sequential 自动将第一个 MyLinear 模块的输出作为输入馈送到 ReLU,并将该模块的输出作为输入馈送到第二个 MyLinear 模块。如所示,它仅限于对具有单个输入和输出的模块进行按顺序的串联。

一般而言,建议为超出最简单用例的任何内容定义自定义模块,因为这提供了对子模块如何用于模块计算的完全灵活性。

例如,以下是一个作为自定义模块实现的简单神经网络

import torch.nn.functional as F

class Net(nn.Module):
  def __init__(self):
    super().__init__()
    self.l0 = MyLinear(4, 3)
    self.l1 = MyLinear(3, 1)
  def forward(self, x):
    x = self.l0(x)
    x = F.relu(x)
    x = self.l1(x)
    return x

该模块由两个“子级”或“子模块”(l0l1)组成,它们定义了神经网络的层,并在模块的 forward() 方法中用于计算。可以通过调用 children()named_children() 来迭代模块的直接子级。

net = Net()
for child in net.named_children():
  print(child)
: ('l0', MyLinear())
('l1', MyLinear())

为了比直接子级更深入地了解,modules()named_modules() 递归地遍历模块及其子模块。

class BigNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.l1 = MyLinear(5, 4)
    self.net = Net()
  def forward(self, x):
    return self.net(self.l1(x))

big_net = BigNet()
for module in big_net.named_modules():
  print(module)
: ('', BigNet(
  (l1): MyLinear()
  (net): Net(
    (l0): MyLinear()
    (l1): MyLinear()
  )
))
('l1', MyLinear())
('net', Net(
  (l0): MyLinear()
  (l1): MyLinear()
))
('net.l0', MyLinear())
('net.l1', MyLinear())

有时,模块需要动态地定义子模块。 ModuleListModuleDict 模块在此处很有用;它们从列表或字典中注册子模块。

class DynamicNet(nn.Module):
  def __init__(self, num_layers):
    super().__init__()
    self.linears = nn.ModuleList(
      [MyLinear(4, 4) for _ in range(num_layers)])
    self.activations = nn.ModuleDict({
      'relu': nn.ReLU(),
      'lrelu': nn.LeakyReLU()
    })
    self.final = MyLinear(4, 1)
  def forward(self, x, act):
    for linear in self.linears:
      x = linear(x)
    x = self.activations[act](x)
    x = self.final(x)
    return x

dynamic_net = DynamicNet(3)
sample_input = torch.randn(4)
output = dynamic_net(sample_input, 'relu')

对于任何给定的模块,其参数包括其直接参数以及所有子模块的参数。这意味着对 parameters()named_parameters() 的调用将递归地包含子参数,从而可以方便地优化网络中的所有参数。

for parameter in dynamic_net.named_parameters():
  print(parameter)
: ('linears.0.weight', Parameter containing:
tensor([[-1.2051,  0.7601,  1.1065,  0.1963],
        [ 3.0592,  0.4354,  1.6598,  0.9828],
        [-0.4446,  0.4628,  0.8774,  1.6848],
        [-0.1222,  1.5458,  1.1729,  1.4647]], requires_grad=True))
('linears.0.bias', Parameter containing:
tensor([ 1.5310,  1.0609, -2.0940,  1.1266], requires_grad=True))
('linears.1.weight', Parameter containing:
tensor([[ 2.1113, -0.0623, -1.0806,  0.3508],
        [-0.0550,  1.5317,  1.1064, -0.5562],
        [-0.4028, -0.6942,  1.5793, -1.0140],
        [-0.0329,  0.1160, -1.7183, -1.0434]], requires_grad=True))
('linears.1.bias', Parameter containing:
tensor([ 0.0361, -0.9768, -0.3889,  1.1613], requires_grad=True))
('linears.2.weight', Parameter containing:
tensor([[-2.6340, -0.3887, -0.9979,  0.0767],
        [-0.3526,  0.8756, -1.5847, -0.6016],
        [-0.3269, -0.1608,  0.2897, -2.0829],
        [ 2.6338,  0.9239,  0.6943, -1.5034]], requires_grad=True))
('linears.2.bias', Parameter containing:
tensor([ 1.0268,  0.4489, -0.9403,  0.1571], requires_grad=True))
('final.weight', Parameter containing:
tensor([[ 0.2509], [-0.5052], [ 0.3088], [-1.4951]], requires_grad=True))
('final.bias', Parameter containing:
tensor([0.3381], requires_grad=True))

使用 to() 也很容易将所有参数移动到不同的设备或更改其精度。

# Move all parameters to a CUDA device
dynamic_net.to(device='cuda')

# Change precision of all parameters
dynamic_net.to(dtype=torch.float64)

dynamic_net(torch.randn(5, device='cuda', dtype=torch.float64))
: tensor([6.5166], device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>)

更一般地说,可以使用 apply() 函数递归地将任意函数应用于模块及其子模块。例如,要将自定义初始化应用于模块及其子模块的参数

# Define a function to initialize Linear weights.
# Note that no_grad() is used here to avoid tracking this computation in the autograd graph.
@torch.no_grad()
def init_weights(m):
  if isinstance(m, nn.Linear):
    nn.init.xavier_normal_(m.weight)
    m.bias.fill_(0.0)

# Apply the function recursively on the module and its submodules.
dynamic_net.apply(init_weights)

这些示例展示了如何通过模块组合形成复杂的网络,并可以方便地进行操作。为了能够快速轻松地构建神经网络,同时最小化样板代码,PyTorch 在 torch.nn 命名空间中提供了大量性能优异的模块库,这些模块执行常见的网络操作,例如池化、卷积、损失函数等。

在下一节中,我们将提供一个完整的训练神经网络的示例。

有关更多信息,请查看

使用模块训练神经网络

构建网络后,需要对其进行训练,并且可以使用 PyTorch 中 torch.optim 的优化器之一轻松优化其参数。

# Create the network (from previous section) and optimizer
net = Net()
optimizer = torch.optim.SGD(net.parameters(), lr=1e-4, weight_decay=1e-2, momentum=0.9)

# Run a sample training loop that "teaches" the network
# to output the constant zero function
for _ in range(10000):
  input = torch.randn(4)
  output = net(input)
  loss = torch.abs(output)
  net.zero_grad()
  loss.backward()
  optimizer.step()

# After training, switch the module to eval mode to do inference, compute performance metrics, etc.
# (see discussion below for a description of training and evaluation modes)
...
net.eval()
...

在此简化示例中,网络学习简单地输出零,因为任何非零输出都将根据其绝对值被“惩罚”,方法是采用 torch.abs() 作为损失函数。虽然这不是一项非常有趣的任务,但训练的关键部分都存在。

  • 创建一个网络。

  • 创建一个优化器(在本例中,一个随机梯度下降优化器),并将网络的参数与其关联。

  • 一个训练循环…
    • 获取输入,

    • 运行网络,

    • 计算损失,

    • 将网络参数的梯度清零,

    • 调用 loss.backward() 更新参数的梯度,

    • 调用 optimizer.step() 将梯度应用于参数。

在运行完上面的代码段后,请注意网络的参数已经发生了变化。特别是,检查 l1weight 参数的值表明其值现在更接近于 0(如预期的那样)。

print(net.l1.weight)
: Parameter containing:
tensor([[-0.0013],
        [ 0.0030],
        [-0.0008]], requires_grad=True)

请注意,上述过程完全在网络模块处于“训练模式”时完成。模块默认情况下处于训练模式,可以使用 train()eval() 在训练模式和评估模式之间切换。它们的行为可能因所处的模式而异。例如,BatchNorm 模块在训练期间维护一个正在运行的均值和方差,这些值在模块处于评估模式时不会更新。通常,模块在训练期间应处于训练模式,并且仅在推断或评估时切换到评估模式。以下是一个自定义模块的示例,它在两种模式之间的行为有所不同

class ModalModule(nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, x):
    if self.training:
      # Add a constant only in training mode.
      return x + 1.
    else:
      return x


m = ModalModule()
x = torch.randn(4)

print('training mode output: {}'.format(m(x)))
: tensor([1.6614, 1.2669, 1.0617, 1.6213, 0.5481])

m.eval()
print('evaluation mode output: {}'.format(m(x)))
: tensor([ 0.6614,  0.2669,  0.0617,  0.6213, -0.4519])

训练神经网络通常很棘手。有关更多信息,请查看

模块状态

在上一节中,我们演示了训练模块的“参数”或可学习的计算方面。现在,如果要将训练后的模型保存到磁盘,可以通过保存其 state_dict(即“状态字典”)来实现。

# Save the module
torch.save(net.state_dict(), 'net.pt')

...

# Load the module later on
new_net = Net()
new_net.load_state_dict(torch.load('net.pt'))
: <All keys matched successfully>

模块的 state_dict 包含影响其计算的状态。这包括但不限于模块的参数。对于某些模块,拥有超出参数之外影响模块计算但不可学习的状态可能很有用。针对这种情况,PyTorch 提供了“缓冲区”的概念,包括“持久”缓冲区和“非持久”缓冲区。以下是模块可以具有的各种状态类型的概述

  • 参数:可学习的计算方面;包含在 state_dict

  • 缓冲区:不可学习的计算方面

    • 持久缓冲区:包含在 state_dict 中(即在保存和加载时进行序列化)

    • 非持久缓冲区:不包含在 state_dict 中(即在序列化时排除)

作为使用缓冲区的示例,考虑一个维护运行均值的简单模块。我们希望将运行均值的当前值视为模块的 state_dict 的一部分,以便在加载模块的序列化形式时对其进行还原,但我们不希望它可学习。此代码段展示了如何使用 register_buffer() 来实现此目的

class RunningMean(nn.Module):
  def __init__(self, num_features, momentum=0.9):
    super().__init__()
    self.momentum = momentum
    self.register_buffer('mean', torch.zeros(num_features))
  def forward(self, x):
    self.mean = self.momentum * self.mean + (1.0 - self.momentum) * x
    return self.mean

现在,运行均值的当前值被视为模块 state_dict 的一部分,并将从磁盘加载模块时正确还原。

m = RunningMean(4)
for _ in range(10):
  input = torch.randn(4)
  m(input)

print(m.state_dict())
: OrderedDict([('mean', tensor([ 0.1041, -0.1113, -0.0647,  0.1515]))]))

# Serialized form will contain the 'mean' tensor
torch.save(m.state_dict(), 'mean.pt')

m_loaded = RunningMean(4)
m_loaded.load_state_dict(torch.load('mean.pt'))
assert(torch.all(m.mean == m_loaded.mean))

如前所述,缓冲区可以通过将其标记为非持久来从模块的 state_dict 中排除。

self.register_buffer('unserialized_thing', torch.randn(5), persistent=False)

持久缓冲区和非持久缓冲区都受到使用 to() 应用于模型范围的设备/数据类型更改的影响。

# Moves all module parameters and buffers to the specified device / dtype
m.to(device='cuda', dtype=torch.float64)

可以使用 buffers()named_buffers() 迭代模块的缓冲区。

for buffer in m.named_buffers():
  print(buffer)

以下类演示了在模块中注册参数和缓冲区的各种方法

class StatefulModule(nn.Module):
  def __init__(self):
    super().__init__()
    # Setting a nn.Parameter as an attribute of the module automatically registers the tensor
    # as a parameter of the module.
    self.param1 = nn.Parameter(torch.randn(2))

    # Alternative string-based way to register a parameter.
    self.register_parameter('param2', nn.Parameter(torch.randn(3)))

    # Reserves the "param3" attribute as a parameter, preventing it from being set to anything
    # except a parameter. "None" entries like this will not be present in the module's state_dict.
    self.register_parameter('param3', None)

    # Registers a list of parameters.
    self.param_list = nn.ParameterList([nn.Parameter(torch.randn(2)) for i in range(3)])

    # Registers a dictionary of parameters.
    self.param_dict = nn.ParameterDict({
      'foo': nn.Parameter(torch.randn(3)),
      'bar': nn.Parameter(torch.randn(4))
    })

    # Registers a persistent buffer (one that appears in the module's state_dict).
    self.register_buffer('buffer1', torch.randn(4), persistent=True)

    # Registers a non-persistent buffer (one that does not appear in the module's state_dict).
    self.register_buffer('buffer2', torch.randn(5), persistent=False)

    # Reserves the "buffer3" attribute as a buffer, preventing it from being set to anything
    # except a buffer. "None" entries like this will not be present in the module's state_dict.
    self.register_buffer('buffer3', None)

    # Adding a submodule registers its parameters as parameters of the module.
    self.linear = nn.Linear(2, 3)

m = StatefulModule()

# Save and load state_dict.
torch.save(m.state_dict(), 'state.pt')
m_loaded = StatefulModule()
m_loaded.load_state_dict(torch.load('state.pt'))

# Note that non-persistent buffer "buffer2" and reserved attributes "param3" and "buffer3" do
# not appear in the state_dict.
print(m_loaded.state_dict())
: OrderedDict([('param1', tensor([-0.0322,  0.9066])),
               ('param2', tensor([-0.4472,  0.1409,  0.4852])),
               ('buffer1', tensor([ 0.6949, -0.1944,  1.2911, -2.1044])),
               ('param_list.0', tensor([ 0.4202, -0.1953])),
               ('param_list.1', tensor([ 1.5299, -0.8747])),
               ('param_list.2', tensor([-1.6289,  1.4898])),
               ('param_dict.bar', tensor([-0.6434,  1.5187,  0.0346, -0.4077])),
               ('param_dict.foo', tensor([-0.0845, -1.4324,  0.7022])),
               ('linear.weight', tensor([[-0.3915, -0.6176],
                                         [ 0.6062, -0.5992],
                                         [ 0.4452, -0.2843]])),
               ('linear.bias', tensor([-0.3710, -0.0795, -0.3947]))])

有关更多信息,请查看

模块初始化

默认情况下,由 torch.nn 提供的模块的参数和浮点缓冲区在模块实例化期间初始化为 CPU 上的 32 位浮点值,使用历史上针对该模块类型表现良好的初始化方案。对于某些用例,可能需要使用不同的数据类型、设备(例如 GPU)或初始化技术进行初始化。

示例

# Initialize module directly onto GPU.
m = nn.Linear(5, 3, device='cuda')

# Initialize module with 16-bit floating point parameters.
m = nn.Linear(5, 3, dtype=torch.half)

# Skip default parameter initialization and perform custom (e.g. orthogonal) initialization.
m = torch.nn.utils.skip_init(nn.Linear, 5, 3)
nn.init.orthogonal_(m.weight)

请注意,上面演示的设备和数据类型选项也适用于为模块注册的任何浮点缓冲区。

m = nn.BatchNorm2d(3, dtype=torch.half)
print(m.running_mean)
: tensor([0., 0., 0.], dtype=torch.float16)

虽然模块编写者可以在其自定义模块中使用任何设备或数据类型来初始化参数,但最佳实践是默认情况下使用 dtype=torch.floatdevice='cpu'。可选地,您可以通过遵循上面演示的约定为您的自定义模块提供这些方面的完全灵活性,所有 torch.nn 模块都遵循该约定。

  • 提供一个 device 构造函数关键字参数,该参数适用于模块注册的任何参数/缓冲区。

  • 提供一个 dtype 构造函数关键字参数,该参数适用于模块注册的任何参数/浮点缓冲区。

  • 仅在模块的构造函数中对参数和缓冲区使用初始化函数(即来自 torch.nn.init 的函数)。请注意,这仅是使用 skip_init() 所必需的;有关说明,请参见 此页面

有关更多信息,请查看

模块钩子

使用模块进行神经网络训练 中,我们演示了模块的训练过程,该过程迭代地执行前向和反向传递,并在每次迭代中更新模块参数。为了更有效地控制此过程,PyTorch 提供了“钩子”,这些钩子可以在前向或反向传递期间执行任意计算,甚至根据需要修改传递方式。此功能的一些有用示例包括调试、可视化激活、深入检查梯度等。钩子可以添加到您没有自己编写的模块中,这意味着此功能可以应用于第三方模块或 PyTorch 提供的模块。

PyTorch 为模块提供了两种类型的钩子

所有钩子都允许用户返回一个更新的值,该值将在剩余的计算中使用。因此,这些钩子可用于在常规模块正向/反向传递期间执行任意代码,或修改某些输入/输出,而无需更改模块的 forward() 函数。

以下是一个演示正向和反向钩子用法的示例

torch.manual_seed(1)

def forward_pre_hook(m, inputs):
  # Allows for examination and modification of the input before the forward pass.
  # Note that inputs are always wrapped in a tuple.
  input = inputs[0]
  return input + 1.

def forward_hook(m, inputs, output):
  # Allows for examination of inputs / outputs and modification of the outputs
  # after the forward pass. Note that inputs are always wrapped in a tuple while outputs
  # are passed as-is.

  # Residual computation a la ResNet.
  return output + inputs[0]

def backward_hook(m, grad_inputs, grad_outputs):
  # Allows for examination of grad_inputs / grad_outputs and modification of
  # grad_inputs used in the rest of the backwards pass. Note that grad_inputs and
  # grad_outputs are always wrapped in tuples.
  new_grad_inputs = [torch.ones_like(gi) * 42. for gi in grad_inputs]
  return new_grad_inputs

# Create sample module & input.
m = nn.Linear(3, 3)
x = torch.randn(2, 3, requires_grad=True)

# ==== Demonstrate forward hooks. ====
# Run input through module before and after adding hooks.
print('output with no forward hooks: {}'.format(m(x)))
: output with no forward hooks: tensor([[-0.5059, -0.8158,  0.2390],
                                        [-0.0043,  0.4724, -0.1714]], grad_fn=<AddmmBackward>)

# Note that the modified input results in a different output.
forward_pre_hook_handle = m.register_forward_pre_hook(forward_pre_hook)
print('output with forward pre hook: {}'.format(m(x)))
: output with forward pre hook: tensor([[-0.5752, -0.7421,  0.4942],
                                        [-0.0736,  0.5461,  0.0838]], grad_fn=<AddmmBackward>)

# Note the modified output.
forward_hook_handle = m.register_forward_hook(forward_hook)
print('output with both forward hooks: {}'.format(m(x)))
: output with both forward hooks: tensor([[-1.0980,  0.6396,  0.4666],
                                          [ 0.3634,  0.6538,  1.0256]], grad_fn=<AddBackward0>)

# Remove hooks; note that the output here matches the output before adding hooks.
forward_pre_hook_handle.remove()
forward_hook_handle.remove()
print('output after removing forward hooks: {}'.format(m(x)))
: output after removing forward hooks: tensor([[-0.5059, -0.8158,  0.2390],
                                               [-0.0043,  0.4724, -0.1714]], grad_fn=<AddmmBackward>)

# ==== Demonstrate backward hooks. ====
m(x).sum().backward()
print('x.grad with no backwards hook: {}'.format(x.grad))
: x.grad with no backwards hook: tensor([[ 0.4497, -0.5046,  0.3146],
                                         [ 0.4497, -0.5046,  0.3146]])

# Clear gradients before running backward pass again.
m.zero_grad()
x.grad.zero_()

m.register_full_backward_hook(backward_hook)
m(x).sum().backward()
print('x.grad with backwards hook: {}'.format(x.grad))
: x.grad with backwards hook: tensor([[42., 42., 42.],
                                      [42., 42., 42.]])

高级功能

PyTorch 还提供了一些更高级的功能,旨在与模块一起使用。所有这些功能都可用于自定义编写的模块,一个小小的注意事项是某些功能可能需要模块符合特定约束才能得到支持。有关这些功能和相应要求的深入讨论,请参见下面的链接。

分布式训练

PyTorch 中存在多种分布式训练方法,既可用于使用多个 GPU 扩展训练,也可用于跨多台机器进行训练。查看 分布式训练概述页面,了解有关如何利用这些方法的详细信息。

性能分析

PyTorch Profiler 可用于识别模型中的性能瓶颈。它测量并输出内存使用量和花费时间这两个方面的性能特征。

使用量化提高性能

将量化技术应用于模块可以通过使用低于浮点精度的位宽来提高性能和内存使用量。查看此处提供的各种 PyTorch 量化机制 此处

使用剪枝改进内存使用量

大型深度学习模型通常参数过多,导致内存使用量过高。为了解决这个问题,PyTorch 提供了模型剪枝机制,这有助于在保持任务精度的同时减少内存使用量。 剪枝教程 描述了如何利用 PyTorch 提供的剪枝技术或根据需要定义自定义剪枝技术。

参数化

对于某些应用程序,在模型训练期间约束参数空间可能是有益的。例如,强制学习到的参数正交性可以提高 RNN 的收敛速度。PyTorch 提供了一种机制来应用 参数化(如这种参数化),并且还允许定义自定义约束。

使用 FX 变换模块

PyTorch 的 FX 组件提供了一种灵活的方式来通过直接操作模块计算图来变换模块。这可用于以编程方式生成或操作模块,以满足各种用例。要探索 FX,请查看这些使用 FX 进行 卷积 + 批归一化融合CPU 性能分析 的示例。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源