快捷方式

OP Lowering 指南

PyTorch 封装了 C++ ATen 张量库,该库提供了在 GPU 和 CPU 上实现的广泛操作。PyTorch/XLA 是 PyTorch 的一个扩展;其目的之一是将 PyTorch 操作转换为 XLA 操作。Lowering 定义了将更高级表示转换为更低级表示的过程。在本文档中,我将把 PyTorch 操作转换为 XLA 操作的过程称为 lowering。XLA 编译器也会将 XlaOp 降低(lower)为 HLO,但这超出了本文档的范围。对于尚未提供 XLA lowering 的操作,我们会将其转发到 CPU 并调用 ATen 实现。转发到 CPU 的操作会导致显著的性能下降。为了获得最佳性能,我们必须对模型中使用的所有操作进行 lowering。

以下是 PyTorch/XLA 调试工具针对尚未进行 lowering 的操作可能显示的内容示例

pt-xla-profiler: Op(s) not lowered: aten::_ctc_loss, aten::_ctc_loss_backward,  Please open a GitHub issue with the above op lowering requests.

开始之前

您应该按照 贡献 PyTorch/XLA 中的说明安装所需的依赖项,并从源代码构建 PyTorch 和 PyTorch/XLA。实现 lowering 不需要访问 TPU。建议在工作站上进行实验,并将其配置为使用 XLA:CPU。您可以通过运行以下命令将 PyTorch/XLA 配置为使用 XLA:CPU:

export PJRT_DEVICE=CPU

理解操作

您可以在 native_functions.yaml 中找到 C++ ATen 操作的定义。从源代码构建 PyTorch/XLA 后,您还将在 xla/torch_xla/csrc/aten_fallback.h/cpp 中找到我们的默认实现(一个将调用转发给 PyTorch 原生内核的 boxed kernel)。PyTorch 操作通常可以轻松映射到 PyTorch 张量 API。如果不是这种情况,建议在 PyTorch repo 下搜索 PyTorch 原生实现。目标是将 PyTorch 操作降低为 XLA operation semantics 中定义的一系列 XLA 操作。

文件结构

下面提到的所有文件都位于 xla/torch_xla/csrc 文件夹下,但 codegen/xla_native_functions.yaml 除外

  1. xla_native_functions.yaml 包含所有被显式 lowering 的操作符列表(来自 Core Aten list)。复合操作符不在此列出。此处每个操作符名称必须直接匹配 native_functions.yaml 中列出的 PyTorch 操作符。此文件是添加新 XLA 操作符的接口,并且是 PyTorch 代码生成机制 的输入。它会生成以下 3 个文件:XLANativeFunctions.hRegisterXLA.cppRegisterAutogradXLA.cpp

  2. XLANativeFunctions.haten_xla_type.cpp 是 PyTorch 进入 pytorch_xla 世界的入口点,包含为每个操作符手动编写的 XLA lowerings。XLANativeFunctions.h 是通过结合使用 xla_native_functions.yaml 和 PyTorch 核心 native_functions.yaml 文件自动生成的,包含需要在 aten_xla_type.cpp 中定义的内核声明。在此处编写的内核需要使用输入的 at::Tensor 和其他参数来构造 'XLATensor'。生成的 XLATensor 在返回 PyTorch 世界之前需要转换回 at::Tensor

  3. RegisterXLA.cppRegisterAutogradXLA.cpp 是自动生成的文件,用于向 PyTorch Dispatcher 注册所有 lowerings。它们还包括 out=inplace 操作符的自动生成包装实现。

  4. aten_fallback.h/.cpp 包含我们的 boxed fallback 实现。如果在 xla_native_functions.yaml + aten_xla_type.cpp 中没有显式定义某个操作符的 lowering,并且该操作符不是复合操作符,则会使用 boxed fallback 内核。

  5. tensor_methods.h 包含 XLATensor 的声明。这些声明通常与我们在 XLANativeFunctions.h 中声明的 at::Tensor 节点一一对应。

  6. tensor_methods.cpp 包含 tensor_methods.h 中定义的 XLATensor node 的实现。我们使用参数的 ir::Value 构造相应的 ir::op 并将其封装在 XLATensor 中。Ir 是中间表示(intermediate representation)的缩写。

  7. ops/ 目录包含所有 ir::ops 的声明和定义。较小的节点可以放在 ops/ops.h/.cpp 中。更复杂的节点可以放在单独的文件中。所有 ops 都继承自 ir::ops::Node,并提供一种将输入 ir::Value 降低为一系列 XlaOp 的方法。

