快捷方式

伪张量

代码: fake_tensor.py

动机

在进行 Dynamo 符号评估和编译器传递时,我们经常希望能够运行张量运算来了解输出大小/数据类型/设备,而无需实际运行这些运算(或破坏现有的张量),这会更慢(如果您正在进行大量的计算)并且会占用大量内存(如果您的编译器需要在编译程序时使用 GPU 内存,这很糟糕)。伪张量在所有方面都类似于真实张量,只是它实际上没有任何数据。例如,当我们进行 Dynamo 跟踪时,我们需要跟踪用户张量代码并回答有关中间值的问题(例如,如果用户对中间张量进行条件判断)。没有伪张量,我们就无法获得这些查询的准确信息。

类似地,假设您想为张量存储元数据,例如在 FX IR 节点上(meta[‘val’])。您可以直接在节点上存储一个假张量,这将为您提供张量所需的所有元数据,包括您可能没有处理的细微内容(例如,别名关系)。

总体架构

所有虚假张量都与一个 FakeTensorMode 相关联。由于虚假张量的主要用例是对真实张量进行分析,因此一般的流程是:你有一堆真实张量,你分配一个 FakeTensorMode,然后你使用 from_real_tensor 将所有这些真实张量转换为虚假张量,然后你对虚假张量进行操作。特别是,FakeTensorMode 持久地维护一个备忘表,将张量(和存储)映射到相同的存储。如果你多次虚假化同一个张量,你将得到同一个虚假张量;如果你虚假化两个相互别名的张量,你将得到两个别名指向同一个虚假存储的虚假张量。FakeTensor 是张量子类,所以如果你对它们进行操作,你将自动获得一个虚假张量,但通常情况下,你希望在 FakeTensorMode 激活的情况下对虚假张量进行操作(例如,如果你正在运行 FX 传递);张量操作将做的是自动开启虚假张量模式并尝试再次执行。

虚假张量表示为元张量的 __torch_dispatch__ 张量子类。这意味着在幕后,虚假张量是元设备张量;然后它们使用额外的扩展性钩子,特别是 dispatch_device,来撒谎关于张量的实际设备是什么。这是早期虚假张量中最容易出错的部分之一:有时,虚假张量在撒谎自己是 CPU/CUDA 时过于出色,最终会导致 CPU 内核被调用,而虚假张量试图取消引用数据指针,这显然行不通。如果你在虚假张量代码中遇到段错误,这是你应该检查的第一件事:C++ 回溯是在 CPU 内核(意外!)还是元内核(预期!)中?元内核就像一个真实内核,但它只做分配输出,它不进行任何数据计算。

张量子类必须定义如何实现各种操作。以下是通用的虚假张量配方

  • 在输入的假张量上运行元内核,将它们重新解释为元张量。这是通过一个神奇的上下文管理器 in_kernel_invocation_manager 完成的,它指示所有 PyTorch 将假张量视为其底层的元张量,而不是将假张量“解包”为元张量(假张量是一个元张量)。假张量以这种方式表示是为了避免必须同步两组元数据(元张量的元数据和假张量的元数据);“是”关系确保只有一个规范的元数据副本。

  • 如果您是工厂函数,则改为使用 device='meta' 调用底层工厂函数。

  • 将生成的元张量转换为假张量,计算张量的输出设备应该是什么(这通常很简单,但有时并非如此,例如,CPU 标量提升或设备转换操作)。

API:重要部分

非 PT2 使用(查看 test/test_fake_tensor.py 获取更多示例)

# Create a fake mode
from torch._subclasses.fake_tensor import FakeTensorMode
fake_mode = FakeTensorMode()
# Fakeify some real tensors
fake_x = fake_mode.from_real_tensor(x)
with fake_mode:
    # Do some operations on the fake tensors
    fake_y = fake_x * 2
    # Factory operations automatically get fakeified in the context manager
    fake_z = torch.empty(20)

问:为什么您将真实张量作为输入?

答:在 PT2 上下文中,这是因为您通常是进行即时编译,因此对于您正在编译的图的所有输入,您已经拥有“真实”输入,因为您是在执行程序时进行编译的。

PT2 预 AOTAutograd 使用(这很不寻常,您可能不想这样做)

# Fake mode is not enabled!
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode(args)
fake_args = [fake_mode.from_real_tensor(arg) for arg in args]
with fake_mode:
... do stuff with the fake args, if needed ...

detect_fake_mode 将搜索多个位置以尝试找到与生命周期相关的“假张量模式”。通常它将从跟踪上下文中提取。

PT2 后 AOTAutograd 使用

# 假模式已启用!example_inputs 通常已经是假的 # TODO:我们可能希望更改此项 # 仍然这样做以访问假模式 fake_mode = detect_fake_mode(example_inputs) # 但一般来说,您不必打开它

其他有用信息

from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
with maybe_disable_fake_tensor_mode():
    # fake mode is disabled here, you can do real tensor compute

您什么时候可能想要禁用假张量模式?通常您不想这样做。我们发现它有用的一个利基案例是在假张量上实现常量传播:在这种情况下,我们需要进行一些实际的张量计算,即使我们处于假张量模式。

