torch.library¶
torch.library 是用于扩展 PyTorch 核心算子库的 API 集合。它包含用于测试自定义算子、创建新的自定义算子以及扩展使用 PyTorch C++ 算子注册 API(例如 aten 算子)定义的算子的实用程序。
有关有效使用这些 API 的详细指南,请参阅 请参阅 PyTorch 自定义算子着陆页,以详细了解如何有效使用这些 API。
测试自定义算子¶
使用 torch.library.opcheck()
测试自定义算子在 Python torch.library 和/或 C++ TORCH_LIBRARY API 中的错误用法。此外,如果您的算子支持训练,请使用 torch.autograd.gradcheck()
测试梯度在数学上是否正确。
- torch.library.opcheck(op, args, kwargs=None, *, test_utils=('test_schema', 'test_autograd_registration', 'test_faketensor', 'test_aot_dispatch_dynamic'), raise_exception=True)[源代码]¶
给定一个算子和一些示例参数,测试该算子是否已正确注册。
也就是说,当您使用 torch.library/TORCH_LIBRARY API 创建自定义算子时,您指定了有关自定义算子的元数据(例如可变性信息),并且这些 API 要求您传递给它们的函数满足某些属性(例如,在伪/元/抽象内核中没有数据指针访问)
opcheck
测试这些元数据和属性。具体来说,我们测试以下内容
test_schema:架构是否与算子的实现匹配。例如:如果架构指定张量被修改,那么我们检查实现是否修改了张量。如果架构指定我们返回一个新的张量,那么我们检查实现是否返回一个新的张量(而不是现有的张量或现有张量的视图)。
test_autograd_registration:如果算子支持训练(autograd):我们检查其 autograd 公式是否通过 torch.library.register_autograd 或手动注册到一个或多个 DispatchKey::Autograd 键中。任何其他基于 DispatchKey 的注册都可能导致未定义的行为。
test_faketensor:如果算子具有 FakeTensor 内核(以及它是否正确)。FakeTensor 内核对于算子与 PyTorch 编译 API(torch.compile/export/FX)一起使用是必要的(但不是充分的)。我们检查是否为算子注册了 FakeTensor 内核(有时也称为元内核),以及它是否正确。此测试获取在真实张量上运行算子的结果和在 FakeTensor 上运行算子的结果,并检查它们是否具有相同的张量元数据(大小/步幅/数据类型/设备/等)。
test_aot_dispatch_dynamic:如果算子在 PyTorch 编译 API(torch.compile/export/FX)中具有正确的行为。这检查输出(以及适用的梯度)在急切模式 PyTorch 和 torch.compile 下是否相同。此测试是
test_faketensor
的超集,并且是一个端到端测试;它测试的其他内容包括算子是否支持功能化以及反向传递(如果存在)是否也支持 FakeTensor 和功能化。
为了获得最佳结果,请使用具有代表性输入集多次调用
opcheck
。如果您的算子支持 autograd,请使用requires_grad = True
的输入调用opcheck
;如果您的算子支持多个设备(例如 CPU 和 CUDA),请使用所有受支持设备上的输入调用opcheck
。- 参数
op (Union[OpOverload, OpOverloadPacket, CustomOpDef]) – 操作符。必须是使用
torch.library.custom_op()
装饰的函数,或者是在 torch.ops.* 中找到的 OpOverload/OpOverloadPacket(例如 torch.ops.aten.sin,torch.ops.mylib.foo)。test_utils (Union[str, Sequence[str]]) – 我们应该运行的测试。默认:所有测试。示例:(“test_schema”, “test_faketensor”)
raise_exception (bool) – 是否在第一个错误时抛出异常。如果为 False,我们将返回一个字典,其中包含每个测试是否通过的信息。
- 返回类型
警告
opcheck 和
torch.autograd.gradcheck()
测试的是不同的东西;opcheck 测试您对 torch.library API 的使用是否正确,而torch.autograd.gradcheck()
测试您的自动微分公式在数学上是否正确。对于支持梯度计算的自定义操作符,请同时使用这两种方法进行测试。示例
>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) >>> def numpy_add(x: Tensor, y: float) -> Tensor: >>> x_np = x.numpy(force=True) >>> z_np = x_np + y >>> return torch.from_numpy(z_np).to(x.device) >>> >>> @numpy_sin.register_fake >>> def _(x, y): >>> return torch.empty_like(x) >>> >>> def setup_context(ctx, inputs, output): >>> y, = inputs >>> ctx.y = y >>> >>> def backward(ctx, grad): >>> return grad * ctx.y, None >>> >>> numpy_sin.register_autograd(backward, setup_context=setup_context) >>> >>> sample_inputs = [ >>> (torch.randn(3), 3.14), >>> (torch.randn(2, 3, device='cuda'), 2.718), >>> (torch.randn(1, 10, requires_grad=True), 1.234), >>> (torch.randn(64, 64, device='cuda', requires_grad=True), 90.18), >>> ] >>> >>> for args in sample_inputs: >>> torch.library.opcheck(foo, args)
在 Python 中创建新的自定义操作符¶
使用 torch.library.custom_op()
创建新的自定义操作符。
- torch.library.custom_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None)¶
将函数包装成自定义操作符。
您可能希望创建自定义操作符的原因包括:- 将第三方库或自定义内核包装起来,以便与 PyTorch 子系统(如 Autograd)一起使用。- 防止 torch.compile/export/FX 跟踪窥视您的函数内部。
此 API 用作函数的装饰器(请参阅示例)。提供的函数必须具有类型提示;这些提示是与 PyTorch 的各种子系统交互所必需的。
- 参数
name (str) – 自定义操作符的名称,格式为“{namespace}::{name}”,例如“mylib::my_linear”。此名称用作操作符在 PyTorch 子系统(例如 torch.export、FX 图)中的稳定标识符。为了避免名称冲突,请使用您的项目名称作为命名空间;例如,pytorch/fbgemm 中的所有自定义操作符都使用“fbgemm”作为命名空间。
mutates_args (Iterable[str] or "unknown") – 函数修改的参数名称。这**必须**准确,否则行为未定义。如果为“unknown”,则悲观地假设操作符的所有输入都将被修改。
device_types (None | str | Sequence[str]) – 函数有效的设备类型。如果没有提供设备类型,则该函数将用作所有设备类型的默认实现。示例:“cpu”、“cuda”。当为不接受张量的操作符注册特定于设备的实现时,我们需要该操作符具有“device: torch.device”参数。
schema (None | str) – 操作符的模式字符串。如果为 None(推荐),我们将根据其类型注释推断操作符的模式。除非您有特定原因,否则我们建议您让我们推断模式。示例:“(Tensor x, int y) -> (Tensor, Tensor)”。
- 返回类型
注意
我们建议不要传入
schema
参数,而是让我们根据类型注释推断它。编写您自己的模式很容易出错。如果您不希望我们对类型注释的解释,则可能希望提供您自己的模式。有关如何编写模式字符串的更多信息,请参阅 此处- 示例:
>>> import torch >>> from torch import Tensor >>> from torch.library import custom_op >>> import numpy as np >>> >>> @custom_op("mylib::numpy_sin", mutates_args=()) >>> def numpy_sin(x: Tensor) -> Tensor: >>> x_np = x.cpu().numpy() >>> y_np = np.sin(x_np) >>> return torch.from_numpy(y_np).to(device=x.device) >>> >>> x = torch.randn(3) >>> y = numpy_sin(x) >>> assert torch.allclose(y, x.sin()) >>> >>> # Example of a custom op that only works for one device type. >>> @custom_op("mylib::numpy_sin_cpu", mutates_args=(), device_types="cpu") >>> def numpy_sin_cpu(x: Tensor) -> Tensor: >>> x_np = x.numpy() >>> y_np = np.sin(x_np) >>> return torch.from_numpy(y_np) >>> >>> x = torch.randn(3) >>> y = numpy_sin_cpu(x) >>> assert torch.allclose(y, x.sin()) >>> >>> # Example of a custom op that mutates an input >>> @custom_op("mylib::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu") >>> def numpy_sin_inplace(x: Tensor) -> None: >>> x_np = x.numpy() >>> np.sin(x_np, out=x_np) >>> >>> x = torch.randn(3) >>> expected = x.sin() >>> numpy_sin_inplace(x) >>> assert torch.allclose(x, expected) >>> >>> # Example of a factory function >>> @torch.library.custom_op("mylib::bar", mutates_args={}, device_types="cpu") >>> def bar(device: torch.device) -> Tensor: >>> return torch.ones(3) >>> >>> bar("cpu")
扩展自定义操作符(由 Python 或 C++ 创建)¶
使用 register.* 方法,例如 torch.library.register_kernel()
和 func:torch.library.register_fake,为任何操作符添加实现(它们可能是使用 torch.library.custom_op()
或通过 PyTorch 的 C++ 操作符注册 API 创建的)。
- torch.library.register_kernel(op, device_types, func=None, /, *, lib=None)[source]¶
为该操作符的设备类型注册一个实现。
一些有效的 device_types 为:“cpu”、“cuda”、“xla”、“mps”、“ipu”、“xpu”。此 API 可以用作装饰器。
- 参数
- 示例:
>>> import torch >>> from torch import Tensor >>> from torch.library import custom_op >>> import numpy as np >>> >>> # Create a custom op that works on cpu >>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu") >>> def numpy_sin(x: Tensor) -> Tensor: >>> x_np = x.numpy() >>> y_np = np.sin(x_np) >>> return torch.from_numpy(y_np) >>> >>> # Add implementations for the cuda device >>> @torch.library.register_kernel("mylib::numpy_sin", "cuda") >>> def _(x): >>> x_np = x.cpu().numpy() >>> y_np = np.sin(x_np) >>> return torch.from_numpy(y_np).to(device=x.device) >>> >>> x_cpu = torch.randn(3) >>> x_cuda = x_cpu.cuda() >>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin()) >>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin())
- torch.library.register_autograd(op, backward, /, *, setup_context=None, lib=None)[source]¶
为该自定义操作符注册反向传播公式。
为了使操作符能够与自动微分一起使用,您需要注册一个反向传播公式:1. 您必须通过提供“backward”函数来告诉我们如何在反向传播过程中计算梯度。2. 如果您需要来自正向传播的任何值来计算梯度,您可以使用setup_context保存用于反向传播的值。
backward
在反向传播过程中运行。它接受(ctx, *grads)
:-grads
是一个或多个梯度。梯度的数量与操作符输出的数量匹配。ctx
对象是 与torch.autograd.Function
使用的相同 ctx 对象。backward_fn
的语义与torch.autograd.Function.backward()
相同。setup_context(ctx, inputs, output)
在正向传播过程中运行。请使用torch.autograd.function.FunctionCtx.save_for_backward()
或将它们分配为ctx
的属性,将反向传播所需的数量保存到ctx
对象中。如果您的自定义操作符具有仅限关键字参数,我们期望setup_context
的签名为setup_context(ctx, inputs, keyword_only_inputs, output)
。setup_context_fn
和backward_fn
都必须是可追踪的。也就是说,它们不能直接访问torch.Tensor.data_ptr()
,并且不能依赖或修改全局状态。如果您需要一个不可追踪的反向传播,您可以将其作为一个单独的自定义操作,并在backward_fn
中调用它。示例
>>> import torch >>> import numpy as np >>> from torch import Tensor >>> >>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=()) >>> def numpy_sin(x: Tensor) -> Tensor: >>> x_np = x.cpu().numpy() >>> y_np = np.sin(x_np) >>> return torch.from_numpy(y_np).to(device=x.device) >>> >>> def setup_context(ctx, inputs, output) -> Tensor: >>> x, = inputs >>> ctx.save_for_backward(x) >>> >>> def backward(ctx, grad): >>> x, = ctx.saved_tensors >>> return grad * x.cos() >>> >>> torch.library.register_autograd( ... "mylib::numpy_sin", backward, setup_context=setup_context ... ) >>> >>> x = torch.randn(3, requires_grad=True) >>> y = numpy_sin(x) >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y)) >>> assert torch.allclose(grad_x, x.cos()) >>> >>> # Example with a keyword-only arg >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) >>> def numpy_mul(x: Tensor, *, val: float) -> Tensor: >>> x_np = x.cpu().numpy() >>> y_np = x_np * val >>> return torch.from_numpy(y_np).to(device=x.device) >>> >>> def setup_context(ctx, inputs, keyword_only_inputs, output) -> Tensor: >>> ctx.val = keyword_only_inputs["val"] >>> >>> def backward(ctx, grad): >>> return grad * ctx.val >>> >>> torch.library.register_autograd( ... "mylib::numpy_mul", backward, setup_context=setup_context ... ) >>> >>> x = torch.randn(3, requires_grad=True) >>> y = numpy_mul(x, val=3.14) >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y)) >>> assert torch.allclose(grad_x, torch.full_like(x, 3.14))
- torch.library.register_fake(op, func=None, /, *, lib=None, _stacklevel=1)[source]¶
为该算子注册一个 FakeTensor 实现(“伪实现”)。
有时也称为“元内核”、“抽象实现”。
“FakeTensor 实现”指定了该算子在不携带数据的张量(“FakeTensor”)上的行为。给定一些具有特定属性(大小/步长/存储偏移量/设备)的输入张量,它指定了输出张量的属性是什么。
FakeTensor 实现具有与算子相同的签名。它在 FakeTensor 和元张量上都运行。要编写 FakeTensor 实现,假设算子所有 Tensor 输入都是常规的 CPU/CUDA/Meta 张量,但它们没有存储,并且您尝试返回常规的 CPU/CUDA/Meta 张量作为输出。FakeTensor 实现必须仅由 PyTorch 操作组成(并且不能直接访问任何输入或中间张量的存储或数据)。
此 API 可用作装饰器(参见示例)。
有关自定义操作的详细指南,请参阅 https://pytorch.ac.cn/tutorials/advanced/custom_ops_landing_page.html
示例
>>> import torch >>> import numpy as np >>> from torch import Tensor >>> >>> # Example 1: an operator without data-dependent output shape >>> @torch.library.custom_op("mylib::custom_linear", mutates_args=()) >>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor: >>> raise NotImplementedError("Implementation goes here") >>> >>> @torch.library.register_fake("mylib::custom_linear") >>> def _(x, weight, bias): >>> assert x.dim() == 2 >>> assert weight.dim() == 2 >>> assert bias.dim() == 1 >>> assert x.shape[1] == weight.shape[1] >>> assert weight.shape[0] == bias.shape[0] >>> assert x.device == weight.device >>> >>> return (x @ weight.t()) + bias >>> >>> with torch._subclasses.fake_tensor.FakeTensorMode(): >>> x = torch.randn(2, 3) >>> w = torch.randn(3, 3) >>> b = torch.randn(3) >>> y = torch.ops.mylib.custom_linear(x, w, b) >>> >>> assert y.shape == (2, 3) >>> >>> # Example 2: an operator with data-dependent output shape >>> @torch.library.custom_op("mylib::custom_nonzero", mutates_args=()) >>> def custom_nonzero(x: Tensor) -> Tensor: >>> x_np = x.numpy(force=True) >>> res = np.stack(np.nonzero(x_np), axis=1) >>> return torch.tensor(res, device=x.device) >>> >>> @torch.library.register_fake("mylib::custom_nonzero") >>> def _(x): >>> # Number of nonzero-elements is data-dependent. >>> # Since we cannot peek at the data in an fake impl, >>> # we use the ctx object to construct a new symint that >>> # represents the data-dependent size. >>> ctx = torch.library.get_ctx() >>> nnz = ctx.new_dynamic_size() >>> shape = [nnz, x.dim()] >>> result = x.new_empty(shape, dtype=torch.int64) >>> return result >>> >>> from torch.fx.experimental.proxy_tensor import make_fx >>> >>> x = torch.tensor([0, 1, 2, 3, 4, 0]) >>> trace = make_fx(torch.ops.mylib.custom_nonzero, tracing_mode="symbolic")(x) >>> trace.print_readable() >>> >>> assert torch.allclose(trace(x), torch.ops.mylib.custom_nonzero(x))
- torch.library.register_vmap(op, func=None, /, *, lib=None)[source]¶
注册一个 vmap 实现以支持
torch.vmap()
用于此自定义操作。此 API 可用作装饰器(参见示例)。
为了使算子能够与
torch.vmap()
一起使用,您可能需要注册以下签名的 vmap 实现vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs)
,其中
*args
和**kwargs
是op
的参数和关键字参数。我们不支持仅关键字参数的 Tensor 参数。它指定了如何在给定具有额外维度(由
in_dims
指定)的输入的情况下计算op
的批处理版本。对于
args
中的每个参数,in_dims
都有一个对应的Optional[int]
。如果参数不是 Tensor 或参数没有被 vmap 遍历,则为None
,否则为一个整数,指定 Tensor 的哪个维度正在被 vmap 遍历。info
是一个包含其他元数据的集合,这些元数据可能会有所帮助:info.batch_size
指定正在被 vmap 遍历的维度的尺寸,而info.randomness
是传递给torch.vmap()
的randomness
选项。函数
func
的返回值是一个(output, out_dims)
元组。类似于in_dims
,out_dims
应该与output
具有相同的结构,并且每个输出都包含一个out_dim
,用于指定输出是否具有 vmap 维度以及该维度在什么索引位置。示例
>>> import torch >>> import numpy as np >>> from torch import Tensor >>> from typing import Tuple >>> >>> def to_numpy(tensor): >>> return tensor.cpu().numpy() >>> >>> lib = torch.library.Library("mylib", "FRAGMENT") >>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=()) >>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]: >>> x_np = to_numpy(x) >>> dx = torch.tensor(3 * x_np ** 2, device=x.device) >>> return torch.tensor(x_np ** 3, device=x.device), dx >>> >>> def numpy_cube_vmap(info, in_dims, x): >>> result = numpy_cube(x) >>> return result, (in_dims[0], in_dims[0]) >>> >>> torch.library.register_vmap(numpy_cube, numpy_cube_vmap) >>> >>> x = torch.randn(3) >>> torch.vmap(numpy_cube)(x) >>> >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) >>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor: >>> return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) >>> >>> @torch.library.register_vmap("mylib::numpy_mul") >>> def numpy_mul_vmap(info, in_dims, x, y): >>> x_bdim, y_bdim = in_dims >>> x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) >>> y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) >>> result = x * y >>> result = result.movedim(-1, 0) >>> return result, 0 >>> >>> >>> x = torch.randn(3) >>> y = torch.randn(3) >>> torch.vmap(numpy_mul)(x, y)
注意
vmap 函数应旨在保留整个自定义操作的语义。也就是说,
grad(vmap(op))
应该可以用grad(map(op))
替换。如果您的自定义操作在反向传播过程中有任何自定义行为,请记住这一点。
- torch.library.impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1)[source]¶
此 API 在 PyTorch 2.4 中已重命名为
torch.library.register_fake()
。请改用它。
- torch.library.get_ctx()[source]¶
get_ctx()
返回当前的 AbstractImplCtx 对象。仅在伪实现内部调用
get_ctx()
是有效的(有关更多用法细节,请参阅torch.library.register_fake()
)。- 返回类型
FakeImplCtx
- torch.library.register_torch_dispatch(op, torch_dispatch_class, func=None, /, *, lib=None)[source]¶
为给定的算子和
torch_dispatch_class
注册一个 torch_dispatch 规则。这允许开放注册以指定算子和
torch_dispatch_class
之间行为,而无需直接修改torch_dispatch_class
或算子。torch_dispatch_class
是一个具有__torch_dispatch__
的 Tensor 子类或 TorchDispatchMode。如果它是 Tensor 子类,我们期望
func
具有以下签名:(cls, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any
如果它是 TorchDispatchMode,我们期望
func
具有以下签名:(mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any
args
和kwargs
将以与__torch_dispatch__
中相同的方式进行规范化(参见 __torch_dispatch__ 调用约定)。示例
>>> import torch >>> >>> @torch.library.custom_op("mylib::foo", mutates_args={}) >>> def foo(x: torch.Tensor) -> torch.Tensor: >>> return x.clone() >>> >>> class MyMode(torch.utils._python_dispatch.TorchDispatchMode): >>> def __torch_dispatch__(self, func, types, args=(), kwargs=None): >>> return func(*args, **kwargs) >>> >>> @torch.library.register_torch_dispatch("mylib::foo", MyMode) >>> def _(mode, func, types, args, kwargs): >>> x, = args >>> return x + 1 >>> >>> x = torch.randn(3) >>> y = foo(x) >>> assert torch.allclose(y, x) >>> >>> with MyMode(): >>> y = foo(x) >>> assert torch.allclose(y, x + 1)
- torch.library.infer_schema(prototype_function, /, *, mutates_args, op_name=None)¶
使用类型提示解析给定函数的模式。模式是从函数的类型提示中推断出来的,可用于定义新的算子。
我们做出以下假设
没有一个输出与任何输入或彼此重叠。
- 没有库规范的字符串类型注释“device、dtype、Tensor、types”被假定为 torch.*。类似地,没有库规范的字符串类型注释“Optional、List、Sequence、Union”被假定为 typing.*。
- 只有
mutates_args
中列出的参数正在被修改。如果mutates_args
为“unknown”,则假定算子的所有输入都正在被修改。
调用者(例如自定义操作 API)负责检查这些假设。
- 参数
- 返回值
推断出的模式。
- 返回类型
示例
>>> def foo_impl(x: torch.Tensor) -> torch.Tensor: >>> return x.sin() >>> >>> infer_schema(foo_impl, op_name="foo", mutates_args={}) foo(Tensor x) -> Tensor >>> >>> infer_schema(foo_impl, mutates_args={}) (Tensor x) -> Tensor
- class torch._library.custom_ops.CustomOpDef(namespace, name, schema, fn)[source]¶
CustomOpDef 是一个函数的包装器,将其转换为自定义操作。
它有各种方法来注册此自定义操作的其他行为。
您不应该直接实例化 CustomOpDef;而是使用
torch.library.custom_op()
API。- set_kernel_enabled(device_type, enabled=True)[source]¶
禁用或重新启用已为此自定义运算符注册的内核。
如果内核已禁用/启用,则此操作无效。
注意
如果内核先被禁用然后注册,则它将保持禁用状态,直到再次启用。
示例
>>> inp = torch.randn(1) >>> >>> # define custom op `f`. >>> @custom_op("mylib::f", mutates_args=()) >>> def f(x: Tensor) -> Tensor: >>> return torch.zeros(1) >>> >>> print(f(inp)) # tensor([0.]), default kernel >>> >>> @f.register_kernel("cpu") >>> def _(x): >>> return torch.ones(1) >>> >>> print(f(inp)) # tensor([1.]), CPU kernel >>> >>> # temporarily disable the CPU kernel >>> with f.set_kernel_enabled("cpu", enabled = False): >>> print(f(inp)) # tensor([0.]) with CPU kernel disabled
低级 API¶
以下 API 是 PyTorch 的 C++ 低级运算符注册 API 的直接绑定。
警告
低级运算符注册 API 和 PyTorch 调度程序是 PyTorch 中一个复杂的概念。我们建议您尽可能使用上面更高层次的 API(不需要 torch.library.Library 对象)。这篇博文 <http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/>`_ 是了解 PyTorch 调度程序的一个很好的起点。
在 Google Colab 上提供了一个教程,其中介绍了如何使用此 API 的一些示例。
- class torch.library.Library(ns, kind, dispatch_key='')[source]¶
一个用于创建库的类,该库可用于从 Python 注册新的运算符或覆盖现有库中的运算符。如果用户只想注册对应于一个特定调度键的内核,则可以可选地传入一个调度键名。
要创建一个库来覆盖现有库(名称为 ns)中的运算符,请将 kind 设置为“IMPL”。要创建一个新库(名称为 ns)来注册新运算符,请将 kind 设置为“DEF”。要创建一个可能存在的库片段来注册运算符(并绕过每个命名空间只有一个库的限制),请将 kind 设置为“FRAGMENT”。
- 参数
ns – 库名称
kind – “DEF”、“IMPL”(默认值:“IMPL”)、“FRAGMENT”
dispatch_key – PyTorch 调度键(默认值:“”)
- define(schema, alias_analysis='', *, tags=())[source]¶
在 ns 命名空间中定义一个新的运算符及其语义。
- 参数
- 返回值
从模式中推断出的运算符名称。
- 示例:
>>> my_lib = Library("mylib", "DEF") >>> my_lib.define("sum(Tensor self) -> Tensor")
- fallback(fn, dispatch_key='', *, with_keyset=False)[source]¶
将函数实现注册为给定键的回退。
此函数仅适用于具有全局命名空间 (“_”) 的库。
- 参数
fn – 用作给定调度键的回退的函数或
fallthrough_kernel()
以注册回退。dispatch_key – 输入函数应为其注册的调度键。默认情况下,它使用创建库时使用的调度键。
with_keyset – 控制是否应将当前调度程序调用键集作为第一个参数传递给
fn
时调用的标志。这应该用于为重新调度调用创建适当的键集。
- 示例:
>>> my_lib = Library("_", "IMPL") >>> def fallback_kernel(op, *args, **kwargs): >>> # Handle all autocast ops generically >>> # ... >>> my_lib.fallback(fallback_kernel, "Autocast")
- impl(op_name, fn, dispatch_key='', *, with_keyset=False)[source]¶
注册库中定义的运算符的函数实现。
- 参数
op_name – 运算符名称(以及重载)或 OpOverload 对象。
fn – 作为输入调度键的运算符实现的函数或
fallthrough_kernel()
以注册回退。dispatch_key – 输入函数应为其注册的调度键。默认情况下,它使用创建库时使用的调度键。
with_keyset – 控制是否应将当前调度程序调用键集作为第一个参数传递给
fn
时调用的标志。这应该用于为重新调度调用创建适当的键集。
- 示例:
>>> my_lib = Library("aten", "IMPL") >>> def div_cpu(self, other): >>> return self * (1 / other) >>> my_lib.impl("div.Tensor", div_cpu, "CPU")
- torch.library.define(qualname, schema, *, lib=None, tags=())[source]¶
- torch.library.define(lib, schema, alias_analysis='')
定义一个新的运算符。
在 PyTorch 中,定义一个操作(简称“运算符”)是一个两步过程: - 我们需要定义操作(通过提供运算符名称和模式) - 我们需要实现操作与各种 PyTorch 子系统(如 CPU/CUDA 张量、Autograd 等)交互的行为。
此入口点定义自定义运算符(第一步),然后您必须通过调用各种
impl_*
API(如torch.library.impl()
或torch.library.register_fake()
)来执行第二步。- 参数
qualname (str) – 运算符的限定名称。应是一个类似于“namespace::name”的字符串,例如“aten::sin”。PyTorch 中的运算符需要一个命名空间来避免名称冲突;给定的运算符只能创建一次。如果您正在编写 Python 库,我们建议命名空间为顶级模块的名称。
schema (str) – 运算符的模式。例如,对于接受一个张量并返回一个张量的操作,“(Tensor x) -> Tensor”。它不包含运算符名称(该名称在
qualname
中传递)。lib (Optional[Library]) – 如果提供,则此运算符的生命周期将与 Library 对象的生命周期绑定。
tags (Tag | Sequence[Tag]) – 应用于此运算符的一个或多个 torch.Tag。标记运算符会更改运算符在各种 PyTorch 子系统下的行为;在应用它之前,请仔细阅读 torch.Tag 的文档。
- 示例:
>>> import torch >>> import numpy as np >>> >>> # Define the operator >>> torch.library.define("mylib::sin", "(Tensor x) -> Tensor") >>> >>> # Add implementations for the operator >>> @torch.library.impl("mylib::sin", "cpu") >>> def f(x): >>> return torch.from_numpy(np.sin(x.numpy())) >>> >>> # Call the new operator from torch.ops. >>> x = torch.randn(3) >>> y = torch.ops.mylib.sin(x) >>> assert torch.allclose(y, x.sin())
- torch.library.impl(qualname, types, func=None, *, lib=None)[source]¶
- torch.library.impl(lib, name, dispatch_key='')
为该操作符的设备类型注册一个实现。
您可以为
types
传递“default”,以将此实现注册为所有设备类型的默认实现。请仅当实现真正支持所有设备类型时才使用此方法;例如,如果它是内置 PyTorch 运算符的组合,则为真。一些有效的类型包括:“cpu”、“cuda”、“xla”、“mps”、“ipu”、“xpu”。
- 参数
示例
>>> import torch >>> import numpy as np >>> >>> # Define the operator >>> torch.library.define("mylib::mysin", "(Tensor x) -> Tensor") >>> >>> # Add implementations for the cpu device >>> @torch.library.impl("mylib::mysin", "cpu") >>> def f(x): >>> return torch.from_numpy(np.sin(x.numpy())) >>> >>> x = torch.randn(3) >>> y = torch.ops.mylib.mysin(x) >>> assert torch.allclose(y, x.sin())