扩展 PyTorch¶
在本说明中,我们将介绍扩展 torch.nn
、torch.autograd
、torch
以及编写自定义 C++ 扩展的方法。
添加新算子¶
PyTorch 提供了大量可在张量 (Tensor) 上工作的算子(例如 torch.add()
、torch.sum()
等)。但是,您可能希望将新的自定义操作引入 PyTorch,并使其行为与 PyTorch 的内置算子一样。为此,您必须通过 Python torch.library 或 C++ TORCH_LIBRARY API 将自定义操作注册到 PyTorch。
有关更多详细信息,请参阅 PyTorch 自定义算子登陆页。
扩展 torch.autograd
¶
将操作添加到 autograd
需要为每个操作实现一个新的 Function
子类。回想一下,Function 是 autograd
用于编码操作历史和计算梯度的对象。
本文档的第一部分侧重于反向模式自动微分 (AD),因为它是最广泛使用的特性。末尾的部分讨论了前向模式自动微分 (AD) 的扩展。
何时使用¶
通常,如果您想在模型中执行不可微分或依赖非 PyTorch 库(例如 NumPy)的计算,但仍希望您的操作能够与其他算子链式连接并与 autograd 引擎一起工作,那么请实现自定义函数。
在某些情况下,自定义函数也可用于提高性能和内存使用:如果您使用 C++ 扩展实现了前向和反向传播,您可以将它们封装在 Function
中以与 autograd 引擎交互。如果您想减少为反向传播保存的缓冲区数量,可以使用自定义函数将多个算子组合在一起。
何时不使用¶
如果您已经可以使用 PyTorch 的内置算子来编写您的函数,那么它的反向图(很可能)已经可以被 autograd 记录。在这种情况下,您无需自己实现 backward 函数。考虑使用普通的 Python 函数即可。
如果您需要维护状态,即可训练参数,您应该(也)使用自定义模块。有关扩展 torch.nn
的更多信息,请参阅下面的部分。
如果您想在反向传播期间修改梯度或执行副作用,请考虑注册一个张量或Module hook。
如何使用¶
请按照以下步骤操作:1. 子类化 Function
并实现 forward()
、(可选)setup_context()
和 backward()
方法。2. 调用 ctx 参数上的适当方法。3. 声明您的函数是否支持二次反向传播 (double backward)。4. 使用 gradcheck 验证您的梯度是否正确。
步骤 1:子类化 Function
后,您需要定义 3 个方法
forward()
是执行操作的代码。它可以接受任意数量的参数,如果您指定默认值,其中一些参数是可选的。这里接受所有类型的 Python 对象。Tensor
参数如果跟踪历史(即,requires_grad=True
),则在调用前将被转换为不跟踪历史的 Tensor,并且它们的用法将被注册到图中。请注意,此逻辑不会遍历列表/字典/任何其他数据结构,只会考虑直接作为调用参数的 Tensor。您可以返回单个Tensor
输出,或者在有多个输出时返回一个tuple
的 tensors。此外,请参阅Function
的文档以查找只能从forward()
中调用的有用方法的描述。setup_context()
(可选)。您可以编写一个接受ctx
对象的“组合式”forward()
,或者(自 PyTorch 2.0 起)编写一个不接受ctx
的单独forward()
和一个setup_context()
方法,在其中进行ctx
的修改。forward()
应该包含计算逻辑,而setup_context()
只应负责ctx
的修改(不包含任何计算)。通常,分开的forward()
和setup_context()
更接近 PyTorch 本地操作的工作方式,因此与各种 PyTorch 子系统更具组合性。有关更多详细信息,请参阅 组合式或分开式 forward() 和 setup_context()。backward()
(或vjp()
)定义了梯度公式。它将获得与输出数量相同的Tensor
参数,每个参数代表相对于相应输出的梯度。切记不要就地修改这些参数。它应该返回与输入数量相同的 tensors,每个 Tensor 包含相对于其相应输入的梯度。如果您的输入不需要梯度(needs_input_grad
是一个布尔值元组,指示每个输入是否需要计算梯度),或者是非Tensor
对象,您可以返回python:None
。此外,如果forward()
有可选参数,您可以返回比输入更多的梯度,只要它们全部为None
。
步骤 2:您有责任正确使用 ctx 中的函数,以确保新的 Function
与 autograd 引擎正常工作。
必须使用
save_for_backward()
保存要在反向传播中使用的任何张量。非张量应直接存储在 ctx 上。如果保存了既不是输入也不是输出的张量用于反向传播,您的Function
可能不支持二次反向传播(参见步骤 3)。必须使用
mark_dirty()
标记 forward 函数就地修改的任何输入。必须使用
mark_non_differentiable()
告知引擎输出是否不可微分。默认情况下,所有可微分类型的输出张量都将设置为需要梯度。不可微分类型(即整型)的张量永远不会被标记为需要梯度。set_materialize_grads()
可用于告知 autograd 引擎在输出不依赖于输入的情况下优化梯度计算,方法是不具体化 (materializing) 传递给 backward 函数的梯度张量。也就是说,如果设置为 False,Python 中的 None 对象或 C++ 中的“未定义张量”(即 x.defined() 为 False 的张量 x)将不会在调用 backward 之前被转换为填充零的张量,因此您的代码需要像处理填充零的张量一样处理这些对象。此设置的默认值为 True。
步骤 3:如果您的 Function
不支持二次反向传播,您应该通过使用 once_differentiable()
装饰器装饰 backward 来明确声明。使用此装饰器后,尝试通过您的函数进行二次反向传播将产生错误。有关二次反向传播的更多信息,请参阅我们的二次反向传播教程。
步骤 4:建议您使用 torch.autograd.gradcheck()
来检查您的 backward 函数是否通过使用 backward 函数计算雅可比矩阵,并将其值与使用有限差分法数值计算的雅可比矩阵进行逐元素比较,从而正确计算了 forward 的梯度。
示例¶
您可以在下方找到 Linear 函数的代码,并附有额外注释。
# Inherit from Function
class LinearFunction(Function):
# Note that forward, setup_context, and backward are @staticmethods
@staticmethod
def forward(input, weight, bias):
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
@staticmethod
# inputs is a Tuple of all of the inputs passed to forward.
# output is the output of the forward().
def setup_context(ctx, inputs, output):
input, weight, bias = inputs
ctx.save_for_backward(input, weight, bias)
# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_output):
# This is a pattern that is very convenient - at the top of backward
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
# These needs_input_grad checks are optional and there only to
# improve efficiency. If you want to make your code simpler, you can
# skip them. Returning gradients for inputs that don't require it is
# not an error.
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
return grad_input, grad_weight, grad_bias
现在,为了更容易使用这些自定义算子,我们建议将它们别名化或封装在一个函数中。封装在函数中可以让我们支持默认参数和关键字参数。
# Option 1: alias
linear = LinearFunction.apply
# Option 2: wrap in a function, to support default args and keyword args.
def linear(input, weight, bias=None):
return LinearFunction.apply(input, weight, bias)
这里,我们提供一个由非 Tensor 参数参数化的函数的额外示例。
class MulConstant(Function):
@staticmethod
def forward(tensor, constant):
return tensor * constant
@staticmethod
def setup_context(ctx, inputs, output):
# ctx is a context object that can be used to stash information
# for backward computation
tensor, constant = inputs
ctx.constant = constant
@staticmethod
def backward(ctx, grad_output):
# We return as many input gradients as there were arguments.
# Gradients of non-Tensor arguments to forward must be None.
return grad_output * ctx.constant, None
在这里,我们通过调用 set_materialize_grads(False) 来优化上述示例。
class MulConstant(Function):
@staticmethod
def forward(tensor, constant):
return tensor * constant
@staticmethod
def setup_context(ctx, inputs, output):
tensor, constant = inputs
ctx.set_materialize_grads(False)
ctx.constant = constant
@staticmethod
def backward(ctx, grad_output):
# Here we must handle None grad_output tensor. In this case we
# can skip unnecessary computations and just return None.
if grad_output is None:
return None, None
# We return as many input gradients as there were arguments.
# Gradients of non-Tensor arguments to forward must be None.
return grad_output * ctx.constant, None
如果在 forward() 中计算的任何“中间”张量需要被保存,则它们必须作为输出返回,或者结合使用 forward 和 setup_context()(参见 组合式或分开式 forward() 和 setup_context())。请注意,这意味着如果您希望梯度流经这些中间值,您需要为它们定义梯度公式(另请参见 二次反向传播教程)。
class MyCube(torch.autograd.Function):
@staticmethod
def forward(x):
# We wish to save dx for backward. In order to do so, it must
# be returned as an output.
dx = 3 * x ** 2
result = x ** 3
return result, dx
@staticmethod
def setup_context(ctx, inputs, output):
x, = inputs
result, dx = output
ctx.save_for_backward(x, dx)
@staticmethod
def backward(ctx, grad_output, grad_dx):
x, dx = ctx.saved_tensors
# In order for the autograd.Function to work with higher-order
# gradients, we must add the gradient contribution of `dx`,
# which is grad_dx * 6 * x.
result = grad_output * dx + grad_dx * 6 * x
return result
# Wrap MyCube in a function so that it is clearer what the output is
def my_cube(x):
result, dx = MyCube.apply(x)
return result
注意
backward 的输入,即 grad_output
,也可以是跟踪历史的张量。因此,如果 backward 是用可微分操作实现的(例如,调用另一个自定义 Function
),则高阶导数将起作用。在这种情况下,使用 save_for_backward
保存的张量也可以在 backward 中使用并有梯度流回,但保存在 ctx
中的张量将不会有梯度流回。如果您需要保存在 ctx
中的 Tensor 有梯度流回,您应该将其作为自定义 Function
的输出并使用 save_for_backward
保存。
您可能想检查您实现的 backward 方法是否实际计算了您的函数的导数。这可以通过与使用小有限差分法的数值逼近进行比较来实现。
from torch.autograd import gradcheck
# gradcheck takes a tuple of tensors as input, check if your gradient
# evaluated with these tensors are close enough to numerical
# approximations and returns True if they all verify this condition.
input = (torch.randn(20,20,dtype=torch.double,requires_grad=True), torch.randn(30,20,dtype=torch.double,requires_grad=True))
test = gradcheck(linear, input, eps=1e-6, atol=1e-4)
print(test)
有关有限差分梯度比较的更多详细信息,请参阅 数值梯度检查。如果您的函数用于高阶导数(对 backward 传播求导),您可以使用同一包中的 gradgradcheck
函数来检查高阶导数。
组合式或分开式 forward()
和 setup_context()
¶
定义 Function
有两种主要方式。即
我们推荐第二种选项(分开的 forward()
和 setup_context()
),因为它更接近 PyTorch 本地操作的实现方式,并且与 torch.func
转换更具组合性。然而,我们计划将来继续支持这两种方法;将 forward()
与 setup_context()
结合使用:可以提供更大的灵活性,因为您可以在不将中间值作为输出返回的情况下保存它们。
有关如何使用分开的 forward()
和 setup_context()
定义 Function
的信息,请参阅上一节。
以下是一个示例,说明如何使用组合式的 forward()
和 setup_context()
定义 Function
。
class LinearFunction(Function):
@staticmethod
# ctx is the first argument to forward
def forward(ctx, input, weight, bias=None):
# The forward pass can use ctx.
ctx.save_for_backward(input, weight, bias)
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
return grad_input, grad_weight, grad_bias
前向模式自动微分 (AD)¶
重写前向模式自动微分 (AD) 公式具有非常相似的 API,但有一些细微差别。您可以实现 jvp()
函数。
它将获得与输入数量相同的 Tensor
参数,每个参数代表相对于相应输入的梯度。它应该返回与输出数量相同的 tensors,每个 Tensor 包含相对于其相应输出的梯度。jvp()
将在 forward()
方法之后、apply()
返回之前被调用。
jvp()
与 backward()
函数有一些细微差别
您可以使用 ctx 将
forward()
中的任何数据传递给jvp()
函数。如果backward()
不需要该状态,您可以在jvp()
函数末尾通过执行del ctx.foo
明确释放它。jvp()
的实现必须是 backward 可微分的,或者明确检查给定的前向模式梯度都没有设置requires_grad
。jvp()
函数必须与forward()
的 view/就地行为相匹配。例如,如果第i
个输入被就地修改,则第i
个梯度也必须就地更新。类似地,如果第j
个输出是第k
个输入的视图。那么返回的第j
个输出梯度必须是给定第k
个输入梯度的视图。由于用户无法指定需要计算哪个梯度,
jvp()
函数应该始终计算所有输出的梯度。前向模式梯度确实遵守由
set_materialize_grads()
设置的标志,并且当禁用此功能时,您可以获得 None 输入梯度。
torch.func
转换和/或 torch.vmap()
¶
有关详细信息,请参阅 使用 autograd.Function 扩展 torch.func。
扩展 torch.nn
¶
nn
导出了两种接口 - 模块 (modules) 及其函数式版本 (functional versions)。你可以通过这两种方式进行扩展,但我们推荐使用模块来实现需要持有参数或缓冲 (buffers) 的各种层,而对于激活函数、池化等无参数操作,则推荐使用函数形式。
在上面一节中已经全面介绍了如何添加一个操作的函数式版本。
添加 Module
¶
由于 nn
大量利用了 autograd
,添加一个新的 Module
需要实现一个 Function
,该函数执行操作并能计算梯度。现在我们假设要实现一个 Linear
模块,并且已经按照上面的代码清单实现了该函数。添加这个模块只需要非常少的代码。现在,需要实现两个函数:
__init__
(可选) - 接收诸如卷积核大小、特征数量等参数,并初始化参数和缓冲。
以下是 Linear
模块的实现示例:
class Linear(nn.Module):
def __init__(self, input_features, output_features, bias=True):
super().__init__()
self.input_features = input_features
self.output_features = output_features
# nn.Parameter is a special kind of Tensor, that will get
# automatically registered as Module's parameter once it's assigned
# as an attribute. Parameters and buffers need to be registered, or
# they won't appear in .parameters() (doesn't apply to buffers), and
# won't be converted when e.g. .cuda() is called. You can use
# .register_buffer() to register buffers.
# nn.Parameters require gradients by default.
self.weight = nn.Parameter(torch.empty(output_features, input_features))
if bias:
self.bias = nn.Parameter(torch.empty(output_features))
else:
# You should always register all possible parameters, but the
# optional ones can be None if you want.
self.register_parameter('bias', None)
# Not a very smart way to initialize weights
nn.init.uniform_(self.weight, -0.1, 0.1)
if self.bias is not None:
nn.init.uniform_(self.bias, -0.1, 0.1)
def forward(self, input):
# See the autograd section for explanation of what happens here.
return LinearFunction.apply(input, self.weight, self.bias)
def extra_repr(self):
# (Optional)Set the extra information about this module. You can test
# it by printing an object of this class.
return 'input_features={}, output_features={}, bias={}'.format(
self.input_features, self.output_features, self.bias is not None
)
扩展 torch
Python API¶
你可以通过定义一个自定义类,使其方法与 Tensor
相匹配,从而创建模仿 Tensor
的自定义类型。但是,如果你想将这些类型传递给顶层 torch
命名空间中接受 Tensor
操作数的函数,例如 torch.add()
,该怎么办呢?
如果你的自定义 Python 类型定义了一个名为 __torch_function__
的方法,当将你的自定义类实例传递给 torch
命名空间中的函数时,PyTorch 将调用你的 __torch_function__
实现。这使得你可以为 torch
命名空间中的任何函数定义自定义实现,你的 __torch_function__
实现可以调用这些函数,从而允许你的用户在使用 Tensor
时,在他们已有的 PyTorch 工作流中使用你的自定义类型。这既适用于与 Tensor
无关的“鸭子类型” (duck types),也适用于用户定义的 Tensor
子类。
使用类似 Tensor
的类型扩展 torch
¶
为了具体说明这一点,我们从一个简单的示例开始,它演示了 API 分派机制。我们将创建一个自定义类型,表示一个二维标量张量,该张量由阶数 N
和对角线元素的值 value
参数化。
class ScalarTensor(object):
def __init__(self, N, value):
self._N = N
self._value = value
def __repr__(self):
return "ScalarTensor(N={}, value={})".format(self._N, self._value)
def tensor(self):
return self._value * torch.eye(self._N)
设计的第一版并不是非常有用。ScalarTensor
的主要功能是提供比基本张量类更紧凑的标量张量字符串表示形式。
>>> d = ScalarTensor(5, 2)
>>> d
ScalarTensor(N=5, value=2)
>>> d.tensor()
tensor([[2., 0., 0., 0., 0.],
[0., 2., 0., 0., 0.],
[0., 0., 2., 0., 0.],
[0., 0., 0., 2., 0.],
[0., 0., 0., 0., 2.]])
如果我们尝试将此对象与 torch
API 一起使用,将会遇到问题:
>>> import torch
>>> torch.mean(d)
TypeError: mean(): argument 'input' (position 1) must be Tensor, not ScalarTensor
向 ScalarTensor
添加 __torch_function__
实现使得上述操作得以成功。让我们重新编写实现,这次添加 __torch_function__
实现:
HANDLED_FUNCTIONS = {}
class ScalarTensor(object):
def __init__(self, N, value):
self._N = N
self._value = value
def __repr__(self):
return "ScalarTensor(N={}, value={})".format(self._N, self._value)
def tensor(self):
return self._value * torch.eye(self._N)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func not in HANDLED_FUNCTIONS or not all(
issubclass(t, (torch.Tensor, ScalarTensor))
for t in types
):
return NotImplemented
return HANDLED_FUNCTIONS[func](*args, **kwargs)
__torch_function__
方法接受四个参数:func
,一个指向正在被覆盖的 torch API 函数的引用;types
,实现 __torch_function__
的 Tensor-like 类型的列表;args
,传递给函数的参数元组;kwargs
,传递给函数的关键字参数字典。它使用一个名为 HANDLED_FUNCTIONS
的全局分派表来存储自定义实现。这个字典的键是 torch
命名空间中的函数,值是针对 ScalarTensor
的实现。
注意
使用全局分派表不是 __torch_function__
API 的强制要求,它只是一个有用的设计模式,用于组织你的覆盖实现。
这个类定义不足以让 torch.mean
在我们传递 ScalarTensor
时执行正确的操作——我们还需要为 torch.mean
定义一个针对 ScalarTensor
操作数的实现,并将该实现添加到 HANDLED_FUNCTIONS
分派表字典中。一种实现方法是定义一个装饰器:
import functools
def implements(torch_function):
"""Register a torch function override for ScalarTensor"""
def decorator(func):
functools.update_wrapper(func, torch_function)
HANDLED_FUNCTIONS[torch_function] = func
return func
return decorator
然后可以将其应用于我们覆盖实现的函数:
@implements(torch.mean)
def mean(input):
return float(input._value) / input._N
通过此更改,我们现在可以使用 torch.mean
处理 ScalarTensor
:
>>> d = ScalarTensor(5, 2)
>>> torch.mean(d)
0.4
当然,torch.mean
是最简单的覆盖函数类型之一,因为它只接受一个操作数。我们可以使用相同的机制来覆盖接受多个操作数的函数,其中任何一个操作数都可能是定义了 __torch_function__
的张量或类似张量的类型,例如 torch.add()
:
def ensure_tensor(data):
if isinstance(data, ScalarTensor):
return data.tensor()
return torch.as_tensor(data)
@implements(torch.add)
def add(input, other):
try:
if input._N == other._N:
return ScalarTensor(input._N, input._value + other._value)
else:
raise ValueError("Shape mismatch!")
except AttributeError:
return torch.add(ensure_tensor(input), ensure_tensor(other))
此版本为两个操作数都是 ScalarTensor
实例的情况提供了一个快速路径,还提供了一个较慢的路径,在任一操作数不是 ScalarTensor
时会退化为将数据转换为张量。这使得覆盖函数在任一操作数是 ScalarTensor
或常规 Tensor
时都能正确工作。
>>> s = ScalarTensor(2, 2)
>>> torch.add(s, s)
ScalarTensor(N=2, value=4)
>>> t = torch.tensor([[1, 1,], [1, 1]])
>>> torch.add(s, t)
tensor([[3., 1.],
[1., 3.]])
请注意,我们的 add
实现不像 torch.add()
那样接受 alpha
或 out
作为关键字参数:
>>> torch.add(s, s, alpha=2)
TypeError: add() got an unexpected keyword argument 'alpha'
为了速度和灵活性,__torch_function__
分派机制不会检查覆盖函数的签名是否与 torch
API 中被覆盖函数的签名匹配。对于某些应用,忽略可选参数是可以接受的,但为了确保与 Tensor
的完全兼容性,用户实现的 torch API 函数应该注意完全模仿被覆盖函数的 API。
在 torch
API 中没有显式覆盖的函数将从 __torch_function__
返回 NotImplemented
。如果所有定义了 __torch_function__
的操作数都返回 NotImplemented
,PyTorch 将引发 TypeError
。这意味着在大多数情况下,当传递此类类型的实例时,没有显式覆盖的操作将引发 TypeError
。
>>> torch.mul(s, 3)
TypeError: no implementation found for 'torch.mul' on types that
implement __torch_function__: [ScalarTensor]
实际上,这意味着如果你想按照这些思路使用 __torch_function__
实现来编写覆盖,你需要显式地实现完整的 torch
API,或者至少是你用例关心的 API 子集。这可能是一个艰巨的任务,因为完整的 torch
API 相当广泛。
另一个选择是对于未处理的操作不返回 NotImplemented
,而是在没有可用覆盖时将 Tensor
传递给原始的 torch
函数。例如,如果我们将 ScalarTensor
的 __torch_function__
实现更改为以下内容:
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func not in HANDLED_FUNCTIONS or not all(
issubclass(t, (torch.Tensor, ScalarTensor))
for t in types
):
args = [a.tensor() if hasattr(a, 'tensor') else a for a in args]
return func(*args, **kwargs)
return HANDLED_FUNCTIONS[func](*args, **kwargs)
那么 torch.mul()
将正常工作,尽管返回类型始终是 Tensor
而不是 ScalarTensor
,即使两个操作数都是 ScalarTensor
实例:
>>> s = ScalarTensor(2, 2)
>>> torch.mul(s, s)
tensor([[4., 0.],
[0., 4.]])
另请参见下面的 MetadataTensor
示例,它展示了这种模式的另一种变体,但始终返回 MetadataTensor
,以便在 torch
API 的操作中传播元数据。
__torch_function__
协议旨在覆盖完整的 API,部分覆盖可能导致不良结果,特别是某些函数会引发 TypeError
。对于子类尤其如此,torch.add、torch.Tensor.__add__ 和 torch.Tensor.add 都必须被覆盖,即使它们返回完全相同的结果。未能做到这一点也可能导致无限递归。如果需要从 torch.Tensor
子类实现函数,则必须在其实现内部使用 super().__torch_function__
。
子类化 torch.Tensor
¶
自版本 1.7.0 起,应用于 torch.Tensor
子类的 torch.Tensor
方法和公共 torch.*
命名空间中的函数将返回子类实例而不是 torch.Tensor
实例:
>>> class SubTensor(torch.Tensor):
... pass
>>> type(torch.add(SubTensor([0]), SubTensor([1]))).__name__
'SubTensor'
>>> type(torch.add(SubTensor([0]), torch.tensor([1]))).__name__
'SubTensor'
如果存在多个子类,默认会选择层级最低的那个。如果无法唯一确定这种情况,则会引发 TypeError
。
>>> type(torch.add(SubTensor2([0]), SubTensor([1]))).__name__
'SubTensor2'
>>> type(torch.add(SubTensor2([0]), torch.tensor([1]))).__name__
'SubTensor2'
>>> torch.add(SubTensor([0]), OtherSubTensor([1]))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: no implementation found for 'torch.add' on types that implement __torch_function__: [SubTensor, OtherSubTensor]
如果希望对所有张量方法进行全局覆盖,可以使用 __torch_function__
。以下是记录所有函数/方法调用的示例:
class LoggingTensor(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
# NOTE: Logging calls Tensor.__repr__, so we can't log __repr__ without infinite recursion
if func is not torch.Tensor.__repr__:
logging.info(f"func: {func.__name__}, args: {args!r}, kwargs: {kwargs!r}")
if kwargs is None:
kwargs = {}
return super().__torch_function__(func, types, args, kwargs)
但是,如果希望覆盖 Tensor 子类上的方法,可以通过直接覆盖方法(为子类定义该方法)或使用 __torch_function__
并与 func
匹配来实现。
对于子类中的 __torch_function__
,应注意始终调用 super().__torch_function__(func, ...)
,而不是直接调用 func
,这与 1.7.0 版本之前的做法不同。未能做到这一点可能导致 func
递归调用回 __torch_function__
,从而导致无限递归。
使用 Tensor 包装器类型扩展 torch
¶
另一个有用的场景是包装 Tensor
的类型,无论作为属性还是通过子类化。下面我们实现这种类型的一种特殊情况:一个 MetadataTensor
,它将一个元数据字典附加到 Tensor
上,并在 torch
操作中传播。由于这是对完整 torch
API 的通用包装,我们无需单独实现每个覆盖,因此可以使 __torch_function__
实现对允许的操作更具包容性:
class MetadataTensor(object):
def __init__(self, data, metadata=None, **kwargs):
self._t = torch.as_tensor(data, **kwargs)
self._metadata = metadata
def __repr__(self):
return "Metadata:\n{}\n\ndata:\n{}".format(self._metadata, self._t)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
metadatas = tuple(a._metadata for a in args if hasattr(a, '_metadata'))
args = [getattr(a, '_t', a) for a in args]
assert len(metadatas) > 0
ret = func(*args, **kwargs)
return MetadataTensor(ret, metadata=metadatas[0])
这个简单的实现不一定适用于 torch
API 中的每个函数,但足以涵盖大多数常用操作。
>>> metadata = {'owner': 'Ministry of Silly Walks'}
>>> m = MetadataTensor([[1, 2], [3, 4]], metadata=metadata)
>>> t = torch.tensor([[1, 2], [1, 2]])
>>> torch.add(t, m)
Metadata:
{'owner': 'Ministry of Silly Walks'}
data:
tensor([[2, 4],
[4, 6]])
>>> torch.mul(t, m)
Metadata:
{'owner': 'Ministry of Silly Walks'}
data:
tensor([[1, 4],
[3, 8]])
操作多个定义了 __torch_function__
的类型¶
可以使用 torch API 处理多个各自拥有 __torch_function__
实现的不同类型,但需要特别注意。在这种情况下,规则如下:
分派操作会收集每个操作数上所有不同的
__torch_function__
实现,并按顺序调用它们:子类优先于超类,否则按操作表达式中的从左到右顺序。如果返回的值不是
NotImplemented
,则该值作为结果返回。实现可以通过返回NotImplemented
来表明它们不实现某个操作。如果所有
__torch_function__
实现都返回NotImplemented
,PyTorch 将引发TypeError
。
测试 PyTorch API 覆盖范围¶
实现 __torch_function__
的一个棘手方面是,如果某些操作有覆盖而另一些没有,用户充其量会遇到不一致的体验,最坏的情况下会在使用没有覆盖的函数时遇到运行时错误。为了简化这个过程,PyTorch 提供了一个面向开发者的 API,用于确保对 __torch_function__
覆盖的全面支持。此 API 是私有的,未来可能会在没有警告的情况下进行更改。
首先,要获取所有可覆盖函数的列表,请使用 torch.overrides._get_overridable_functions
。这会返回一个字典,其键是 PyTorch
Python API 中的命名空间,值是该命名空间中可以被覆盖的函数列表。例如,让我们打印 torch.nn.functional
中前 5 个可覆盖函数的名称:
>>> from torch.overrides import get_overridable_functions
>>> func_dict = get_overridable_functions()
>>> nn_funcs = func_dict[torch.nn.functional]
>>> print([f.__name__ for f in nn_funcs[:5])
['adaptive_avg_pool1d', 'adaptive_avg_pool2d', 'adaptive_avg_pool3d',
'adaptive_max_pool1d', 'adaptive_max_pool1d_with_indices']
这个函数列表使得可以迭代所有可覆盖的函数,然而实际上,如果不能费力地手动复制每个函数的签名进行测试,这不足以编写针对所有这些函数的测试。为了简化这个过程,torch.overrides._get_testing_overrides
函数返回一个字典,将 PyTorch API 中可覆盖的函数映射到具有相同签名的模拟 lambda 函数,这些函数无条件返回 -1。这些函数最适合与 inspect
一起使用,分析原始 PyTorch 函数的函数签名:
>>> import inspect
>>> from torch.overrides import get_testing_overrides
>>> override_dict = get_testing_overrides()
>>> dummy_add = override_dict[torch.add]
>>> inspect.signature(dummy_add)
<Signature (input, other, out=None)>
最后,torch.overrides.get_ignored_functions
返回一个函数元组,这些函数明确不能被 __torch_function__
覆盖。这个列表对于确认在 get_overridable_functions
返回的字典中不存在的函数无法被覆盖是很有用的。
扩展 torch
原生 API¶
虽然 __torch_function__
允许有效地扩展 PyTorch 纯 Python 组件的行为,但它不允许扩展 PyTorch 中用 C++ 实现的部分。为此,Tensor
子类也可以定义 __torch_dispatch__
,它将能够覆盖 C++ 级别的行为。
要有效使用此功能,了解 PyTorch 的原生部分是如何实现的非常重要。其中最重要的组件是我们称之为“分发器” (dispatcher) 的东西(最好的描述可以在这篇博客文章中找到,尽管它略有 outdated)。正如其名称所示,它负责为函数的特定调用调用正确的后端函数。例如,当调用 torch.add(a, b)
时,分发器会检查两个参数,确定应为此特定调用使用哪个“特性”(autograd、autocast、functionalization 等)和哪个“后端”(CPU、CUDA、MPS 等),最后调用所有正确的内核。内核经常做的一件事是“重新分派” (redispatch)。例如,当在 GPU 上使用 autocast 运行神经网络时,第一次调用将是 autocast 内核,它将处理任何潜在的 autocast 逻辑,然后向下重新分派。队列中的下一个特性将是 autograd,它将正确创建 autograd 图,然后向下重新分派。最后,我们到达 CUDA 的后端内核,它将启动正确的 CUDA 内核并返回最终结果。在返回途中,autograd 将图附加到输出,最后,autocast 将有机会在退出时进行任何所需的更新。
分发器的一种配置是所有这些特性和后端键的调用顺序。最新的列表及其顺序可以在 DispatchKey.h
中的 DispatchKey
枚举中找到。就扩展 torch 而言,本次讨论中重要的顺序子集是:
vmap -> Autocast -> Autograd -> ZeroTensor -> Neg/Conj -> Functionalize -> Python -> Backends
就本次讨论而言,最重要的键是 Python
,因为所有定义了 __torch_dispatch__
方法的 Tensor 子类都将调用此特性。用户定义的方法就是从这里调用的,并且可以在这里任意覆盖行为。从这里,再次调用提供的 func
将执行“重新分派”。
此实现的一些重要含义是:
这段代码运行在“所有特性之下”。因此,它仅负责(像常规后端一样)生成每个 Tensor 的输出值(并且可以,也应该,忽略所有高级特性,如 autograd、autocast 等)。
如果任何高级特性在不重新分派的情况下实现了给定函数,它将永远不会到达
Python
键,因此__torch_dispatch__
回调将永远不会被触发。对于 CompositeImplicitAutograd 函数尤其如此,它们在 Autograd 级别进行评估而不进行重新分派。这是因为 CompositeImplicitAutograd 函数通过隐式调用其他原生操作来指定其 autograd 公式,因此在 Autograd 级别,该函数会被分解为其原生操作并对其进行评估。在回调到 Python 并包装结果时,使用的转换与常规 PyTorch Python/C++ 绑定相同。特别地,有些对象无法在 Python 中表示,需要特殊处理(例如,未定义的 Tensors 会变成 None)。
我们的原生函数被延迟填充为
torch.ops.{namespace}.{func_name}.{overload_name}
的可调用 Python 对象,以便从 Python 轻松与其交互。传递给__torch_dispatch__
的func
对象始终是此命名空间中的一个条目。此命名空间可用于直接调用原生操作,绕过常规 Python API 和绑定代码。
与 __torch_function__
能够拦截 torch 的所有 Python API 和 Tensor 方法类似,__torch_dispatch__
能够拦截所有对 aten 原生 API 的调用。请注意,Tensor 上的所有方法在进入分发器之前都会转换为函数调用,因此会在此处显示为函数调用:torch.add(a, 2)
和 a + 2
将导致完全相同的 aten 调用。这些函数大多数定义在 native_functions.yaml
中,其中指定了这些函数的属性及其后端实现。然后,通过代码生成,它们的实现以及指定的特性会被自动注册。一些更奇特的函数或特性也在 C++ 代码库的其他地方或用户定义的 C++ 扩展中注册。
也可以使用 torch.library
添加 新的 原生函数。这个 Python 特性允许定义和/或向原生函数添加新的实现。这可用于添加缺失的内核、替换现有内核或定义全新的原生函数。
您可以在 subclass zoo 仓库中找到许多基于 __torch_dispatch__
的子类示例。
__torch_dispatch__
调用约定¶
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
pass
当用户调用带有定义了 __torch_dispatch__
的输入的运算符时,该调用可能会被转发到 __torch_dispatch__
。在调用 __torch_dispatch__
之前,args 和 kwargs 会被标准化,也就是说:
kwargs
由运算符 schema 中的仅关键字参数组成。如果某个关键字参数等于其默认值 (在 schema 中),则不会传递它。args
由所有其他参数组成,无论它们如何传递给运算符(位置参数 vs 关键字参数)。如果某个参数等于其默认值,并且它是最右边的位置参数,或者其右边的所有参数都没有传递,则不会传递它。
使用模式 (Modes) 扩展所有 torch
API¶
不幸的是,有些函数不接受 Tensor 输入。这意味着上面描述的子类方法无法用于覆盖 PyTorch 所有函数的行为。此外,如果用例需要拦截每个函数调用,将每个 Tensor 更改为子类可能会过于侵入性。
为了解决这个用例,我们引入了“模式”(Mode) 的概念。它们存在于 __torch_function__
和 __torch_dispatch__
的覆盖中,分别通过子类化 torch.overrides.TorchFunctionMode
和 torch.utils._python_dispatch.TorchDispatchMode
创建,并作为上下文管理器使用。
为了简化其与子类和其他模式交互的描述,每当模式的上下文管理器被进入时,每个函数都表现得好像参数列表开头有一个额外的 Tensor 参数,该 Tensor 的子类就是该模式。这意味着所有模式处理程序将先于任何子类处理程序被调用,并且与内部上下文管理器对应的模式将始终首先运行。
同样重要的是要注意,在给定的模式处理程序内,此特定模式被禁用,可以通过执行 with self:
手动重新启用。
这是一个显示不同类型模式日志记录的示例:
import torch
from torch.overrides import TorchFunctionMode, resolve_name
from torch.utils._python_dispatch import TorchDispatchMode
class FunctionLog(TorchFunctionMode):
def __torch_function__(self, func, types, args, kwargs=None):
print(f"Function Log: {resolve_name(func)}(*{args}, **{kwargs})")
return func(*args, **(kwargs or {}))
class DispatchLog(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args, kwargs=None):
print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
return func(*args, **(kwargs or {}))
def f():
a = torch.rand(10, requires_grad=True)
b = a * 2
b.sum().backward()
print("TorchFunctionMode logging:")
with FunctionLog():
f()
print("TorchDispatchMode logging:")
with DispatchLog():
f()
打印以下内容,附带额外的注释:
TorchFunctionMode logging:
Function Log: torch.rand(*(10,), **{'requires_grad': True})
Function Log: torch.Tensor.mul(*(tensor([0.7164, 0.9897, 0.1745, 0.9336, 0.4287, 0.7989, 0.2169, 0.7474, 0.5624,
0.5970], requires_grad=True), 2), **None)
Function Log: torch.Tensor.sum(*(tensor([1.4328, 1.9794, 0.3490, 1.8671, 0.8573, 1.5977, 0.4338, 1.4948, 1.1249,
1.1939], grad_fn=<MulBackward0>),), **None)
# Note that at the python level, we only see the call to backward but not what happens in the autograd engine.
Function Log: torch.Tensor.backward(*(tensor(12.3307, grad_fn=<SumBackward0>),), **{'gradient': None, 'retain_graph': None, 'create_graph': False, 'inputs': None})
TorchDispatchMode logging:
# Here the requires_grad flag from autograd is removed while default arguments were populated.
Dispatch Log: aten.rand.default(*([10],), **{'device': device(type='cpu'), 'pin_memory': False})
Dispatch Log: aten.mul.Tensor(*(tensor([0.2151, 0.6018, 0.8415, 0.9060, 0.2974, 0.7708, 0.6668, 0.0352, 0.7948,
0.6023], requires_grad=True), 2), **{})
Dispatch Log: aten.sum.default(*(tensor([0.4303, 1.2036, 1.6831, 1.8120, 0.5949, 1.5416, 1.3335, 0.0705, 1.5897,
1.2046], grad_fn=<MulBackward0>),), **{})
# Here we don't see the call to backward itself, but its constituents. Starting here with the factory function that creates the initial gradient.
Dispatch Log: aten.ones_like.default(*(tensor(11.4637, grad_fn=<SumBackward0>),), **{'pin_memory': False, 'memory_format': torch.preserve_format})
# This is the backward of the sum
Dispatch Log: aten.expand.default(*(tensor(1.), [10]), **{})
Dispatch Log: aten.mul.Tensor(*(tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]), 2), **{})
Dispatch Log: aten.detach.default(*(tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]),), **{})
Dispatch Log: aten.detach.default(*(tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]),), **{})