torch.library¶
torch.library 是一个用于扩展 PyTorch 核心运算符库的 API 集合。它包含用于创建新的自定义运算符以及扩展使用 PyTorch 的 C++ 运算符注册 API(例如 aten 运算符)定义的运算符的实用程序。
有关有效使用这些 API 的详细指南,请参阅 此 gdoc
使用 torch.library.define()
定义新的自定义运算符。使用 impl 方法,例如 torch.library.impl()
和 func:torch.library.impl_abstract,为任何运算符添加实现(它们可能是使用 torch.library.define()
或通过 PyTorch 的 C++ 运算符注册 API 创建的)。
- torch.library.define(qualname, schema, *, lib=None, tags=())[source]¶
- torch.library.define(lib, schema, alias_analysis='')
定义一个新的运算符。
在 PyTorch 中,定义一个运算符(简称“操作符”)是一个两步过程: - 我们需要定义运算符(通过提供运算符名称和模式) - 我们需要实现运算符如何与各种 PyTorch 子系统(如 CPU/CUDA 张量、Autograd 等)交互的行为。
此入口点定义了自定义运算符(第一步),然后您必须通过调用各种
impl_*
API(如torch.library.impl()
或torch.library.impl_abstract()
)来执行第二步。- 参数
qualname (str) – 运算符的限定名称。应该是一个类似于“namespace::name”的字符串,例如“aten::sin”。PyTorch 中的运算符需要一个命名空间来避免名称冲突;给定运算符只能创建一次。如果您正在编写 Python 库,我们建议命名空间为您的顶级模块的名称。
schema (str) – 运算符的模式。例如,对于接受一个张量并返回一个张量的运算符,其模式为“(Tensor x) -> Tensor”。它不包含运算符名称(该名称在
qualname
中传递)。lib (Optional[Library]) – 如果提供,此运算符的生命周期将与 Library 对象的生命周期绑定。
tags (Tag | Sequence[Tag]) – 应用于此运算符的一个或多个 torch.Tag。标记运算符会更改运算符在各种 PyTorch 子系统下的行为;在应用它之前,请仔细阅读 torch.Tag 的文档。
- 示例:
>>> import torch >>> import numpy as np >>> >>> # Define the operator >>> torch.library.define("mylib::sin", "(Tensor x) -> Tensor") >>> >>> # Add implementations for the operator >>> @torch.library.impl("mylibrary::sin", "cpu") >>> def f(x): >>> return torch.from_numpy(np.sin(x.numpy())) >>> >>> # Call the new operator from torch.ops. >>> x = torch.randn(3) >>> y = torch.ops.mylib.sin(x) >>> assert torch.allclose(y, x)
- torch.library.impl(qualname, types, func=None, *, lib=None)[source]¶
- torch.library.impl(lib, name, dispatch_key='')
为该操作符的设备类型注册一个实现。
您可以将“default”传递给
types
以将此实现注册为所有设备类型的默认实现。请仅在实现真正支持所有设备类型时使用此方法;例如,如果它是内置 PyTorch 操作符的组合,则为真。一些有效的类型是:“cpu”、“cuda”、“xla”、“mps”、“ipu”、“xpu”。
- 参数
示例
>>> import torch >>> import numpy as np >>> >>> # Define the operator >>> torch.library.define("mylibrary::sin", "(Tensor x) -> Tensor") >>> >>> # Add implementations for the cpu device >>> @torch.library.impl("mylibrary::sin", "cpu") >>> def f(x): >>> return torch.from_numpy(np.sin(x.numpy())) >>> >>> x = torch.randn(3) >>> y = torch.ops.mylibrary.sin(x) >>> assert torch.allclose(y, x.sin())
- torch.library.impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1)[source]¶
为该运算符注册抽象实现。
“抽象实现”指定了该运算符对不包含数据的张量的行为。给定一些具有特定属性(大小/步幅/存储偏移量/设备)的输入张量,它指定了输出张量的属性。
抽象实现与运算符具有相同的签名。它在 FakeTensors 和元张量上运行。要编写抽象实现,假设运算符的所有张量输入都是常规的 CPU/CUDA/Meta 张量,但它们没有存储,并且您尝试返回常规的 CPU/CUDA/Meta 张量作为输出。抽象实现必须仅包含 PyTorch 操作(并且可能不会直接访问任何输入或中间张量的存储或数据)。
此 API 可用作装饰器(请参阅示例)。
有关自定义操作的详细指南,请参阅 https://docs.google.com/document/d/1W–T6wz8IY8fOI0Vm8BF44PdBgs283QvpelJZWieQWQ/edit
示例
>>> import torch >>> import numpy as np >>> from torch import Tensor >>> >>> # Example 1: an operator without data-dependent output shape >>> torch.library.define( >>> "mylib::custom_linear", >>> "(Tensor x, Tensor weight, Tensor bias) -> Tensor") >>> >>> @torch.library.impl_abstract("mylib::custom_linear") >>> def custom_linear_abstract(x, weight): >>> assert x.dim() == 2 >>> assert weight.dim() == 2 >>> assert bias.dim() == 1 >>> assert x.shape[1] == weight.shape[1] >>> assert weight.shape[0] == bias.shape[0] >>> assert x.device == weight.device >>> >>> return (x @ weight.t()) + bias >>> >>> # Example 2: an operator with data-dependent output shape >>> torch.library.define("mylib::custom_nonzero", "(Tensor x) -> Tensor") >>> >>> @torch.library.impl_abstract("mylib::custom_nonzero") >>> def custom_nonzero_abstract(x): >>> # Number of nonzero-elements is data-dependent. >>> # Since we cannot peek at the data in an abstract impl, >>> # we use the ctx object to construct a new symint that >>> # represents the data-dependent size. >>> ctx = torch.library.get_ctx() >>> nnz = ctx.new_dynamic_size() >>> shape = [nnz, x.dim()] >>> result = x.new_empty(shape, dtype=torch.int64) >>> return result >>> >>> @torch.library.impl("mylib::custom_nonzero", "cpu") >>> def custom_nonzero_cpu(x): >>> x_np = x.numpy() >>> res = np.stack(np.nonzero(x_np), axis=1) >>> return torch.tensor(res, device=x.device)
- torch.library.get_ctx()[source]¶
get_ctx() 返回当前的 AbstractImplCtx 对象。
仅在抽象实现内部调用
get_ctx()
有效(有关更多使用细节,请参阅torch.library.impl_abstract()
)。- 返回类型
AbstractImplCtx
低级 API¶
以下 API 是 PyTorch 的 C++ 低级运算符注册 API 的直接绑定。
警告
低级运算符注册 API 和 PyTorch 调度器是 PyTorch 中一个复杂的概念。我们建议您尽可能使用上面的高级 API(不需要 torch.library.Library 对象)。这篇博文 <http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/>`_ 是学习 PyTorch 调度器的良好起点。
一个关于如何使用此 API 的示例教程可在 Google Colab 上获得。
- class torch.library.Library(ns, kind, dispatch_key='')[source]¶
一个用于创建库的类,该库可用于从 Python 注册新运算符或覆盖现有库中的运算符。用户可以选择传入一个调度键名,如果他们只想注册对应于特定调度键的内核。
要创建一个库来覆盖现有库(名称为 ns)中的运算符,请将 kind 设置为“IMPL”。要创建一个新库(名称为 ns)来注册新运算符,请将 kind 设置为“DEF”。要创建一个可能存在的库的片段来注册运算符(并绕过只有一个库用于给定命名空间的限制),请将 kind 设置为“FRAGMENT”。
- 参数
ns – 库名称
kind – “DEF”, “IMPL” (默认: “IMPL”), “FRAGMENT”
dispatch_key – PyTorch 调度键 (默认: “”)
- define(schema, alias_analysis='', *, tags=())[source]¶
在 ns 命名空间中定义一个新的运算符及其语义。
- 参数
- 返回值
从模式推断出的运算符名称。
- 示例:
>>> my_lib = Library("foo", "DEF") >>> my_lib.define("sum(Tensor self) -> Tensor")
- impl(op_name, fn, dispatch_key='')[source]¶
注册在库中定义的运算符的函数实现。
- 参数
op_name – 操作符名称(以及重载)或 OpOverload 对象。
fn – 用于输入调度键的操作符实现函数,或
fallthrough_kernel()
用于注册透传。dispatch_key – 输入函数应注册的调度键。默认情况下,它使用创建库时使用的调度键。
- 示例:
>>> my_lib = Library("aten", "IMPL") >>> def div_cpu(self, other): >>> return self * (1 / other) >>> my_lib.impl("div.Tensor", div_cpu, "CPU")