快捷方式

torch.library

torch.library 是一个 API 集合,用于扩展 PyTorch 的核心运算符库。它包含用于测试自定义运算符、创建新的自定义运算符以及扩展使用 PyTorch 的 C++ 运算符注册 API(例如 aten 运算符)定义的运算符的实用程序。

有关有效使用这些 API 的详细指南,请参阅 PyTorch 自定义运算符登录页面,以详细了解如何有效使用这些 API。

测试自定义运算符

使用 torch.library.opcheck() 测试自定义运算符是否错误使用了 Python torch.library 和/或 C++ TORCH_LIBRARY API。此外,如果您的运算符支持训练,请使用 torch.autograd.gradcheck() 测试梯度在数学上是否正确。

torch.library.opcheck(op, args, kwargs=None, *, test_utils=('test_schema', 'test_autograd_registration', 'test_faketensor', 'test_aot_dispatch_dynamic'), raise_exception=True)[源代码]

给定一个运算符和一些示例参数,测试该运算符是否已正确注册。

也就是说,当您使用 torch.library/TORCH_LIBRARY API 创建自定义运算符时,您指定了有关该自定义运算符的元数据(例如可变性信息),并且这些 API 要求您传递给它们的函数满足某些属性(例如,在 fake/meta/abstract 内核中没有数据指针访问)opcheck 测试这些元数据和属性。

具体而言,我们测试以下内容:- test_schema:运算符的模式是否正确。- test_autograd_registration:自动微分是否已正确注册。- test_faketensor:运算符是否具有 FakeTensor 内核(以及它是否正确)。FakeTensor 内核是运算符与 PyTorch 编译 API(torch.compile/export/FX)一起工作所必需的(但不是充分的)。- test_aot_dispatch_dynamic:运算符在使用 PyTorch 编译 API(torch.compile/export/FX)时是否具有正确的行为。这将检查在 eager-mode PyTorch 和 torch.compile 下输出(以及梯度,如果适用)是否相同。此测试是 test_faketensor 的超集。

为获得最佳结果,请使用一组具有代表性的输入多次调用 opcheck。如果您的运算符支持自动微分,请使用 requires_grad = True 的输入调用 opcheck;如果您的运算符支持多个设备(例如 CPU 和 CUDA),请对所有受支持的设备上的输入使用 opcheck

参数
  • op (Union[OpOverload, OpOverloadPacket, CustomOpDef]) – 运算符。必须是使用 torch.library.custom_op() 装饰的函数或在 torch.ops.* 中找到的 OpOverload/OpOverloadPacket(例如 torch.ops.aten.sin、torch.ops.mylib.foo)

  • args (Tuple[Any, ...]) – 运算符的参数

  • kwargs (Optional[Dict[str, Any]]) – 运算符的关键字参数

  • test_utils (Union[str, Sequence[str]]) – 我们应该运行的测试。默认值:全部。示例:("test_schema", "test_faketensor")

  • raise_exception (bool) – 如果在第一个错误时抛出异常,则为 True。 如果为 False,我们将返回一个字典,其中包含每个测试是否通过的信息。

返回类型

Dict[str, str]

警告

opcheck 和 torch.autograd.gradcheck() 测试不同的内容;opcheck 测试您对 torch.library API 的使用是否正确,而 torch.autograd.gradcheck() 测试您的自动求导公式在数学上是否正确。 使用这两者来测试支持梯度计算的自定义运算。

示例

>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
>>> def numpy_add(x: Tensor, y: float) -> Tensor:
>>>     x_np = x.numpy(force=True)
>>>     z_np = x_np + y
>>>     return torch.from_numpy(z_np).to(x.device)
>>>
>>> @numpy_sin.register_fake
>>> def _(x, y):
>>>     return torch.empty_like(x)
>>>
>>> def setup_context(ctx, inputs, output):
>>>     y, = inputs
>>>     ctx.y = y
>>>
>>> def backward(ctx, grad):
>>>     return grad * ctx.y, None
>>>
>>> numpy_sin.register_autograd(backward, setup_context=setup_context)
>>>
>>> sample_inputs = [
>>>     (torch.randn(3), 3.14),
>>>     (torch.randn(2, 3, device='cuda'), 2.718),
>>>     (torch.randn(1, 10, requires_grad=True), 1.234),
>>>     (torch.randn(64, 64, device='cuda', requires_grad=True), 90.18),
>>> ]
>>>
>>> for args in sample_inputs:
>>>     torch.library.opcheck(foo, args)

在 Python 中创建新的自定义运算

使用 torch.library.custom_op() 创建新的自定义运算。

