快捷方式

PyTorch 自定义算子

创建于:2024 年 6 月 18 日 | 最后更新:2025 年 1 月 6 日 | 最后验证:2024 年 11 月 5 日

PyTorch 提供了大量的算子库,可以在张量 (例如 torch.add, torch.sum 等) 上工作。但是,您可能希望将一个新的自定义操作引入 PyTorch,并使其与 torch.compile、自动求导和 torch.vmap 等子系统协同工作。为此,您必须通过 Python torch.library 文档 或 C++ TORCH_LIBRARY API 向 PyTorch 注册自定义操作。

从 Python 创建自定义算子

请参阅 自定义 Python 算子

您可能希望从 Python (而非 C++) 创建自定义算子,如果

  • 您有一个 Python 函数,希望 PyTorch 将其视为一个不透明的可调用对象,尤其是在 torch.compiletorch.export 方面。

  • 您有一些连接 C++/CUDA 内核的 Python 绑定,并希望这些绑定与 PyTorch 子系统 (如 torch.compiletorch.autograd) 协同工作

  • 您正在使用 Python (而不是像 AOTInductor 这样的纯 C++ 环境)。

将自定义 C++ 和/或 CUDA 代码与 PyTorch 集成

请参阅 自定义 C++ 和 CUDA 算子

您可能希望从 C++ (而非 Python) 创建自定义算子,如果

  • 您有自定义的 C++ 和/或 CUDA 代码。

  • 您计划将此代码与 AOTInductor 一起使用以进行无 Python 推理。

自定义算子手册

对于教程和本页未涵盖的信息,请参阅 自定义算子手册 (我们正在将这些信息迁移到我们的文档网站)。我们建议您先阅读上面的一个教程,然后将自定义算子手册作为参考;它不是用来从头读到尾的。

何时应该创建自定义算子?

如果您的操作可以表示为内置 PyTorch 算子的组合,那么请将其编写为一个 Python 函数并调用它,而不是创建自定义算子。如果您正在调用 PyTorch 无法理解的某个库 (例如自定义 C/C++ 代码、自定义 CUDA 内核或 C/C++/CUDA 扩展的 Python 绑定),请使用算子注册 API 创建自定义算子。

为什么应该创建自定义算子?

可以通过获取张量的数据指针并将其传递给 pybind 绑定的内核来使用 C/C++/CUDA 内核。但是,这种方法无法与 PyTorch 子系统 (如自动求导、torch.compile、vmap 等) 协同工作。为了使操作能够与 PyTorch 子系统协同工作,必须通过算子注册 API 进行注册。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源