OP Lowering 指南¶
PyTorch 封装了 C++ ATen 张量库,该库提供了在 GPU 和 CPU 上实现的广泛操作。PyTorch/XLA 是 PyTorch 的一个扩展;其目的之一是将 PyTorch 操作转换为 XLA 操作。Lowering 定义了将更高级表示转换为更低级表示的过程。在本文档中,我将把 PyTorch 操作转换为 XLA 操作的过程称为 lowering。XLA 编译器也会将 XlaOp 降低(lower)为 HLO,但这超出了本文档的范围。对于尚未提供 XLA lowering 的操作,我们会将其转发到 CPU 并调用 ATen 实现。转发到 CPU 的操作会导致显著的性能下降。为了获得最佳性能,我们必须对模型中使用的所有操作进行 lowering。
以下是 PyTorch/XLA 调试工具针对尚未进行 lowering 的操作可能显示的内容示例
pt-xla-profiler: Op(s) not lowered: aten::_ctc_loss, aten::_ctc_loss_backward, Please open a GitHub issue with the above op lowering requests.
开始之前¶
您应该按照 贡献 PyTorch/XLA 中的说明安装所需的依赖项,并从源代码构建 PyTorch 和 PyTorch/XLA。实现 lowering 不需要访问 TPU。建议在工作站上进行实验,并将其配置为使用 XLA:CPU。您可以通过运行以下命令将 PyTorch/XLA 配置为使用 XLA:CPU:
export PJRT_DEVICE=CPU
理解操作¶
您可以在 native_functions.yaml 中找到 C++ ATen 操作的定义。从源代码构建 PyTorch/XLA 后,您还将在 xla/torch_xla/csrc/aten_fallback.h/cpp
中找到我们的默认实现(一个将调用转发给 PyTorch 原生内核的 boxed kernel)。PyTorch 操作通常可以轻松映射到 PyTorch 张量 API。如果不是这种情况,建议在 PyTorch repo 下搜索 PyTorch 原生实现。目标是将 PyTorch 操作降低为 XLA operation semantics 中定义的一系列 XLA 操作。
文件结构¶
下面提到的所有文件都位于 xla/torch_xla/csrc
文件夹下,但 codegen/xla_native_functions.yaml
除外
xla_native_functions.yaml
包含所有被显式 lowering 的操作符列表(来自 Core Aten list)。复合操作符不在此列出。此处每个操作符名称必须直接匹配 native_functions.yaml 中列出的 PyTorch 操作符。此文件是添加新 XLA 操作符的接口,并且是 PyTorch 代码生成机制 的输入。它会生成以下 3 个文件:XLANativeFunctions.h
、RegisterXLA.cpp
和RegisterAutogradXLA.cpp
XLANativeFunctions.h
和aten_xla_type.cpp
是 PyTorch 进入 pytorch_xla 世界的入口点,包含为每个操作符手动编写的 XLA lowerings。XLANativeFunctions.h
是通过结合使用xla_native_functions.yaml
和 PyTorch 核心native_functions.yaml
文件自动生成的,包含需要在aten_xla_type.cpp
中定义的内核声明。在此处编写的内核需要使用输入的at::Tensor
和其他参数来构造 'XLATensor'。生成的XLATensor
在返回 PyTorch 世界之前需要转换回at::Tensor
。RegisterXLA.cpp
和RegisterAutogradXLA.cpp
是自动生成的文件,用于向 PyTorch Dispatcher 注册所有 lowerings。它们还包括out=
和inplace
操作符的自动生成包装实现。aten_fallback.h/.cpp
包含我们的 boxed fallback 实现。如果在xla_native_functions.yaml
+aten_xla_type.cpp
中没有显式定义某个操作符的 lowering,并且该操作符不是复合操作符,则会使用 boxed fallback 内核。tensor_methods.h
包含XLATensor
的声明。这些声明通常与我们在XLANativeFunctions.h
中声明的at::Tensor
节点一一对应。tensor_methods.cpp
包含tensor_methods.h
中定义的XLATensor node
的实现。我们使用参数的ir::Value
构造相应的ir::op
并将其封装在XLATensor
中。Ir 是中间表示(intermediate representation)的缩写。ops/
目录包含所有ir::ops
的声明和定义。较小的节点可以放在ops/ops.h/.cpp
中。更复杂的节点可以放在单独的文件中。所有 ops 都继承自ir::ops::Node
,并提供一种将输入ir::Value
降低为一系列XlaOp
的方法。
单元测试¶
我们的 CI 每天都会针对每次更改运行 PyTorch 原生 Python 测试。如果我们提供了 lowering,这些测试将使用 XLA 实现。通常,我们不需要为 PyTorch/XLA 添加额外的 Python 测试,除非我们想验证某些 XLA 行为(例如动态形状)或由于某些原因跳过了 PyTorch 原生测试。如果需要,应将 Python 测试添加到 xla/test/test_operations.py
。我们还需要在 xla/test/cpp/test_aten_xla_tensor.cpp
中添加 CPP 测试。此测试应调用 PyTorch C++ API,并验证我们的实现是否与 PyTorch 原生实现产生相同的结果。我们还需要通过检查 aten::op
和 xla::op
计数器来验证当张量是 XLA 张量时是否调用了 XLA 实现。
技巧¶
lowering 的过程是将 PyTorch 操作分解为一系列 XlaOp。要为 PyTorch 操作提供良好的 lowering,需要很好地掌握 XLA 的能力。阅读 XlaOp 文档并查看类似操作是如何进行 lowering 的是实现此目的的最佳方法。您可以在 这个 Op lowering PR 中找到一个最小的 Op lowering 示例。您还可以在 这个 backward lowering PR 中找到一个包含 backward lowering 的稍微复杂一些的示例。
我们在 RegisterXLA.cpp
中为某些操作符提供了 out=
和 inplace
操作符的自动生成包装实现。在这种情况下,我们只需对 vanilla op 进行 lowering。例如 lerp
操作符,它在 native_functions.yaml
中有 6 个变体,分别是
- lerp_.Scalar
- lerp_.Tensor
- lerp.Scalar_out
- lerp.Tensor_out
- lerp.Scalar
- lerp.Tensor
并将生成函数原型
at::Tensor lerp(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight);
at::Tensor & lerp_(at::Tensor & self, const at::Tensor & end, const at::Scalar & weight);
at::Tensor lerp(const at::Tensor & self, const at::Tensor & end, const at::Tensor & weight);
at::Tensor & lerp_out(const at::Tensor & self, const at::Tensor & end, const at::Tensor & weight, at::Tensor & out);
at::Tensor & lerp_(at::Tensor & self, const at::Tensor & end, const at::Tensor & weight);
at::Tensor & lerp_out(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight, at::Tensor & out);
如果我们把所有这些变体都添加到 xla_native_functions.yaml
中,则会在 XLANativeFunctions.h
中。但是,如果我们在 xla_native_functions.yaml
中只对 lerp.Scalar
和 lerp.Tensor
进行 lowering 并查看 RegisterXLA.cpp
,我们将看到
namespace {
at::Tensor wrapper_Scalar_lerp(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight) {
// No device check
// DeviceGuard omitted
return torch_xla::lerp(self, end, weight);
}
} // anonymous namespace
at::Tensor & wrapper_Scalar_lerp_(at::Tensor & self, const at::Tensor & end, const at::Scalar & weight) {
auto wrapper_Scalar_lerp__tmp = wrapper_Scalar_lerp(self, end, weight);
at::_copy_from(wrapper_Scalar_lerp__tmp, self);
return self;
}
...
m.impl("lerp_.Scalar",
TORCH_FN(wrapper_Scalar_lerp_));
代码生成器将自动为使用我们 lerp.Scalar
实现的 lerp_.Scalar
和 lerp.Scalar_out
生成 lowerings,而无需我们提供显式的 lowering。
一般来说,如果 PyTorch 核心中有一个同时具有 out-of-place 和 out= 变体的操作符,最好为 out-of-place 变体编写 lowering,因为您将免费获得一个由代码生成的 out= lowering。
对于每个节点,我们需要传递一个 ir::OpKind
。这里是一个 (示例)。您可以在 interned_strings.h 中找到 OpKind
的定义。如果缺少 aten 符号,您可以提交一个像 这个 这样的 PR。