PyTorch 自定义算子¶
创建于:2024 年 6 月 18 日 | 最后更新:2025 年 1 月 6 日 | 最后验证:2024 年 11 月 5 日
PyTorch 提供了大量的算子库,用于处理张量(例如 torch.add
、torch.sum
等)。但是,您可能希望将新的自定义操作引入 PyTorch,并使其与 torch.compile
、autograd 和 torch.vmap
等子系统一起工作。为了做到这一点,您必须通过 Python torch.library 文档 或 C++ TORCH_LIBRARY
API 在 PyTorch 中注册自定义操作。
从 Python 编写自定义算子¶
请参阅 自定义 Python 算子。
在以下情况下,您可能希望从 Python(而不是 C++)编写自定义算子
您有一个 Python 函数,您希望 PyTorch 将其视为不透明的可调用对象,尤其是在
torch.compile
和torch.export
方面。您有一些 C++/CUDA 内核的 Python 绑定,并希望这些绑定与 PyTorch 子系统(如
torch.compile
或torch.autograd
)组合使用您正在使用 Python(而不是仅限 C++ 的环境,如 AOTInductor)。
将自定义 C++ 和/或 CUDA 代码与 PyTorch 集成¶
请参阅 自定义 C++ 和 CUDA 算子。
在以下情况下,您可能希望从 C++(而不是 Python)编写自定义算子
您有自定义的 C++ 和/或 CUDA 代码。
您计划将此代码与
AOTInductor
一起使用以进行无 Python 推理。
自定义算子手册¶
有关教程和本页未涵盖的信息,请参阅 自定义算子手册(我们正在努力将信息迁移到我们的文档站点)。我们建议您首先阅读上面的教程之一,然后将自定义算子手册作为参考;它不适合从头到尾阅读。
我应该何时创建自定义算子?¶
如果您的操作可以表示为内置 PyTorch 算子的组合,请将其编写为 Python 函数并调用它,而不是创建自定义算子。如果您要调用 PyTorch 不理解的库(例如,自定义 C/C++ 代码、自定义 CUDA 内核或 C/C++/CUDA 扩展的 Python 绑定),请使用算子注册 API 创建自定义算子。
我为什么要创建自定义算子?¶
可以使用 C/C++/CUDA 内核,方法是获取张量的数据指针并将其传递给 pybind 的内核。但是,这种方法不能与 PyTorch 子系统(如 autograd、torch.compile、vmap 等)组合使用。为了使操作能够与 PyTorch 子系统组合使用,必须通过算子注册 API 进行注册。