自訂¶
本節說明如何自訂 TorchElastic 以滿足您的需求。
啟動器¶
TorchElastic 附帶的啟動器程式應該足以應付大多數的使用案例(請參閱 torchrun(彈性啟動))。您可以透過以程式化方式建立代理程式並將其傳遞給您的工作者的規格來實作自訂啟動器,如下所示。
# my_launcher.py
if __name__ == "__main__":
  args = parse_args(sys.argv[1:])
  rdzv_handler = RendezvousHandler(...)
  spec = WorkerSpec(
      local_world_size=args.nproc_per_node,
      fn=trainer_entrypoint_fn,
      args=(trainer_entrypoint_fn args.fn_args,...),
      rdzv_handler=rdzv_handler,
      max_restarts=args.max_restarts,
      monitor_interval=args.monitor_interval,
  )
  agent = LocalElasticAgent(spec, start_method="spawn")
  try:
      run_result = agent.run()
      if run_result.is_failed():
          print(f"worker 0 failed with: run_result.failures[0]")
      else:
          print(f"worker 0 return value is: run_result.return_values[0]")
  except Exception ex:
      # handle exception
會合處理常式¶
若要實作您自己的會合,請擴充 torch.distributed.elastic.rendezvous.RendezvousHandler 並實作其方法。
警告
會合處理常式的實作很棘手。在您開始之前,請確定您完全瞭解會合的屬性。如需更多資訊,請參閱 會合。
實作後,您可以在建立代理程式時將自訂會合處理常式傳遞給工作者規格。
spec = WorkerSpec(
    rdzv_handler=MyRendezvousHandler(params),
    ...
)
elastic_agent = LocalElasticAgent(spec, start_method=start_method)
elastic_agent.run(spec.role)
指標處理常式¶
TorchElastic 會發出平台級別的指標(請參閱 指標)。根據預設,指標會發送到 /dev/null,因此您不會看到它們。若要將指標推送至基礎結構中的指標處理服務,請實作 torch.distributed.elastic.metrics.MetricHandler 並在自訂啟動器中 設定 它。
# my_launcher.py
import torch.distributed.elastic.metrics as metrics
class MyMetricHandler(metrics.MetricHandler):
    def emit(self, metric_data: metrics.MetricData):
        # push metric_data to your metric sink
def main():
  metrics.configure(MyMetricHandler())
  spec = WorkerSpec(...)
  agent = LocalElasticAgent(spec)
  agent.run()
事件處理常式¶
TorchElastic 支援事件記錄(請參閱 事件)。事件模組定義了允許您記錄事件和實作自訂 EventHandler 的 API。EventHandler 用於將 torchelastic 執行期間產生的事件發佈到不同的來源,例如 AWS CloudWatch。根據預設,它會使用忽略事件的 torch.distributed.elastic.events.NullEventHandler。若要設定自訂事件處理常式,您需要實作 torch.distributed.elastic.events.EventHandler 介面並在自訂啟動器中 設定 它。
# my_launcher.py
import torch.distributed.elastic.events as events
class MyEventHandler(events.EventHandler):
    def record(self, event: events.Event):
        # process event
def main():
  events.configure(MyEventHandler())
  spec = WorkerSpec(...)
  agent = LocalElasticAgent(spec)
  agent.run()