torch.library.custom_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None)

将函数包装到自定义运算符中。

您可能想要创建自定义运算的原因包括:- 包装第三方库或自定义内核以使用 PyTorch 子系统(如 Autograd)。- 防止 torch.compile/export/FX 跟踪窥视您的函数内部。

此 API 用作函数周围的装饰器(请参阅示例)。 提供的函数必须具有类型提示;这些是与 PyTorch 的各种子系统交互所必需的。

参数
  • name (str) – 自定义运算的名称,格式为“{namespace}::{name}”,例如“mylib::my_linear”。 该名称用作 PyTorch 子系统(例如 torch.export、FX 图)中运算的稳定标识符。 为避免名称冲突,请使用您的项目名称作为命名空间;例如,pytorch/fbgemm 中的所有自定义运算都使用“fbgemm”作为命名空间。

  • mutates_args (Iterable[str]) – 函数变异的参数名称。 这必须是准确的,否则行为未定义。

  • device_types (None | str | Sequence[str]) – 函数有效的设备类型。 如果未提供设备类型,则该函数将用作所有设备类型的默认实现。 示例:“cpu”、“cuda”。

  • schema (None | str) – 运算符的架构字符串。 如果为 None(推荐),我们将从其类型注释中推断运算符的架构。 我们建议让我们推断架构,除非您有特殊原因不这样做。 示例:“(Tensor x, int y) -> (Tensor, Tensor)”。

返回类型

可调用

注意

我们建议不要传入 schema 参数,而是让我们从类型注释中推断它。 编写自己的架构很容易出错。 如果我们对类型注释的解释不是您想要的,您可能希望提供自己的架构。 有关如何编写架构字符串的更多信息,请参阅此处

示例:
>>> import torch
>>> from torch import Tensor
>>> from torch.library import custom_op
>>> import numpy as np
>>>
>>> @custom_op("mylib::numpy_sin", mutates_args=())
>>> def numpy_sin(x: Tensor) -> Tensor:
>>>     x_np = x.cpu().numpy()
>>>     y_np = np.sin(x_np)
>>>     return torch.from_numpy(y_np).to(device=x.device)
>>>
>>> x = torch.randn(3)
>>> y = numpy_sin(x)
>>> assert torch.allclose(y, x.sin())
>>>
>>> # Example of a custom op that only works for one device type.
>>> @custom_op("mylib::numpy_sin_cpu", mutates_args=(), device_types="cpu")
>>> def numpy_sin_cpu(x: Tensor) -> Tensor:
>>>     x_np = x.numpy()
>>>     y_np = np.sin(x_np)
>>>     return torch.from_numpy(y_np)
>>>
>>> x = torch.randn(3)
>>> y = numpy_sin_cpu(x)
>>> assert torch.allclose(y, x.sin())
>>>
>>> # Example of a custom op that mutates an input
>>> @custom_op("mylib::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu")
>>> def numpy_sin_inplace(x: Tensor) -> None:
>>>     x_np = x.numpy()
>>>     np.sin(x_np, out=x_np)
>>>
>>> x = torch.randn(3)
>>> expected = x.sin()
>>> numpy_sin_inplace(x)
>>> assert torch.allclose(x, expected)

扩展自定义运算(从 Python 或 C++ 创建)

使用 register.* 方法(例如 torch.library.register_kernel() 和 func:torch.library.register_fake)为任何运算符添加实现(它们可能是使用 torch.library.custom_op() 或通过 PyTorch 的 C++ 运算符注册 API 创建的)。

torch.library.register_kernel(op, device_types, func=None, /, *, lib=None)[源代码]

为此运算符注册设备类型的实现。

一些有效的 device_types 是:“cpu”、“cuda”、“xla”、“mps”、“ipu”、“xpu”。 此 API 可用作装饰器。

参数
  • fn (Callable) – 要注册为给定设备类型的实现的函数。

  • device_types (None | str | Sequence[str]) – 要将实现注册到的 device_types。 如果为 None,我们将注册到所有设备类型 - 仅当您的实现真正与设备类型无关时,才使用此选项。

示例:
>>> import torch
>>> from torch import Tensor
>>> from torch.library import custom_op
>>> import numpy as np
>>>
>>> # Create a custom op that works on cpu
>>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu")
>>> def numpy_sin(x: Tensor) -> Tensor:
>>>     x_np = x.numpy()
>>>     y_np = np.sin(x_np)
>>>     return torch.from_numpy(y_np)
>>>
>>> # Add implementations for the cuda device
>>> @torch.library.register_kernel("mylib::numpy_sin", "cuda")
>>> def _(x):
>>>     x_np = x.cpu().numpy()
>>>     y_np = np.sin(x_np)
>>>     return torch.from_numpy(y_np).to(device=x.device)
>>>
>>> x_cpu = torch.randn(3)
>>> x_cuda = x_cpu.cuda()
>>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin())
>>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin())
torch.library.register_autograd(op, backward, /, *, setup_context=None, lib=None)[源代码]

