快捷方式

伪张量

代码:fake_tensor.py

动机

在进行 Dynamo 符号评估和编译器传递时,我们经常希望能够运行张量运算以了解输出大小/数据类型/设备,而无需实际运行这些运算(或破坏现有张量),这会更慢(如果进行大量的计算)并且会占用大量内存(如果您的编译器需要在编译程序时使用 GPU 内存,这很糟糕)。伪张量在所有方面都像真实张量,只是它实际上没有任何数据。例如,当我们进行 Dynamo 跟踪时,我们需要跟踪用户张量代码并回答关于中间值的问题(例如,如果用户在中间张量上进行条件操作)。如果没有伪张量,我们将无法获得这些查询的准确信息。

类似地,假设您想为张量存储元数据,例如在 FX IR 节点上(meta[‘val’])。您可以改为直接在节点上存储伪张量,它将为您提供张量所需的所有元数据,包括您可能没有处理过的微妙内容(例如,别名关系)。

总体架构

所有伪张量都与 FakeTensorMode 相关联。由于伪张量的主要用例是对真实张量进行分析,因此一般工作流程是:您有一组真实张量,您分配一个 FakeTensorMode,然后使用 from_real_tensor 将所有这些真实张量转换为伪张量,然后您对伪张量执行操作。特别是,FakeTensorMode 持久地维护一个备忘表,将张量(和存储)映射到相同的存储。如果您多次伪造相同的张量,您将获得相同的伪张量;如果您伪造两个相互别名的张量,您将获得两个别名指向相同伪存储的伪张量。FakeTensor 是张量子类,因此如果您对它们进行操作,您将自动获得一个伪张量,但一般来说您将希望在 FakeTensorMode 处于活动状态时对伪张量进行操作(例如,如果您正在运行 FX 传递);张量操作将做的是自动打开伪张量模式并重试。

伪张量表示为元张量的 __torch_dispatch__ 张量子类。这意味着在幕后,伪张量是元设备张量;然后它们使用额外的可扩展性钩子(特别是 dispatch_device)来隐瞒张量的实际设备是什么。这是早期伪张量中最容易出错的部分之一:有时,伪张量在隐瞒自己是 CPU/CUDA 方面做得太好了,以至于您最终会调用一个 CPU 内核,该内核会使用伪张量来取消引用数据指针,这显然不会起作用。如果您在伪张量代码中遇到段错误,这应该是您首先要检查的内容:C++ 回溯是在 CPU 内核中(意外!)还是在元内核中(预期!)元内核就像一个真实的内核,但它只做的是分配输出,它不进行任何数据计算。

张量子类必须定义如何实现各种操作。以下是伪张量的一般配方

  • 在输入的假张量上运行元内核,将它们重新解释为元张量。这是通过一个神奇的上下文管理器 in_kernel_invocation_manager 完成的,该管理器指示所有 PyTorch 将假张量视为其底层的元张量,而不是将假张量“解包”为元张量(假张量实际上就是元张量)。假张量以这种方式表示是为了避免必须保持两组元数据的同步(元张量的元数据和假张量的元数据);“是”关系确保只存在元数据的唯一规范副本。

  • 如果你是一个工厂函数,你将改为使用 device='meta' 调用底层的工厂函数。

  • 将得到的元张量转换为假张量,计算张量的输出设备应该是什么(这通常很简单,但有时并非如此,例如,CPU 标量提升或设备转换操作)。

API:重要的部分

非 PT2 使用(查看 test/test_fake_tensor.py 获取更多示例)

# Create a fake mode
from torch._subclasses.fake_tensor import FakeTensorMode
fake_mode = FakeTensorMode()
# Fakeify some real tensors
fake_x = fake_mode.from_real_tensor(x)
with fake_mode:
    # Do some operations on the fake tensors
    fake_y = fake_x * 2
    # Factory operations automatically get fakeified in the context manager
    fake_z = torch.empty(20)

问:为什么你使用真实张量作为输入?

答:在 PT2 上下文中,这是因为你通常是进行即时编译,因此对于你正在编译的图的所有输入,你已经拥有“真实”输入,因为你是在执行程序时进行编译的。

PT2 预 AOTAutograd 使用(这很不寻常,你可能不想这样做)

# Fake mode is not enabled!
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode(args)
fake_args = [fake_mode.from_real_tensor(arg) for arg in args]
with fake_mode:
... do stuff with the fake args, if needed ...

detect_fake_mode 将搜索多个位置,试图找到与生命周期关联的“假张量模式”。通常它将从跟踪上下文中获取。

PT2 后 AOTAutograd 使用

# 假模式已启用!example_inputs 通常已经是假的 # TODO:我们可能希望改变这一点 # 仍然这样做以访问假模式 fake_mode = detect_fake_mode(example_inputs) # 但通常你不需要打开它

其他有用内容

from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
with maybe_disable_fake_tensor_mode():
    # fake mode is disabled here, you can do real tensor compute

你什么时候可能想要禁用假张量模式?通常你不想这样做。我们发现它有用的一个利基案例是在假张量上实现常量传播:在这种情况下,我们需要进行一些实际的张量计算,即使我们在假张量模式下。

