• 教程 >
  • 通过PrivateUse1促进新后端集成
快捷方式

通过PrivateUse1促进新后端集成

在本教程中,我们将介绍一些必要的步骤,以便将一个新的后端(位于pytorch/pytorch仓库之外)通过PrivateUse1进行集成。注意,本教程假设你已经对PyTorch有了基本的了解,并是PyTorch的高级用户。

注意

本教程仅涉及与PrivateUse1机制相关的部分,该机制用于促进新设备的集成,其他部分将不会介绍。同时,本教程中涉及的所有模块并非必需,您可以根据自己的实际需求选择对您有帮助的模块。

什么是PrivateUse1?

在Pytorch 2.0之前,PyTorch提供了三个保留的分派键(及其对应的自动微分键),用于对树外后端扩展进行原型设计,这三个分派键如下

  • PrivateUse1/AutogradPrivateUse1

  • PrivateUse2/AutogradPrivateUse2

  • PrivateUse3/AutogradPrivateUse3

在原型验证通过后,您可以为新的后端申请私有密钥,例如CUDA、XLA、MPS等。

然而,随着PyTorch的快速发展,越来越多的硬件制造商试图将他们的后端集成到PyTorch中,这可能会导致以下问题

  • 每个新的后端集成都涉及大量的文件修改

  • 目前对分派键的数量( DispatchKeySet 64位限制) 存在硬性限制

注意

将新的后端集成到 PyTorch 中,通过 PrivateUse1 密钥,也存在一个问题,因为不可能同时集成多个后端。幸运的是,这些 out-of-tree 后端很少同时使用。

鉴于上述原因,社区开始建议将新的后端集成到 PyTorch 中,通过 PrivateUse1

然而,之前的 PrivateUse1 机制并不完全能够与新的后端集成,因为它在某些模块中缺乏一些相关的支持,例如 Storage、AMP、Distributed 等。

随着 Pytorch 2.1.0 的到来,在新的后端集成方面,PrivateUse1 进行了一系列优化和增强,现在可以快速高效地支持新设备的集成。

如何通过 PrivateUse1 集成新的后端

在本节中,我们将讨论将新的后端集成到 Pytorch 中的细节,通过 PrivateUse1,主要包括以下部分

  1. 为新的后端注册内核。

  2. 为新的后端注册生成器。

  3. 为新的后端注册设备保护器。

  4. 为新的后端元数据注册序列化和反序列化函数。

  5. 其他模块。

为新的后端注册内核

新的后端可能对算子有一些高性能的实现,这些实现可以通过 TORCH_LIBRARY_IMPL API 注册到调度器中,如 Registering a Dispatched Operator in C++ 中所述。这涉及到几种情况

  1. 将新的后端支持的所有正向算子注册到调度器中,并同时注册回退,以便当新的后端不支持某些算子时,这些算子可以回退到 CPU 上执行,以确保功能的可用性。

at::Tensor wrapper_Custom_Tensor_add(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
  // Implementation of add kernel in new backend
  ...
}

TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
  ...
  m.impl("add.Tensor", TORCH_FN(wrapper_Custom_Tensor_add));
  ...
}

void custom_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
  // Add some hints about new devices that do not support and need to fall back to cpu
  at::native::cpu_fallback(op, stack);
}

TORCH_LIBRARY_IMPL(_, PrivateUse1, m) {
  m.fallback(torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
}
  1. 如果新的后端需要覆盖 PyTorch Autograd ,则可以通过 AutogradPrivateUse1 将来自 torch::autograd::Function 的内核注册到调度器中,调度器和自动微分系统将自动调用这些算子的正向和反向实现。

class CumtomSeluFunction : public torch::autograd::Function<CumtomSeluFunction> {
  // Implementation of selu kernel in new backend
}

at::Tensor wrapper_AutogradCumstom__selu(const at::Tensor & self) {
  return CumtomSeluFunction::apply(self);
}

TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m) {
  ...
  m.impl("selu", TORCH_FN(wrapper_AutogradCustom__selu));
  ...
}
  1. 通过 AutocastPrivateUse1 将想要支持 自动混合精度 (AMP) 和回退机制的内核注册到调度器中,自动类型转换系统将在需要时自动调用这些内核。

TORCH_LIBRARY_IMPL(aten, AutocastPrivateUse1, m) {
  ...
  KERNEL_PRIVATEUSEONE(<operator>, <policy>)
  ...
}

TORCH_LIBRARY_IMPL(_, AutocastPrivateUse1, m) {
  m.fallback(torch::CppFunction::makeFallthrough());
}

需要注意的是,如果您想在新的后端中支持 AMP,则需要通过 torch._register_device_module("backend_name", BackendModule) 注册一个新的 BackendModule,并且 BackendModule 需要具有以下 API

  • get_amp_supported_dtype() -> List[torch.dtype]

    获取新的后端在 AMP 中支持的数据类型,它可能支持一种更多的数据类型。

  • is_autocast_enabled() -> bool

    检查新的后端上是否启用了 AMP。

  • get_autocast_dtype() -> torch.dtype

    获取新的后端在 AMP 中支持的数据类型,它由 set_autocast_dtype 设置或为默认数据类型,默认数据类型是 torch.float16

  • set_autocast_enabled(bool) -> None

    在新的后端上启用或禁用 AMP。

  • set_autocast_dtype(dtype) -> None

    设置新的后端在 AMP 中支持的数据类型,并且数据类型必须包含在从 get_amp_supported_dtype 获取的数据类型中。

