• 教程 >
  • (beta) 使用 torch.compile 利用 Torch Function 模式
快捷方式

(beta) 使用 torch.compile 利用 Torch Function 模式

作者: Michael Lazos

本秘籍介绍了如何使用一个关键的 PyTorch 可扩展点,

Torch Function 模式,它与 torch.compile 协同工作,可以在跟踪时覆盖 PyTorch 算子(也称为 ops)的行为,并且没有运行时开销。

注意

本秘籍需要 PyTorch 2.7.0 或更高版本。

重写 PyTorch 算子 (torch.add -> torch.mul)

对于本示例,我们将使用 Torch Function 模式将加法运算替换为乘法运算。如果某个后端具有应针对给定算子进行分发的自定义实现,这种类型的覆盖可能会很常见。

import torch

# exit cleanly if we are on a device that doesn't support ``torch.compile``
if torch.cuda.get_device_capability() < (7, 0):
    print("Exiting because torch.compile is not supported on this device.")
    import sys
    sys.exit(0)

from torch.overrides import BaseTorchFunctionMode

# Define our mode, Note: ``BaseTorchFunctionMode``
# implements the actual invocation of func(..)
class AddToMultiplyMode(BaseTorchFunctionMode):
    def __torch_function__(self, func, types, args=(), kwargs=None):
        if func == torch.Tensor.add:
            func = torch.mul

        return super().__torch_function__(func, types, args, kwargs)

@torch.compile()
def test_fn(x, y):
    return x + y * x # Note: infix operators map to torch.Tensor.* methods

x = torch.rand(2, 2)
y = torch.rand_like(x)

with AddToMultiplyMode():
    z = test_fn(x, y)

assert torch.allclose(z, x * y * x)

# The mode can also be used within the compiled region as well like this:

@torch.compile()
def test_fn(x, y):
    with AddToMultiplyMode():
        return x + y * x # Note: infix operators map to torch.Tensor.* methods

x = torch.rand(2, 2)
y = torch.rand_like(x)
z = test_fn(x, y)

assert torch.allclose(z, x * y * x)

结论

在本秘籍中,我们演示了如何使用 torch.compile 中的 Torch Function 模式来覆盖 torch.* 算子的行为。这使得用户可以利用 Torch Function 模式的可扩展性优势,而无需承担在每次调用算子时调用 Torch Function 的运行时开销。

脚本总运行时间: ( 0 分钟 5.925 秒)

由 Sphinx-Gallery 生成的图库


评价本教程

© 版权所有 2024, PyTorch。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

查阅 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源