快捷方式

跟踪

概述和用法

注意

实验性,请自行承担风险,API 可能随时更改

在 TorchX 中,应用程序是二进制文件 (可执行文件),因此没有内置的“返回”应用程序结果的方法。该 torchx.runtime.tracking 模块允许应用程序返回简单结果 (请注意“简单”一词)。跟踪器模块支持的返回类型有意受到限制。例如,尝试返回训练后的模型权重是不允许的,因为这些权重的大小可能达到数百 GB。此模块不适合传递大量数据或二进制块。

当应用程序作为更高层协调工作的一部分启动时 (例如,工作流、管道、超参数优化),应用程序的结果通常需要可供协调器或工作流中的其他应用程序访问。

假设 App1 和 App2 按顺序启动,App1 的输出作为 App2 的输入。由于这些是二进制文件,因此应用程序之间链接输入/输出的典型方法是将 App1 的输出文件路径作为 App2 的输入文件路径传递。

$ app1 --output-file s3://foo/out/app1.out
$ app2 --input-file s3://foo/out/app1.out

虽然这似乎很简单,但有一些问题需要注意。

  1. 文件 app1.out 的格式 (App1 需要以 App2 理解的格式写入它)。

  2. 实际解析 URL 以及写入/读取输出文件。

因此,应用程序的主程序最终将如下所示 (为演示目的的伪代码)。

# in app1.py
if __name__ == "__main__":
   accuracy = do_something()
   s3client = ...
   out = {"accuracy": accuracy}

   with open("/tmp/out", "w") as f:
       f = json.dumps(out).encode("utf-8")

   s3client.put(args.output_file, f)

# in app2.py
if __name__ == "__main__":
   s3client = ...
   with open("/tmp/out", "w") as f:
       s3client.get(args.input_file, f)

   with open("/tmp/out", "r") as f:
       in = json.loads(f.read().decode("utf-8"))

   do_something_else(in["accuracy"])

相反,使用跟踪器,可以使用具有相同 tracker_base 的跟踪器跨应用程序,使一个应用程序的返回值可供另一个应用程序访问,而无需链接一个应用程序的输出文件路径和另一个应用程序的输入文件路径,也不必处理自定义序列化和文件写入。

# in app1.py
if __name__ == "__main__":
   accuracy = do_something()
   tracker = FsspecResultTracker(args.tracker_base)
   tracker["app1_out"] = {"accuracy": accuracy}

# in app2.py
if __name__ == "__main__":
   tracker = FsspecResultTracker(args.tracker_base)
   app1_accuracy = tracker["app1_out"]
   do_something_else(app1_accuracy)

ResultTracker

基础

class torchx.runtime.tracking.ResultTracker[source]

基础结果跟踪器,应对其进行子类化以实现跟踪器。通常,每个后备存储都有一个跟踪器实现。

用法

# get and put APIs can be used directly or in map-like API
# the following are equivalent
tracker.put("foo", l2norm=1.2)
tracker["foo"] = {"l2norm": 1.2}

# so are these
tracker.get("foo")["l2norm"] == 1.2
tracker["foo"]["l2norm"] == 1.2

有效的 result 类型为

  1. 数字:int、float

  2. 文字:str (使用 utf-8 编码时大小限制为 1kb)

有效的 key 类型为

  1. int

  2. str

作为约定,可以在键中使用“斜杠”来存储统计结果。例如,要存储 l2norm 的均值和标准误。

tracker[key] = {"l2norm/mean" : 1.2, "l2norm/sem": 3.4}
tracker[key]["l2norm/mean"] # returns 1.2
tracker[key]["l2norm/sem"] # returns 3.4

假设键在跟踪器后备存储的范围内是唯一的。例如,如果跟踪器以本地目录为后备存储,并且 key 是目录中保存结果的文件,那么

# same key, different backing directory -> results are not overwritten
FsspecResultTracker("/tmp/foo")["1"] = {"l2norm":1.2}
FsspecResultTracker("/tmp/bar")["1"] = {"l2norm":3.4}

跟踪器不是一个中心实体,因此不会对 putget 对同一键的操作进行强一致性保证 (超出后备存储提供的一致性)。同样,对同一键的两次连续 putget 操作也不会进行强一致性保证。

例如

tracker[1] = {"l2norm":1.2}
tracker[1] = {"l2norm":3.4}
tracker[1] # NOT GUARANTEED TO BE 3.4!

sleep(1*MIN)
tracker[1] # more likely to be 3.4 but still not guaranteed!

强烈建议使用唯一的 ID 作为键。对于简单作业,此 ID 通常是作业 ID,或者对于超参数优化之类的迭代应用程序,可以是 (实验 ID、试验编号) 或 (作业 ID、副本/工作程序排名) 的串联。

Fsspec

class torchx.runtime.tracking.FsspecResultTracker(tracker_base: str)[source]

使用 fsspec 作为后备存储的跟踪器。

用法

from torchx.runtime.tracking import FsspecResultTracker

# PUT: in trainer.py
tracker_base = "/tmp/foobar" # also supports URIs (e.g. "s3://bucket/trainer/123")
tracker = FsspecResultTracker(tracker_base)
tracker["attempt_1/out"] = {"accuracy": 0.233}

# GET: anywhere outside trainer.py
tracker = FsspecResultTracker(tracker_base)
print(tracker["attempt_1/out"]["accuracy"])
0.233

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源