FakeTensorProp
from torch.fx.passes.fake_tensor_prop
gm: GraphModule
real_inputs: List[Tensor]
FakeTensorProp(gm).propagate(*real_inputs)
# This will populate meta['val'] on all the FX nodes with a fake tensor
# or if you have a preexisting fake mode, you should use it
FakeTensorProp(gm, mode=fake_mode).propagate(*real_inputs)
# There is also propagate_dont_convert_inputs if your inputs are already fake
fake_inputs: List[FakeTensor]
FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(*fake_inputs)

详细信息

自动转换还是不转换?最初,FakeTensorMode 不会在 FakeTensorMode 区域内自动将真实张量转换为假张量。这样做的目的是为了防止以下陷阱

with FakeTensorMode():
real_tensor.t_()

这段代码应该做什么?如果我们真的修改了真实张量的元数据,那会很奇怪。但同时,也没有明显的机会创建 FakeTensor。因此,我们保守地决定让它抛出错误:“在 FakeTensorMode 中调用非 FakeTensor 输入的操作尚不支持。请先将所有张量转换为 FakeTensor。”

这个错误在实践中非常烦人。例如,假设你有一个真实的 nn.Module,你想用假张量来喂它。你需要以某种方式将 nn.Module 转换为假张量。这促使了 FakeCopyMode 的诞生。

最终,我们放弃了,并添加了自动转换功能。但是,在许多使用 FakeTensorMode 的情况下,它仍然默认未启用。

假张量的元数据变动 如果有一个假张量,并且对其执行 t_() 操作,则假张量的元数据会发生变化。从表面上看,这是合理的,但有时你可能还想将假张量存储为 FX 节点的元数据;修改假张量很糟糕,因为这会使旧的元数据失效!

事实上,这里存在一个根本性的矛盾,即假张量维护着关于张量的极其精确的元数据,包括对象标识。如果对象元数据在 FX 图中随时间变化,实际上没有办法表示这种随时间变化。大多数情况下,我们对功能化图进行严肃的 FX 分析,这些图没有这种问题,但偶尔你需要对非功能化图进行分析。也许将假张量放在 meta['val'] 中是一个错误

关于张量子类

假张量同时使用子类和模式张量子类模式,其中 FakeTensor.__torch_dispatch__ 启用与假张量关联的 FakeTensorMode,然后重新调度(依赖 FakeTensorMode 完成繁重的工作)。如果假张量操作得到一个它不识别的子类参数,它将返回 NotImplemented,让另一个子类有机会先运行(希望反糖化为普通张量操作),然后再尝试。这会导致无限循环。

每个单独的操作是如何实现的?

不幸的是,任何给定运算符的实现位置都有一套相当复杂的规则。以下是一些需要了解的重要情况

  • 如果张量子类的元素数量非常少,则它们支持有限的常量传播(这有助于处理一些我们立即对这些张量调用 item() 的情况。)

  • 我们为某些运算符提供了一些快速路径实现,这些实现完全在假张量中完成,以提高性能。

  • 如果您使用 @custom_op 生成自定义张量,这些将直接将 impl_abstract 注册到假张量。

  • 假张量本身对设备转换操作有一些硬编码的特殊情况。

  • 如果没有元实现或任何分解,我们将生成真实的零填充张量,并尝试直接运行运算符以找出结果。如果运算符尝试对数据进行索引,这可能会导致段错误,因此我们默认情况下不会为自定义运算符启用此功能。

转换器如何工作?

由于假张量用于对张量的精确属性非常敏感的情况,因此假张量非常小心地进行转换,保留叶子性、requires_grad 性、别名以及许多其他属性。大部分繁重的工作都在 MetaConverter 中完成。

性能特征

您可能会认为假张量很快,因为它们不执行任何张量计算。但在小型张量尺寸下,我们实际上完全受到开销的限制,而且,假张量是用 Python 编写的,我们经常会做很多工作来执行单个张量操作(因为它们是作为分解实现的)。因此,假张量在实践中实际上相当慢,尤其是在涉及符号形状时。目前我们在假张量中拥有两个重要的快速路径,它们在实践中产生了很大影响

  • 逐点运算不会经过 PrimTorch 分解,而是我们手动编码了它们的传播规则。

  • 如果可能,我们应该。

假张量的假张量?

人们对将伪张量作为用户输入发送到 PT2 堆栈感兴趣,这意味着我们需要能够创建伪张量的伪张量。目前这并不支持,但也许做起来并不难。

与动态形状的交互

每个 FakeTensorMode 都包含一个 ShapeEnv,它跟踪所有符号形状信息。它们的生存期通常是绑定的:它们一起存活和死亡。

因为 FakeTensorMode 有一个 ShapeEnv(但元实现没有),所以依赖于数据的元函数需要分配一个未支持的 SymInt,它们存在于伪张量中。伪张量还负责记忆未支持的 SymInt,因此,例如,如果您在同一个伪张量上调用 nonzero() 两次,您将获得相同的符号大小。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源