FakeTensorProp
from torch.fx.passes.fake_tensor_prop
gm: GraphModule
real_inputs: List[Tensor]
FakeTensorProp(gm).propagate(*real_inputs)
# This will populate meta['val'] on all the FX nodes with a fake tensor
# or if you have a preexisting fake mode, you should use it
FakeTensorProp(gm, mode=fake_mode).propagate(*real_inputs)
# There is also propagate_dont_convert_inputs if your inputs are already fake
fake_inputs: List[FakeTensor]
FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(*fake_inputs)

详细信息

自动转换还是不转换?最初,如果你试图在 FakeTensorMode 区域内对真实张量进行计算,FakeTensorMode 不会自动将真实张量转换为假张量。这背后的动机是为了防止以下踩坑

with FakeTensorMode():
real_tensor.t_()

这段代码应该做什么?如果我们实际上修改了真实张量的元数据,这将令人惊讶。但同时,也没有明显的创建 FakeTensor 的机会。因此,我们保守地决定使这引发错误:“在 FakeTensorMode 中使用非假张量输入调用运算符尚不受支持。请先将所有张量转换为假张量。”

这个错误在实践中非常烦人。例如,假设你有一个真实的 nn.Module,你想通过它传递假张量。你需要以某种方式将 nn.Module 转换为假张量。这促使了 FakeCopyMode 的出现。

最终,我们放弃了,添加了自动转换功能。然而,这在许多 FakeTensorMode 的使用中仍然不是默认启用的。

假张量上的元数据变异 如果你有一个假张量,并且你对其调用 t_(),假张量上的元数据将发生改变。从表面上看,这是合理的,但有时你还想将假张量存储为 FX 节点上的元数据;变异假张量是不好的,因为这将使旧元数据失效!

事实上,这里存在一个根本的冲突,即假张量维护着关于张量的极其精确的元数据,包括对象标识。如果对象元数据在 FX 图中随时间变化,实际上无法表示这种变化。大多数时候,我们进行的严肃 FX 分析是在函数化图上进行的,这些图没有这个问题,但偶尔你需要对非函数化图进行分析。也许将假张量放在 meta['val'] 中是一个错误。

关于张量子类

假张量同时使用子类和模式张量子类模式,其中 FakeTensor.__torch_dispatch__ 启用与假张量关联的 FakeTensorMode,然后重新调度(依靠 FakeTensorMode 来完成繁重的工作)。如果假张量操作获得一个它不识别的子类参数,它将返回 NotImplemented,让另一个子类有机会先运行(希望将糖衣剥离成普通张量操作),然后它再尝试。这会导致无限循环。

每个单独的运算符是如何实现的?

不幸的是,存在一组相当复杂的位置,任何给定运算符都可能在那里实现。以下是一些需要注意的重要情况

  • 如果元素数量非常少,张量子类支持有限的常量传播(这有助于处理一些我们立即对这些张量调用 item() 的情况)。

  • 出于性能原因,我们为某些运算符提供了一些快速路径实现,这些实现完全在假张量中完成。

  • 如果你使用 @custom_op 生成自定义张量,这些张量将直接将 impl_abstract 注册到假张量。

  • 假张量本身对设备转换操作有一些硬编码的特例。

  • 如果没有元实现或任何分解,我们将生成真实的填充为零的张量,并尝试直接运行运算符以找出结果是什么。如果运算符试图使用数据进行索引,这会导致段错误,因此我们不会默认情况下为自定义运算符打开此功能。

转换器是如何工作的?

由于假张量用于对张量的精确属性非常敏感的情况,因此假张量非常小心地进行转换,保留叶子属性、requires_grad 属性、别名和许多其他属性。大部分繁重的工作都在 MetaConverter 中完成。

性能特征

你可能会认为假张量很快,因为它们不进行任何张量计算。但在小型张量尺寸下,我们实际上完全受开销限制,而且,假张量是在 Python 中实现的,我们经常做很多工作来执行单个张量操作(因为它们是用分解来实现的)。因此,假张量在实践中实际上相当慢,尤其是在涉及符号形状时。我们目前在假张量中拥有两个重要的快速路径,它们在实践中产生了很大的影响

  • 逐点操作不通过 PrimTorch 分解,而是我们手动编码了它们的传播规则。

  • 如果可能的话,我们应该。

假张量的假张量?

人们对将假张量作为用户输入发送到 PT2 堆栈感兴趣,这意味着我们需要能够创建假张量的假张量。目前还没有真正支持,但这可能并不难做到。

与动态形状的交互

每个 FakeTensorMode 都包含一个 ShapeEnv,它跟踪所有符号形状信息。它们的生存期通常是绑定的:它们一起存活和死亡。

由于 FakeTensorMode 具有 ShapeEnv(但元实现没有),因此数据相关且需要分配未支持的 SymInt 的元函数存在于假张量中。假张量还负责对未支持的 SymInt 进行记忆化,因此,例如,如果你对同一个假张量两次调用 nonzero(),你将获得相同的符号大小。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源