• 文档 >
  • 自定义硬件插件
快捷方式

自定义硬件插件

PyTorch/XLA 通过 OpenXLA 的 PJRT C API 支持自定义硬件。PyTorch/XLA 团队直接支持 Cloud TPU (libtpu) 和 GPU (OpenXLA) 的插件。JAX 和 TF 也可以使用相同的插件。

实现 PJRT 插件

PJRT C API 插件可以是闭源或开源的。它们包含两个部分

  1. 暴露 PJRT C API 实现的二进制文件。这部分可以与 JAX 和 TensorFlow 共享。

  2. 包含上述二进制文件的 Python 包,以及我们的 DevicePlugin Python 接口的实现,该接口处理额外的设置。

PJRT C API 实现

简而言之,您必须实现一个 PjRtClient,其中包含适用于您设备的 XLA 编译器和运行时。PJRT C++ 接口在 PJRT_Api 中以 C 语言镜像。最直接的选择是在 C++ 中实现您的插件,并将其 包装 为 C API 实现。OpenXLA 的文档中详细解释了此过程。

有关具体示例,请参阅示例实现

PyTorch/XLA 插件包

此时,您应该有一个功能正常的 PJRT 插件二进制文件,您可以使用占位符 LIBRARY 设备类型进行测试。例如

$ PJRT_DEVICE=LIBRARY PJRT_LIBRARY_PATH=/path/to/your/plugin.so python
>>> import torch_xla
>>> torch_xla.devices()
# Assuming there are 4 devices. Your hardware may differ.
[device(type='xla', index=0), device(type='xla', index=1), device(type='xla', index=2), device(type='xla', index=3)]

为了为用户自动注册您的设备类型,并处理例如多进程的额外设置,您可以实现 DevicePlugin Python API。PyTorch/XLA 插件包包含两个关键组件

  1. DevicePlugin 的实现,它(至少)提供您的插件二进制文件的路径。例如

class CpuPlugin(plugins.DevicePlugin):

  def library_path(self) -> str:
    return os.path.join(
        os.path.dirname(__file__), 'lib', 'pjrt_c_api_cpu_plugin.so')
  1. torch_xla.plugins 入口点,用于标识您的 DevicePlugin。例如,要在 pyproject.toml 中注册 EXAMPLE 设备类型

<!-- -->
[project.entry-points."torch_xla.plugins"]
example = "torch_xla_cpu_plugin:CpuPlugin"

安装您的软件包后,您可以直接使用您的 EXAMPLE 设备

$ PJRT_DEVICE=EXAMPLE python
>>> import torch_xla
>>> torch_xla.devices()
[device(type='xla', index=0), device(type='xla', index=1), device(type='xla', index=2), device(type='xla', index=3)]

DevicePlugin 为多进程初始化和客户端选项提供了额外的扩展点。该 API 目前处于实验状态,但预计在未来版本中将变得稳定。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得解答

查看资源