为此自定义运算注册反向公式。

为了使运算符与 autograd 一起工作,您需要注册一个反向公式:1. 您必须通过向我们提供一个“backward”函数来告诉我们如何在反向传递期间计算梯度。 2. 如果您需要来自正向的任何值来计算梯度,则可以使用 setup_context 保存值以供反向使用。

backward 在反向传递期间运行。 它接受 (ctx, *grads):- grads 是一个或多个梯度。 梯度的数量与运算符的输出数量相匹配。 ctx 对象是torch.autograd.Function 使用的相同的 ctx 对象。 backward_fn 的语义与 torch.autograd.Function.backward() 相同。

setup_context(ctx, inputs, output) 在正向传递期间运行。 请通过 torch.autograd.function.FunctionCtx.save_for_backward() 将反向所需的量保存到 ctx 对象上,或将它们分配为 ctx 的属性。 如果您的自定义运算具有仅限关键字的参数,我们希望 setup_context 的签名为 setup_context(ctx, inputs, keyword_only_inputs, output)

setup_context_fnbackward_fn 都必须是可跟踪的。 也就是说,它们不能直接访问 torch.Tensor.data_ptr(),并且它们不能依赖或改变全局状态。 如果您需要不可跟踪的反向,则可以将其设为在 backward_fn 内部调用的单独 custom_op。

示例

>>> import torch
>>> import numpy as np
>>> from torch import Tensor
>>>
>>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=())
>>> def numpy_sin(x: Tensor) -> Tensor:
>>>     x_np = x.cpu().numpy()
>>>     y_np = np.sin(x_np)
>>>     return torch.from_numpy(y_np).to(device=x.device)
>>>
>>> def setup_context(ctx, inputs, output) -> Tensor:
>>>     x, = inputs
>>>     ctx.save_for_backward(x)
>>>
>>> def backward(ctx, grad):
>>>     x, = ctx.saved_tensors
>>>     return grad * x.cos()
>>>
>>> torch.library.register_autograd("mylib::numpy_sin", backward, setup_context=setup_context)
>>>
>>> x = torch.randn(3, requires_grad=True)
>>> y = numpy_sin(x)
>>> grad_x, = torch.autograd.grad(y, x, torch.ones_like(y))
>>> assert torch.allclose(grad_x, x.cos())
>>>
>>> # Example with a keyword-only arg
>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
>>> def numpy_mul(x: Tensor, *, val: float) -> Tensor:
>>>     x_np = x.cpu().numpy()
>>>     y_np = x_np * val
>>>     return torch.from_numpy(y_np).to(device=x.device)
>>>
>>> def setup_context(ctx, inputs, keyword_only_inputs, output) -> Tensor:
>>>     ctx.val = keyword_only_inputs["val"]
>>>
>>> def backward(ctx, grad):
>>>     return grad * ctx.val
>>>
>>> torch.library.register_autograd("mylib::numpy_mul", backward, setup_context=setup_context)
>>>
>>> x = torch.randn(3, requires_grad=True)
>>> y = numpy_mul(x, val=3.14)
>>> grad_x, = torch.autograd.grad(y, x, torch.ones_like(y))
>>> assert torch.allclose(grad_x, torch.full_like(x, 3.14))
torch.library.register_fake(op, func=None, /, *, lib=None, _stacklevel=1)[源代码]

为此运算符注册 FakeTensor 实现(“fake impl”)。

有时也称为“元内核”、“抽象实现”。

“FakeTensor 实现”指定此运算符在不携带数据(“FakeTensor”)的 Tensor 上的行为。 给定一些具有某些属性(大小/步幅/存储偏移量/设备)的输入 Tensor,它指定输出 Tensor 的属性是什么。

FakeTensor 实现与运算符具有相同的签名。 它针对 FakeTensor 和元张量运行。 要编写 FakeTensor 实现,请假设运算符的所有 Tensor 输入都是常规 CPU/CUDA/Meta 张量,但它们没有存储,并且您尝试返回常规 CPU/CUDA/Meta 张量作为输出。 FakeTensor 实现必须仅包含 PyTorch 运算(并且不能直接访问任何输入或中间 Tensor 的存储或数据)。

