快捷方式

自定义

本节介绍如何自定义 TorchElastic 以满足您的需求。

启动器

对于大多数用例,TorchElastic 附带的启动器程序应该足够了(请参阅 torchrun (弹性启动))。您可以通过编程方式创建一个代理并为其工作进程传递规范来实现自定义启动器,如下所示。

# my_launcher.py

if __name__ == "__main__":
  args = parse_args(sys.argv[1:])
  rdzv_handler = RendezvousHandler(...)
  spec = WorkerSpec(
      local_world_size=args.nproc_per_node,
      fn=trainer_entrypoint_fn,
      args=(trainer_entrypoint_fn args.fn_args,...),
      rdzv_handler=rdzv_handler,
      max_restarts=args.max_restarts,
      monitor_interval=args.monitor_interval,
  )

  agent = LocalElasticAgent(spec, start_method="spawn")
  try:
      run_result = agent.run()
      if run_result.is_failed():
          print(f"worker 0 failed with: run_result.failures[0]")
      else:
          print(f"worker 0 return value is: run_result.return_values[0]")
  except Exception ex:
      # handle exception

Rendezvous 处理程序

要实现您自己的 rendezvous,请扩展 torch.distributed.elastic.rendezvous.RendezvousHandler 并实现其方法。

警告

Rendezvous 处理程序很难实现。在开始之前,请确保您完全理解 rendezvous 的属性。请参阅 Rendezvous 以获取更多信息。

一旦实现,您可以在创建代理时将您的自定义 rendezvous 处理程序传递给工作进程规范。

spec = WorkerSpec(
    rdzv_handler=MyRendezvousHandler(params),
    ...
)
elastic_agent = LocalElasticAgent(spec, start_method=start_method)
elastic_agent.run(spec.role)

指标处理程序

TorchElastic 发出平台级指标(请参阅 指标)。默认情况下,指标会发送到 /dev/null,因此您不会看到它们。要将指标推送到您基础设施中的指标处理服务,请实现一个 torch.distributed.elastic.metrics.MetricHandler 并在您的自定义启动器中 配置 它。

# my_launcher.py

import torch.distributed.elastic.metrics as metrics

class MyMetricHandler(metrics.MetricHandler):
    def emit(self, metric_data: metrics.MetricData):
        # push metric_data to your metric sink

def main():
  metrics.configure(MyMetricHandler())

  spec = WorkerSpec(...)
  agent = LocalElasticAgent(spec)
  agent.run()

事件处理程序

TorchElastic 支持事件记录(请参阅 事件)。事件模块定义了 API,允许您记录事件并实现自定义 EventHandler。EventHandler 用于将 torchelastic 执行期间生成的事件发布到不同的来源,例如 AWS CloudWatch。默认情况下,它使用 torch.distributed.elastic.events.NullEventHandler 来忽略事件。要配置自定义事件处理程序,您需要实现 torch.distributed.elastic.events.EventHandler 接口并在您的自定义启动器中 配置 它。

# my_launcher.py

import torch.distributed.elastic.events as events

class MyEventHandler(events.EventHandler):
    def record(self, event: events.Event):
        # process event

def main():
  events.configure(MyEventHandler())

  spec = WorkerSpec(...)
  agent = LocalElasticAgent(spec)
  agent.run()

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的问题解答

查看资源