注意
点击此处下载完整的示例代码
(beta)使用 TORCH_LOGS python API 与 torch.compile¶
创建于:2024 年 1 月 24 日 | 最后更新:2024 年 1 月 31 日 | 最后验证:2024 年 11 月 05 日
作者: Michael Lazos
import logging
本教程介绍了 TORCH_LOGS
环境变量以及 Python API,并演示了如何应用它来观察 torch.compile
的各个阶段。
注意
本教程需要 PyTorch 2.2.0 或更高版本。
设置¶
在本示例中,我们将设置一个执行逐元素加法的简单 Python 函数,并使用 TORCH_LOGS
Python API 观察编译过程。
注意
还有一个环境变量 TORCH_LOGS
,可用于在命令行更改日志设置。每个示例都显示了等效的环境变量设置。
import torch
# exit cleanly if we are on a device that doesn't support torch.compile
if torch.cuda.get_device_capability() < (7, 0):
print("Skipping because torch.compile is not supported on this device.")
else:
@torch.compile()
def fn(x, y):
z = x + y
return z + 2
inputs = (torch.ones(2, 2, device="cuda"), torch.zeros(2, 2, device="cuda"))
# print separator and reset dynamo
# between each example
def separator(name):
print(f"==================={name}=========================")
torch._dynamo.reset()
separator("Dynamo Tracing")
# View dynamo tracing
# TORCH_LOGS="+dynamo"
torch._logging.set_logs(dynamo=logging.DEBUG)
fn(*inputs)
separator("Traced Graph")
# View traced graph
# TORCH_LOGS="graph"
torch._logging.set_logs(graph=True)
fn(*inputs)
separator("Fusion Decisions")
# View fusion decisions
# TORCH_LOGS="fusion"
torch._logging.set_logs(fusion=True)
fn(*inputs)
separator("Output Code")
# View output code generated by inductor
# TORCH_LOGS="output_code"
torch._logging.set_logs(output_code=True)
fn(*inputs)
separator("")
Skipping because torch.compile is not supported on this device.
结论¶
在本教程中,我们通过试验少量可用的日志选项,介绍了 TORCH_LOGS 环境变量和 python API。要查看所有可用选项的描述,请运行任何导入 torch 的 python 脚本并将 TORCH_LOGS 设置为“help”。
或者,您可以查看 torch._logging 文档,以查看所有可用日志选项的描述。
有关 torch.compile 的更多信息,请参阅 torch.compile 教程。
脚本总运行时间: ( 0 分钟 0.002 秒)