自定义硬件插件¶
PyTorch/XLA 通过 OpenXLA 的 PJRT C API 支持自定义硬件。PyTorch/XLA 团队直接支持 Cloud TPU (libtpu
) 和 GPU (OpenXLA) 的插件。JAX 和 TF 也可以使用相同的插件。
实现 PJRT 插件¶
PJRT C API 插件可以是闭源或开源的。它们包含两个部分
暴露 PJRT C API 实现的二进制文件。这部分可以与 JAX 和 TensorFlow 共享。
包含上述二进制文件的 Python 包,以及我们的
DevicePlugin
Python 接口的实现,该接口处理额外的设置。
PJRT C API 实现¶
简而言之,您必须实现一个 PjRtClient,其中包含适用于您设备的 XLA 编译器和运行时。PJRT C++ 接口在 PJRT_Api 中以 C 语言镜像。最直接的选择是在 C++ 中实现您的插件,并将其 包装 为 C API 实现。OpenXLA 的文档中详细解释了此过程。
有关具体示例,请参阅示例实现。
PyTorch/XLA 插件包¶
此时,您应该有一个功能正常的 PJRT 插件二进制文件,您可以使用占位符 LIBRARY
设备类型进行测试。例如
$ PJRT_DEVICE=LIBRARY PJRT_LIBRARY_PATH=/path/to/your/plugin.so python
>>> import torch_xla
>>> torch_xla.devices()
# Assuming there are 4 devices. Your hardware may differ.
[device(type='xla', index=0), device(type='xla', index=1), device(type='xla', index=2), device(type='xla', index=3)]
为了为用户自动注册您的设备类型,并处理例如多进程的额外设置,您可以实现 DevicePlugin
Python API。PyTorch/XLA 插件包包含两个关键组件
DevicePlugin
的实现,它(至少)提供您的插件二进制文件的路径。例如
class CpuPlugin(plugins.DevicePlugin):
def library_path(self) -> str:
return os.path.join(
os.path.dirname(__file__), 'lib', 'pjrt_c_api_cpu_plugin.so')
torch_xla.plugins
入口点,用于标识您的DevicePlugin
。例如,要在pyproject.toml
中注册EXAMPLE
设备类型
<!-- -->
[project.entry-points."torch_xla.plugins"]
example = "torch_xla_cpu_plugin:CpuPlugin"
安装您的软件包后,您可以直接使用您的 EXAMPLE
设备
$ PJRT_DEVICE=EXAMPLE python
>>> import torch_xla
>>> torch_xla.devices()
[device(type='xla', index=0), device(type='xla', index=1), device(type='xla', index=2), device(type='xla', index=3)]
DevicePlugin 为多进程初始化和客户端选项提供了额外的扩展点。该 API 目前处于实验状态,但预计在未来版本中将变得稳定。