单元测试

我们的 CI 每天都会针对每次更改运行 PyTorch 原生 Python 测试。如果我们提供了 lowering,这些测试将使用 XLA 实现。通常,我们不需要为 PyTorch/XLA 添加额外的 Python 测试,除非我们想验证某些 XLA 行为(例如动态形状)或由于某些原因跳过了 PyTorch 原生测试。如果需要,应将 Python 测试添加到 xla/test/test_operations.py。我们还需要在 xla/test/cpp/test_aten_xla_tensor.cpp 中添加 CPP 测试。此测试应调用 PyTorch C++ API,并验证我们的实现是否与 PyTorch 原生实现产生相同的结果。我们还需要通过检查 aten::opxla::op 计数器来验证当张量是 XLA 张量时是否调用了 XLA 实现。

技巧

lowering 的过程是将 PyTorch 操作分解为一系列 XlaOp。要为 PyTorch 操作提供良好的 lowering,需要很好地掌握 XLA 的能力。阅读 XlaOp 文档并查看类似操作是如何进行 lowering 的是实现此目的的最佳方法。您可以在 这个 Op lowering PR 中找到一个最小的 Op lowering 示例。您还可以在 这个 backward lowering PR 中找到一个包含 backward lowering 的稍微复杂一些的示例。

我们在 RegisterXLA.cpp 中为某些操作符提供了 out=inplace 操作符的自动生成包装实现。在这种情况下,我们只需对 vanilla op 进行 lowering。例如 lerp 操作符,它在 native_functions.yaml 中有 6 个变体,分别是

- lerp_.Scalar
- lerp_.Tensor
- lerp.Scalar_out
- lerp.Tensor_out
- lerp.Scalar
- lerp.Tensor

并将生成函数原型

at::Tensor lerp(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight);
at::Tensor & lerp_(at::Tensor & self, const at::Tensor & end, const at::Scalar & weight);
at::Tensor lerp(const at::Tensor & self, const at::Tensor & end, const at::Tensor & weight);
at::Tensor & lerp_out(const at::Tensor & self, const at::Tensor & end, const at::Tensor & weight, at::Tensor & out);
at::Tensor & lerp_(at::Tensor & self, const at::Tensor & end, const at::Tensor & weight);
at::Tensor & lerp_out(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight, at::Tensor & out);

如果我们把所有这些变体都添加到 xla_native_functions.yaml 中,则会在 XLANativeFunctions.h 中。但是,如果我们在 xla_native_functions.yaml 中只对 lerp.Scalarlerp.Tensor 进行 lowering 并查看 RegisterXLA.cpp,我们将看到

namespace {

at::Tensor wrapper_Scalar_lerp(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight) {
    // No device check


  // DeviceGuard omitted
  return torch_xla::lerp(self, end, weight);
}

} // anonymous namespace

at::Tensor & wrapper_Scalar_lerp_(at::Tensor & self, const at::Tensor & end, const at::Scalar & weight) {
  auto wrapper_Scalar_lerp__tmp = wrapper_Scalar_lerp(self, end, weight);
  at::_copy_from(wrapper_Scalar_lerp__tmp, self);
  return self;
}

...
  m.impl("lerp_.Scalar",
  TORCH_FN(wrapper_Scalar_lerp_));

代码生成器将自动为使用我们 lerp.Scalar 实现的 lerp_.Scalarlerp.Scalar_out 生成 lowerings,而无需我们提供显式的 lowering。

一般来说,如果 PyTorch 核心中有一个同时具有 out-of-place 和 out= 变体的操作符,最好为 out-of-place 变体编写 lowering,因为您将免费获得一个由代码生成的 out= lowering。

对于每个节点,我们需要传递一个 ir::OpKind。这里是一个 (示例)。您可以在 interned_strings.h 中找到 OpKind 的定义。如果缺少 aten 符号,您可以提交一个像 这个 这样的 PR。

文档

访问 PyTorch 的完整开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源