Torch 库 API¶
PyTorch C++ API 提供了扩展 PyTorch 核心运算符库的功能,可以使用用户定义的运算符和数据类型进行扩展。使用 Torch 库 API 实现的扩展可在 PyTorch 渴望 API 和 TorchScript 中使用。
有关库 API 的教程式介绍,请查看 使用自定义 C++ 运算符扩展 TorchScript 教程。
宏¶
-
TORCH_LIBRARY(ns, m)
用于定义一个将在静态初始化时运行的函数的宏,以在命名空间
ns
中定义运算符库(必须是有效的 C++ 标识符,不带引号)。当您想要定义一组在 PyTorch 中尚不存在的自定义运算符时,请使用此宏。
示例用法
TORCH_LIBRARY(myops, m) { // m is a torch::Library; methods on it will define // operators in the myops namespace m.def("add", add_impl); }
m
参数绑定到一个 torch::Library,该库用于注册运算符。对于任何给定的命名空间,只能有一个 TORCH_LIBRARY()。
-
TORCH_LIBRARY_IMPL(ns, k, m)
用于定义一个将在静态初始化时运行的函数的宏,以在命名空间
ns
中为调度键k
(必须是 c10::DispatchKey 的非限定枚举成员)定义运算符覆盖。当您想要在新的调度键上实现一组预先存在的自定义运算符时(例如,您想要提供已存在运算符的 CUDA 实现),请使用此宏。一种常见的用法模式是使用 TORCH_LIBRARY() 为您想要定义的所有新运算符定义模式,然后使用多个 TORCH_LIBRARY_IMPL() 块为 CPU、CUDA 和 Autograd 提供运算符的实现。
在某些情况下,您需要定义适用于所有命名空间而不是单个命名空间的内容(通常是回退)。在这种情况下,使用保留的命名空间 _,例如:
TORCH_LIBRARY_IMPL(_, XLA, m) { m.fallback(xla_fallback); }
示例用法
TORCH_LIBRARY_IMPL(myops, CPU, m) { // m is a torch::Library; methods on it will define // CPU implementations of operators in the myops namespace. // It is NOT valid to call torch::Library::def() // in this context. m.impl("add", add_cpu_impl); }
如果
add_cpu_impl
是一个重载函数,请使用static_cast
指定您想要的哪个重载(通过提供完整类型)。
类¶
-
class Library
此对象提供了定义运算符并在调度键处提供实现的 API。
通常,不会直接分配 torch::Library;而是由 TORCH_LIBRARY() 或 TORCH_LIBRARY_IMPL() 宏创建。
torch::Library 上的大多数方法都返回对自身的引用,支持方法链。
// Examples: TORCH_LIBRARY(torchvision, m) { // m is a torch::Library m.def("roi_align", ...); ... } TORCH_LIBRARY_IMPL(aten, XLA, m) { // m is a torch::Library m.impl("add", ...); ... }
公共函数
-
template<typename Schema>
inline Library &def(Schema &&raw_schema, const std::vector<at::Tag> &tags = {}, _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) & 声明一个具有模式的运算符,但不要为其提供任何实现。
您需要使用 impl() 方法提供实现。所有模板参数都是推断的。
// Example: TORCH_LIBRARY(myops, m) { m.def("add(Tensor self, Tensor other) -> Tensor"); }
- 参数
raw_schema – 要定义的运算符的模式。通常,这是一个
const char*
字符串文字,但此处接受 torch::schema() 接受的任何类型。
-
inline Library &set_python_module(const char *pymodule, const char *context = "")
声明对于随后定义的所有运算符,其伪实现可以在给定的 Python 模块 (pymodule) 中找到。
如果找不到伪实现,则注册一些用作帮助文本的内容。
参数
pymodule: Python 模块
context: 我们可能会将其包含在错误消息中。
-
inline Library &impl_abstract_pystub(const char *pymodule, const char *context = "")
已弃用;请改用 set_python_module。
-
template<typename NameOrSchema, typename Func>
inline Library &def(NameOrSchema &&raw_name_or_schema, Func &&raw_f, const std::vector<at::Tag> &tags = {}) & 定义运算符的模式,然后注册其实现。
如果您不打算使用调度程序来构建运算符实现,通常会使用此方法。它大致相当于调用 def() 然后调用 impl(),但是如果您省略运算符的模式,我们将从您的 C++ 函数类型推断它。所有模板参数都将被推断。
// Example: TORCH_LIBRARY(myops, m) { m.def("add", add_fn); }
- 参数
raw_name_or_schema – 要定义的运算符的模式,或者如果要从
raw_f
推断模式,则只需运算符的名称。通常是const char*
字面量。raw_f – 实现此运算符的 C++ 函数。此处接受 torch::CppFunction 的任何有效构造函数;通常您提供函数指针或 lambda 表达式。
-
template<typename Name, typename Func>
inline Library &impl(Name name, Func &&raw_f, _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) & 注册运算符的实现。
您可以在不同的调度键上为单个运算符注册多个实现(请参阅 torch::dispatch())。实现必须具有相应的声明(来自 def()),否则它们无效。如果您计划注册多个实现,请在 def() 运算符时不要提供函数实现。
// Example: TORCH_LIBRARY_IMPL(myops, CUDA, m) { m.impl("add", add_cuda); }
- 参数
name – 要实现的运算符的名称。此处不要提供模式。
raw_f – 实现此运算符的 C++ 函数。此处接受 torch::CppFunction 的任何有效构造函数;通常您提供函数指针或 lambda 表达式。
-
template<typename Func>
inline Library &fallback(Func &&raw_f) & 为所有运算符注册回退实现,如果当前没有可用的特定运算符实现,则使用此实现。
回退必须与 DispatchKey 相关联;例如,仅从 TORCH_LIBRARY_IMPL() 中使用命名空间
_
调用此函数。// Example: TORCH_LIBRARY_IMPL(_, AutogradXLA, m) { // If there is not a kernel explicitly registered // for AutogradXLA, fallthrough to the next // available kernel m.fallback(torch::CppFunction::makeFallthrough()); } // See aten/src/ATen/core/dispatch/backend_fallback_test.cpp // for a full example of boxed fallback
- 参数
raw_f – 实现回退的函数。未装箱的函数通常不适合用作回退函数,因为回退函数必须适用于每个运算符(即使它们具有不同的类型签名)。典型参数是 CppFunction::makeFallthrough() 或 CppFunction::makeFromBoxedFunction()
-
template<typename Schema>
-
class CppFunction
表示实现运算符的 C++ 函数。
大多数用户不会直接与该类交互,除非通过错误消息:此函数的构造函数定义了您可以通过接口绑定的允许的“函数”式内容的集合。
此类擦除了传入函数的类型,但通过函数的推断模式持久地记录了类型。
公共函数
-
template<typename Func>
inline explicit CppFunction(Func *f, std::enable_if_t<c10::guts::is_function_type<Func>::value, std::nullptr_t> = nullptr) 此重载接受函数指针,例如
CppFunction(&add_impl)
-
template<typename FuncPtr>
inline explicit CppFunction(FuncPtr f, std::enable_if_t<c10::is_compile_time_function_pointer<FuncPtr>::value, std::nullptr_t> = nullptr) 此重载接受编译时函数指针,例如:
CppFunction(TORCH_FN(add_impl))
-
template<typename Lambda>
inline explicit CppFunction(Lambda &&f, std::enable_if_t<c10::guts::is_functor<std::decay_t<Lambda>>::value, std::nullptr_t> = nullptr) 此重载接受 lambda 表达式,例如:
CppFunction([](const Tensor& self) { ...
})
公共静态函数
-
static inline CppFunction makeFallthrough()
这会创建一个贯通函数。
贯通函数会立即重新分派到下一个可用的分派键,但实现效率比以相同方式手写函数更高。
-
template<c10::BoxedKernel::BoxedKernelFunction *func>
static inline CppFunction makeFromBoxedFunction() 从具有签名
void(const OperatorHandle&, Stack*)
的盒装内核函数创建函数;即,它们以盒装调用约定而不是本机 C++ 调用约定接收参数栈。盒装函数通常仅用于通过torch::Library::fallback()注册后端回退。
-
template<class KernelFunctor>
static inline CppFunction makeFromBoxedFunctor(std::unique_ptr<KernelFunctor> kernelFunctor) 从盒装内核仿函数创建函数,该仿函数定义了
operator()(const OperatorHandle&, DispatchKeySet, Stack*)
(从盒装调用约定接收参数)并继承自c10::OperatorKernel
。与 makeFromBoxedFunction 不同,以这种方式注册的函数还可以携带由仿函数管理的其他状态;如果您正在编写某个其他实现(例如 Python 可调用对象)的适配器,该适配器与注册的内核动态关联,这将非常有用。
-
template<typename FuncPtr, std::enable_if_t<c10::guts::is_function_type<FuncPtr>::value, std::nullptr_t> = nullptr>
static inline CppFunction makeFromUnboxedFunction(FuncPtr *f) 从非盒装内核函数创建函数。
这通常用于注册通用运算符。
-
template<typename FuncPtr, std::enable_if_t<c10::is_compile_time_function_pointer<FuncPtr>::value, std::nullptr_t> = nullptr>
static inline CppFunction makeFromUnboxedFunction(FuncPtr f) 从编译时非盒装内核函数指针创建函数。
这通常用于注册通用运算符。编译时函数指针可用于允许编译器优化(例如内联)对它的调用。
-
template<typename Func>
函数¶
-
template<typename Func>
inline CppFunction dispatch(c10::DispatchKey k, Func &&raw_f)¶ 创建一个与特定调度键关联的 torch::CppFunction。
使用 c10::DispatchKey 标记的 torch::CppFunctions 只有在调度程序确定应该调度到此特定 c10::DispatchKey 时才会被调用。
通常不会直接使用此函数,而是建议使用 TORCH_LIBRARY_IMPL(),它会在其主体内部隐式设置所有注册调用的 c10::DispatchKey。
-
template<typename Func>
inline CppFunction dispatch(c10::DeviceType type, Func &&raw_f)¶ dispatch() 的便捷重载,它接受 c10::DeviceType。
-
inline c10::FunctionSchema schema(const char *str, c10::AliasAnalysisKind k, bool allow_typevars = false)¶
从字符串构造一个 c10::FunctionSchema,并显式指定 c10::AliasAnalysisKind。
通常,模式只是作为字符串传入,但如果您需要指定自定义别名分析,则可以使用对该函数的调用替换字符串。
// Default alias analysis (FROM_SCHEMA) m.def("def3(Tensor self) -> Tensor"); // Pure function alias analysis m.def(torch::schema("def3(Tensor self) -> Tensor", c10::AliasAnalysisKind::PURE_FUNCTION));
-
inline c10::FunctionSchema schema(const char *s, bool allow_typevars = false)¶
函数模式可以直接从字符串字面量构造。