此 API 可用作装饰器(请参阅示例)。

有关自定义运算的详细指南,请参阅 https://pytorch.ac.cn/tutorials/advanced/custom_ops_landing_page.html

示例

>>> import torch
>>> import numpy as np
>>> from torch import Tensor
>>>
>>> # Example 1: an operator without data-dependent output shape
>>> @torch.library.custom_op("mylib::custom_linear", mutates_args=())
>>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
>>>     raise NotImplementedError("Implementation goes here")
>>>
>>> @torch.library.register_fake("mylib::custom_linear")
>>> def _(x, weight, bias):
>>>     assert x.dim() == 2
>>>     assert weight.dim() == 2
>>>     assert bias.dim() == 1
>>>     assert x.shape[1] == weight.shape[1]
>>>     assert weight.shape[0] == bias.shape[0]
>>>     assert x.device == weight.device
>>>
>>>     return (x @ weight.t()) + bias
>>>
>>> with torch._subclasses.fake_tensor.FakeTensorMode():
>>>     x = torch.randn(2, 3)
>>>     w = torch.randn(3, 3)
>>>     b = torch.randn(3)
>>>     y = torch.ops.mylib.custom_linear(x, w, b)
>>>
>>> assert y.shape == (2, 3)
>>>
>>> # Example 2: an operator with data-dependent output shape
>>> @torch.library.custom_op("mylib::custom_nonzero", mutates_args=())
>>> def custom_nonzero(x: Tensor) -> Tensor:
>>>     x_np = x.numpy(force=True)
>>>     res = np.stack(np.nonzero(x_np), axis=1)
>>>     return torch.tensor(res, device=x.device)
>>>
>>> @torch.library.register_fake("mylib::custom_nonzero")
>>> def _(x):
>>>     # Number of nonzero-elements is data-dependent.
>>>     # Since we cannot peek at the data in an fake impl,
>>>     # we use the ctx object to construct a new symint that
>>>     # represents the data-dependent size.
>>>     ctx = torch.library.get_ctx()
>>>     nnz = ctx.new_dynamic_size()
>>>     shape = [nnz, x.dim()]
>>>     result = x.new_empty(shape, dtype=torch.int64)
>>>     return result
>>>
>>> from torch.fx.experimental.proxy_tensor import make_fx
>>>
>>> x = torch.tensor([0, 1, 2, 3, 4, 0])
>>> trace = make_fx(torch.ops.mylib.custom_nonzero, tracing_mode="symbolic")(x)
>>> trace.print_readable()
>>>
>>> assert torch.allclose(trace(x), torch.ops.mylib.custom_nonzero(x))
torch.library.impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1)[source]

此 API 在 PyTorch 2.4 中已重命名为 torch.library.register_fake()。请改用它。

torch.library.get_ctx()[source]

get_ctx() 返回当前的 AbstractImplCtx 对象。

仅在伪实现内部调用 get_ctx() 才有效(有关更多使用细节,请参阅 torch.library.register_fake())。

返回类型

AbstractImplCtx

底层 API

以下 API 是与 PyTorch 的 C++ 底层算子注册 API 的直接绑定。

警告