为新的后端注册生成器

有必要支持对应于新设备的生成器。目前,PrivateUse1 可以动态注册自定义生成器,主要分为以下步骤。

  1. 继承 GeneratorImpl 类以实现对应于新的后端的生成器类,并实现各种通用方法。

  2. 使用单个参数定义一个新的后端 builderdevice index

  3. 调用 REGISTER_GENERATOR_PRIVATEUSE1 宏以完成动态注册。

struct CustomGeneratorImpl : public c10::GeneratorImpl {
  // Implementation of generator in new backend
}

at::Generator make_custom_generator(c10::DeviceIndex device_index) {
  return at::make_generator<CustomGeneratorImpl>(device_index);
}

REGISTER_GENERATOR_PRIVATEUSE1(make_cumstom_generator)

为新的后端注册设备保护器

PyTorch 通过 DeviceGuard 提供与设备、流和事件切换相关的功能。此功能也适用于 PrivateUse1 密钥。

  1. 继承 DeviceGuardImplInterface 类以实现对应于新的后端的各种通用方法。

  2. 调用 C10_REGISTER_GUARD_IMPL 宏以完成动态注册。

struct CustomGuardImpl final : public c10::impl::DeviceGuardImplInterface {
  // Implementation of guard in new backend
}

C10_REGISTER_GUARD_IMPL(PrivateUse1, CustomGuardImpl);

为新的后端元数据注册序列化和反序列化函数

PyTorch 目前能够动态注册序列化/反序列化函数,以支持在类 TensorImpl.ExtraMeta 中对名为 backend_meta_ 的新的后端附加元数据的序列化和反序列化。您可以参考以下步骤

  1. 继承 BackendMeta 类以实现对应于新的后端的 CustomBackendMetadata,并且新的后端的各个字段可以在类中进行自定义。

  2. 实现新的后端的序列化和反序列化函数,函数签名为 void(const at::Tensor&, std::unordered_map<std::string, bool>&)

  3. 调用 TensorBackendMetaRegistry 宏以完成动态注册。

struct CustomBackendMetadata : public c10::BackendMeta {
  // Implementation of backend metadata in new backend
}

void for_serialization(const at::Tensor& t, std::unordered_map<std::string, bool>& m) {
  // Implementation of serialization
}

void for_deserialization(const at::Tensor& t, std::unordered_map<std::string, bool>& m) {
  // Implementation of deserialization
}

TensorBackendMetaRegistry(c10::DeviceType::PrivateUse1, &for_serialization, &for_deserialization);

其他模块

除了上述部分之外,还有一些其他模块可以通过 PrivateUse1 进行扩展,例如 distributed collective communicationbenchmark timer 等,这些将在未来添加。关于 PrivateUse1 集成的例子是 Ascend NPU

如何通过 Privateuse1 改善用户体验

通过 PrivateUse1 集成新设备的主要目标是满足基本的功能需求,下一步是要改进可用性,主要涉及以下方面。

  1. 将新的后端模块注册到 Pytorch。

  2. 将 PrivateUse1 重命名为新的后端的自定义名称。

  3. 生成与新的后端相关的函数和属性。

将新的后端模块注册到 Pytorch

PyTorch 中一些与 CUDA 相关的接口可以通过以下形式调用:torch.cuda.xxx。因此,为了符合用户的习惯,通过 PrivateUse1 机制实现的新后端也应该提供类似的接口。

例如,使用 Ascend NPU

torch._register_device_module('npu', torch_npu.npu)

完成上述操作后,用户可以通过 torch.npu.xxx 调用 Ascend NPU 的一些专用 API

将 PrivateUse1 重命名为新的后端的自定义名称

PrivateUse1 密钥是新的后端集成到 PyTorch 中的内部机制。对于用户来说,与 PrivateUse1 相比,与新的后端密切相关的自定义名称应该更加友好。

Ascend NPU 为例,第一种用法将更加人性化。

torch.rand((2,2),device='npu:0')
torch.rand((2,2),device='privateuse1:0')

现在,PyTorch 为自命名 PrivateUse1 后端提供了一个新的 C++/Python API,使用起来非常简单。

torch.rename_privateuse1_backend("npu")
c10::register_privateuse1_backend("npu")

未来工作

PrivateUse1 机制的改进仍在进行中,因此将依次添加新的模块的 PrivateUse1 集成方法。以下列出了一些我们正在积极进行的工作

  • 添加 distributed collective communication 的集成方法。

  • 添加 benchmark timer 的集成方法。

结论

本教程介绍了通过 PrivateUse1 将新的后端集成到 PyTorch 中的过程,包括但不限于算子注册、生成器注册、设备保护器注册等。同时,还介绍了一些改进用户体验的方法。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的答案

查看资源