• 教程 >
  • 使用 torch.compiler.set_stance 进行动态编译控制
快捷方式

使用 torch.compiler.set_stance 进行动态编译控制

作者: William Wen

torch.compiler.set_stance 是一个 torch.compiler API,它使你能够在对模型进行不同调用时更改 torch.compile 的行为,而无需重新对模型应用 torch.compile

本代码片段提供了一些关于如何使用 torch.compiler.set_stance 的示例。

前提条件

  • torch >= 2.6

描述

torch.compile.set_stance 可以用作装饰器、上下文管理器或原始函数,以在对模型进行不同调用时更改 torch.compile 的行为。

在下面的示例中,"force_eager" 立场会忽略所有 torch.compile 指令。

import torch


@torch.compile
def foo(x):
    if torch.compiler.is_compiling():
        # torch.compile is active
        return x + 1
    else:
        # torch.compile is not active
        return x - 1


inp = torch.zeros(3)

print(foo(inp))  # compiled, prints 1
tensor([1., 1., 1.])

装饰器用法示例

@torch.compiler.set_stance("force_eager")
def bar(x):
    # force disable the compiler
    return foo(x)


print(bar(inp))  # not compiled, prints -1
tensor([-1., -1., -1.])

上下文管理器用法示例

with torch.compiler.set_stance("force_eager"):
    print(foo(inp))  # not compiled, prints -1
tensor([-1., -1., -1.])

原始函数用法示例

torch.compiler.set_stance("force_eager")
print(foo(inp))  # not compiled, prints -1
torch.compiler.set_stance("default")

print(foo(inp))  # compiled, prints 1
tensor([-1., -1., -1.])
tensor([1., 1., 1.])

torch.compile 立场只能在任何 torch.compile 区域的外部更改。否则将导致错误。

@torch.compile
def baz(x):
    # error!
    with torch.compiler.set_stance("force_eager"):
        return x + 1


try:
    baz(inp)
except Exception as e:
    print(e)


@torch.compiler.set_stance("force_eager")
def inner(x):
    return x + 1


@torch.compile
def outer(x):
    # error!
    return inner(x)


try:
    outer(inp)
except Exception as e:
    print(e)
Attempt to trace forbidden callable <function set_stance at 0x7f12f655a7a0>

from user code:
   File "/var/lib/workspace/recipes_source/torch_compiler_set_stance_tutorial.py", line 85, in baz
    with torch.compiler.set_stance("force_eager"):

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Attempt to trace forbidden callable <function inner at 0x7f1263feba30>

from user code:
   File "/var/lib/workspace/recipes_source/torch_compiler_set_stance_tutorial.py", line 103, in outer
    return inner(x)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
其他立场包括
  • "default":默认立场,用于正常编译。

  • "eager_on_recompile":当需要重新编译时,以 eager 模式运行代码。如果存在对输入有效的缓存编译代码,它仍然会被使用。

  • "fail_on_recompile":在重新编译函数时引发错误。

请参阅 torch.compiler.set_stance 文档页面了解更多立场和选项。未来可能还会添加更多立场/选项。

示例

防止重新编译

有些模型不期望任何重新编译 - 例如,你的输入形状可能总是相同的。由于重新编译可能代价高昂,我们可能希望在尝试重新编译时报错,以便检测并修复重新编译的情况。"fail_on_recompilation" 立场可用于此目的。

@torch.compile
def my_big_model(x):
    return torch.relu(x)


# first compilation
my_big_model(torch.randn(3))

with torch.compiler.set_stance("fail_on_recompile"):
    my_big_model(torch.randn(3))  # no recompilation - OK
    try:
        my_big_model(torch.randn(4))  # recompilation - error
    except Exception as e:
        print(e)
Detected recompile when torch.compile stance is 'fail_on_recompile'

如果报错过于中断,我们可以改用 "eager_on_recompile",它将导致 torch.compile 回退到 eager 模式而不是报错。如果重新编译不常发生,但在需要时,我们宁愿承担 eager 模式运行的代价而不是重新编译的代价,这可能会很有用。

@torch.compile
def my_huge_model(x):
    if torch.compiler.is_compiling():
        return x + 1
    else:
        return x - 1


# first compilation
print(my_huge_model(torch.zeros(3)))  # 1

with torch.compiler.set_stance("eager_on_recompile"):
    print(my_huge_model(torch.zeros(3)))  # 1
    print(my_huge_model(torch.zeros(4)))  # -1
    print(my_huge_model(torch.zeros(3)))  # 1
tensor([1., 1., 1.])
tensor([1., 1., 1.])
tensor([-1., -1., -1., -1.])
tensor([1., 1., 1.])

衡量性能提升

torch.compiler.set_stance 可用于比较 eager 模式与编译模式的性能,而无需定义单独的 eager 模型。

# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000


@torch.compile
def my_gigantic_model(x, y):
    x = x @ y
    x = x @ y
    x = x @ y
    return x


inps = torch.randn(5, 5), torch.randn(5, 5)

with torch.compiler.set_stance("force_eager"):
    print("eager:", timed(lambda: my_gigantic_model(*inps))[1])

# warmups
for _ in range(3):
    my_gigantic_model(*inps)

print("compiled:", timed(lambda: my_gigantic_model(*inps))[1])
eager: 6.774400174617767e-05
compiled: 8.297599852085114e-05

更快地发现错误

在使用 "force_eager" 立场进行编译迭代之前,先运行一个 eager 迭代,可以帮助我们在尝试长时间编译之前捕获与 torch.compile 无关的错误。

@torch.compile
def my_humongous_model(x):
    return torch.sin(x, x)


try:
    with torch.compiler.set_stance("force_eager"):
        print(my_humongous_model(torch.randn(3)))
    # this call to the compiled model won't run
    print(my_humongous_model(torch.randn(3)))
except Exception as e:
    print(e)
sin() takes 1 positional argument but 2 were given

结论

在本代码片段中,我们学习了如何使用 torch.compiler.set_stance API 在对模型进行不同调用时修改 torch.compile 的行为,而无需重新应用它。本代码片段演示了如何使用 torch.compiler.set_stance 作为装饰器、上下文管理器或原始函数来控制编译立场,例如 force_eagerdefaulteager_on_recompile 和 “fail_on_recompile”。

更多信息请参阅:torch.compiler.set_stance API 文档

脚本总运行时间: ( 0 分钟 8.268 秒)

由 Sphinx-Gallery 生成的图库


评价本教程

© Copyright 2024, PyTorch.

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源