自定义硬件插件¶
PyTorch/XLA 通过 OpenXLA 的 PJRT C API 支持自定义硬件。PyTorch/XLA 团队直接支持 Cloud TPU (libtpu
) 和 GPU (OpenXLA) 的插件。这些插件也可以被 JAX 和 TF 使用。
实现 PJRT 插件¶
PJRT C API 插件可以是闭源或开源的。它们包含两个部分
提供 PJRT C API 实现的二进制文件。这部分可以与 JAX 和 TensorFlow 共享。
包含上述二进制文件以及我们
DevicePlugin
Python 接口实现的 Python 包,该接口处理额外的设置。
PJRT C API 实现¶
简而言之,您必须实现一个包含您的设备的 XLA 编译器和运行时的 PjRtClient。PJRT C++ 接口在 PJRT_Api 中以 C 语言形式镜像。最直接的选择是用 C++ 实现您的插件并将其包装为 C API 实现。OpenXLA 的文档中详细解释了此过程。
有关具体示例,请参阅示例实现。
PyTorch/XLA 插件包¶
此时,您应该拥有一个功能正常的 PJRT 插件二进制文件,您可以使用占位符设备类型 LIBRARY
进行测试。例如
$ PJRT_DEVICE=LIBRARY PJRT_LIBRARY_PATH=/path/to/your/plugin.so python
>>> import torch_xla
>>> torch_xla.devices()
# Assuming there are 4 devices. Your hardware may differ.
[device(type='xla', index=0), device(type='xla', index=1), device(type='xla', index=2), device(type='xla', index=3)]
为了自动为用户注册您的设备类型并处理额外的设置(例如多进程),您可以实现 DevicePlugin
Python API。PyTorch/XLA 插件包包含两个关键组件
DevicePlugin
的实现,它(至少)提供您的插件二进制文件的路径。例如
class CpuPlugin(plugins.DevicePlugin):
def library_path(self) -> str:
return os.path.join(
os.path.dirname(__file__), 'lib', 'pjrt_c_api_cpu_plugin.so')
一个
torch_xla.plugins
entry point,用于标识您的DevicePlugin
。例如,要在pyproject.toml
中注册EXAMPLE
设备类型,可以这样写
<!-- -->
[project.entry-points."torch_xla.plugins"]
example = "torch_xla_cpu_plugin:CpuPlugin"
安装您的包后,您可以直接使用您的 EXAMPLE
设备
$ PJRT_DEVICE=EXAMPLE python
>>> import torch_xla
>>> torch_xla.devices()
[device(type='xla', index=0), device(type='xla', index=1), device(type='xla', index=2), device(type='xla', index=3)]
DevicePlugin 为多进程初始化和客户端选项提供了额外的扩展点。该 API 目前处于实验阶段,但预计将在未来版本中稳定下来。