底层算子注册 API 和 PyTorch Dispatcher 是一个复杂的 PyTorch 概念。我们建议您尽可能使用上面的高级 API(不需要 torch.library.Library 对象)。这篇博文 <http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/>`_ 是了解 PyTorch Dispatcher 的一个很好的起点。

您可以在 Google Colab 上找到一个教程,该教程将引导您完成有关如何使用此 API 的一些示例。

class torch.library.Library(ns, kind, dispatch_key='')[source]

一个用于创建库的类,这些库可用于注册新的算子或从 Python 覆盖现有库中的算子。如果用户只想注册与一个特定调度键相对应的内核,则可以选择传入调度键名。

要创建一个库来覆盖现有库(名称为 ns)中的算子,请将类型设置为“IMPL”。要创建一个新库(名称为 ns)来注册新的算子,请将类型设置为“DEF”。要创建一个可能存在的库的片段来注册算子(并绕过给定命名空间只有一个库的限制),请将类型设置为“FRAGMENT”。

参数
  • ns – 库名称

  • kind – “DEF”、“IMPL”(默认值:“IMPL”)、“FRAGMENT”

  • dispatch_key – PyTorch 调度键(默认值:“”)

define(schema, alias_analysis='', *, tags=())[source]

在 ns 命名空间中定义一个新的算子及其语义。

参数
  • schema – 用于定义新算子的函数模式。

  • alias_analysis (可选) – 指示算子参数的别名属性是否可以从模式(默认行为)或“CONSERVATIVE”推断出来。

  • tags (Tag | Sequence[Tag]) – 要应用于此算子的一个或多个 torch.Tag。标记算子会改变算子在各种 PyTorch 子系统下的行为;请在应用之前仔细阅读 torch.Tag 的文档。

返回值

从模式推断出的算子名称。

示例:
>>> my_lib = Library("mylib", "DEF")
>>> my_lib.define("sum(Tensor self) -> Tensor")
impl(op_name, fn, dispatch_key='', *, with_keyset=False)[source]

注册库中定义的算子的函数实现。

参数
  • op_name – 算子名称(以及重载)或 OpOverload 对象。

  • fn – 作为输入调度键的算子实现的函数,或用于注册回退的 fallthrough_kernel()

  • dispatch_key – 应为其注册输入函数的调度键。默认情况下,它使用创建库时使用的调度键。

示例:
>>> my_lib = Library("aten", "IMPL")
>>> def div_cpu(self, other):
>>>     return self * (1 / other)
>>> my_lib.impl("div.Tensor", div_cpu, "CPU")
torch.library.fallthrough_kernel()[source]

一个传递给 Library.impl 以注册回退的虚拟函数。

torch.library.define(qualname, schema, *, lib=None, tags=())[source]
torch.library.define(lib, schema, alias_analysis='')

定义一个新的算子。

在 PyTorch 中,定义运算符(简称“算子”)是一个两步过程: - 我们需要定义算子(通过提供算子名称和模式) - 我们需要实现算子如何与各种 PyTorch 子系统(如 CPU/CUDA 张量、Autograd 等)交互的行为。

此入口点定义了自定义算子(第一步),然后您必须通过调用各种 impl_* API(如 torch.library.impl()torch.library.register_fake())来执行第二步。

参数
  • qualname (str) – 算子的限定名称。应该是一个类似于“命名空间::名称”的字符串,例如“aten::sin”。PyTorch 中的算子需要一个命名空间来避免名称冲突;一个给定的算子只能创建一次。如果您正在编写 Python 库,我们建议将命名空间设置为顶级模块的名称。

  • schema (str) – 算子的模式。例如,对于接受一个张量并返回一个张量的算子,其模式为“(张量 x) -> 张量”。它不包含算子名称(在 qualname 中传递)。

  • lib (可选[Library]) – 如果提供,则此算子的生命周期将与 Library 对象的生命周期绑定。

  • tags (Tag | Sequence[Tag]) – 要应用于此算子的一个或多个 torch.Tag。标记算子会改变算子在各种 PyTorch 子系统下的行为;请在应用之前仔细阅读 torch.Tag 的文档。

示例:
>>> import torch
>>> import numpy as np
>>>
>>> # Define the operator
>>> torch.library.define("mylib::sin", "(Tensor x) -> Tensor")
>>>
>>> # Add implementations for the operator
>>> @torch.library.impl("mylib::sin", "cpu")
>>> def f(x):
>>>     return torch.from_numpy(np.sin(x.numpy()))
>>>
>>> # Call the new operator from torch.ops.
>>> x = torch.randn(3)
>>> y = torch.ops.mylib.sin(x)
>>> assert torch.allclose(y, x.sin())
torch.library.impl(qualname, types, func=None, *, lib=None)[source]
torch.library.impl(lib, name, dispatch_key='')

为此运算符注册设备类型的实现。

您可以为 types 传递“default”,以将此实现注册为所有设备类型的默认实现。仅当实现真正支持所有设备类型时才使用此选项;例如,如果它是内置 PyTorch 算子的组合,则为 true。

一些有效的类型是:“cpu”、“cuda”、“xla”、“mps”、“ipu”、“xpu”。

参数
  • qualname (str) – 应该是一个类似于“命名空间::算子名称”的字符串。

  • types (str | Sequence[str]) – 要将实现注册到的设备类型。

  • lib可选[Library])- 如果提供,则此注册的生命周期将与 Library 对象的生命周期绑定。

示例

>>> import torch
>>> import numpy as np
>>>
>>> # Define the operator
>>> torch.library.define("mylib::mysin", "(Tensor x) -> Tensor")
>>>
>>> # Add implementations for the cpu device
>>> @torch.library.impl("mylib::mysin", "cpu")
>>> def f(x):
>>>     return torch.from_numpy(np.sin(x.numpy()))
>>>
>>> x = torch.randn(3)
>>> y = torch.ops.mylib.mysin(x)
>>> assert torch.allclose(y, x.sin())

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源