伪张量¶
代码: fake_tensor.py
动机¶
在进行 Dynamo 符号评估和编译器传递时,我们通常希望能够运行张量操作来了解输出大小/数据类型/设备是什么,而无需实际运行这些操作(或破坏现有的张量),这将更慢(如果你正在执行大量计算)并占用大量内存(如果你的编译器需要在编译程序时使用 GPU 内存,这很糟糕)。伪张量就像一个真正的张量,除了它实际上没有任何数据。例如,当我们进行 Dynamo 跟踪时,我们需要跟踪用户张量代码并回答有关中间结果的问题(例如,如果用户对中间张量进行条件判断)。如果没有伪张量,我们不会对这些查询有准确的信息。
类似地,假设你想要为张量存储元数据,例如,在 FX IR 节点上(meta[‘val’])。你可以直接在节点上存储一个伪张量,这将为你提供张量所需的所有元数据,包括你可能没有处理过的一些细微内容(例如,别名关系)。
总体架构¶
所有伪张量都与 FakeTensorMode 相关联。因为伪张量的主要用例是对真实张量进行分析,所以一般工作流程是:你有一堆真实的张量,你分配一个 FakeTensorMode,然后你使用 from_real_tensor 将所有这些真实的张量转换为伪张量,然后你对伪张量进行操作。特别是,FakeTensorMode 持久地维护一个记忆表,将张量(和存储)映射到相同的存储。如果你多次伪造同一个张量,你会得到同一个伪张量;如果你伪造两个相互别名的张量,你会得到两个别名相同伪存储的伪张量。FakeTensors 是张量子类,因此如果你对其执行操作,你将自动获得一个伪张量,但通常你希望对伪张量执行操作(例如,如果你正在运行一个 FX 传递);一个张量操作会做的是自动打开伪张量模式并重试。
伪张量表示为元张量的 __torch_dispatch__ 张量子类。这意味着在幕后,伪张量是元设备张量;然后它们使用额外的可扩展性钩子,特别是 dispatch_device,来撒谎关于张量的实际设备是什么。这是早期伪张量中更易出错的部分之一:有时,伪张量在撒谎自己是 CPU/CUDA 时太过出色,最终导致 CPU 内核被调用,其中一个伪张量试图取消引用数据指针,这显然不会工作。如果你在伪张量代码中遇到段错误,这是你应该检查的第一件事:C++ 回溯是在 CPU 内核中(意外!)还是在元内核中(预期!)元内核就像一个真正的内核,但它只做分配输出,不做任何数据计算。
张量子类必须定义如何实现各种操作。以下是伪张量的通用配方
在输入伪张量上运行元内核,将它们重新解释为元张量。这是通过一个名为 in_kernel_invocation_manager 的魔法上下文管理器来完成的,该管理器指示所有 PyTorch 将伪张量视为其底层的元张量,而不是将伪张量“解包”成元张量(伪张量是元张量)。伪张量以这种方式表示是为了避免必须同步维护两组元数据(元张量的元数据和伪张量的元数据);“是”关系确保元数据只有一份规范副本。
如果你是一个工厂函数,你将改为调用具有 device='meta' 的底层工厂函数。
将生成的元张量转换为伪张量,计算张量的输出设备应该是哪个(这通常很简单,但有时却不是,例如,CPU 标量提升,或设备转换操作)。
API:重要部分¶
非 PT2 用法(查看 test/test_fake_tensor.py 获取更多示例)
# Create a fake mode
from torch._subclasses.fake_tensor import FakeTensorMode
fake_mode = FakeTensorMode()
converter = fake_mode.fake_tensor_converter
# Fakeify some real tensors
fake_x = converter.from_real_tensor(fake_mode, x)
with fake_mode:
# Do some operations on the fake tensors
fake_y = fake_x * 2
# Factory operations automatically get fakeified in the context manager
fake_z = torch.empty(20)
问:为什么你的输入是真实的张量?
A: 在 PT2 上下文中,这是因为你通常是进行即时编译的,所以对于你正在编译的图的所有输入,你已经拥有了“真实”的输入,因为你在执行程序时进行编译。
PT2 预 AOTAutograd 使用(这种情况很少见,你可能不想这样做)
# Fake mode is not enabled!
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode(args)
# if fake_mode isn't None
converter = fake_mode.fake_tensor_converter
fake_args = [converter.from_real_tensor(fake_mode, arg) for arg in args]
with fake_mode:
... do stuff with the fake args, if needed ...
detect_fake_mode 将搜索多个位置以尝试找到与生命周期相关的“假张量模式”。通常,它将从跟踪上下文中提取。
PT2 后 AOTAutograd 使用
# 假模式已启用!example_inputs 通常已经是假的 # TODO:我们可能想改变这一点 # 仍然这样做以访问假模式 fake_mode = detect_fake_mode(example_inputs) # 但一般情况下你不需要打开它
其他有用信息
from torch._subclasses.fake_tensor import unset_fake_temporarily
with unset_fake_temporarily():
# fake mode is disabled here, you can do real tensor compute
什么时候可能需要禁用假张量模式?通常你不想这样做。我们发现它有用的一个利基案例是在假张量上实现常量传播:在这种情况下,我们需要进行一些实际的张量计算,即使我们处于假张量模式。
FakeTensorProp
from torch.fx.passes.fake_tensor_prop
gm: GraphModule
real_inputs: List[Tensor]
FakeTensorProp(gm).propagate(*real_inputs)
# This will populate meta['val'] on all the FX nodes with a fake tensor
# or if you have a preexisting fake mode, you should use it
FakeTensorProp(gm, mode=fake_mode).propagate(*real_inputs)
# There is also propagate_dont_convert_inputs if your inputs are already fake
fake_inputs: List[FakeTensor]
FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(*fake_inputs)
详细信息¶
自动转换还是不转换?最初,FakeTensorMode 不会自动将真实张量假化为假张量,如果你尝试在 FakeTensorMode 区域内对它们进行计算。这背后的动机是为了防止出现以下错误
with FakeTensorMode():
real_tensor.t_()
这段代码应该做什么?如果我们实际修改了真实张量上的元数据,这将令人惊讶。但同时,没有明显的机会创建 FakeTensor。因此,我们保守地决定让它引发错误:“在 FakeTensorMode 中调用具有非假张量输入的操作尚不支持。请先将所有张量转换为假张量。”
此错误在实践中非常烦人。例如,假设你有一个真实的 nn.Module,并且想要将假张量馈送到它。你需要以某种方式将 nn.Module 假化为假张量。这促使 FakeCopyMode 的出现。
最终,我们放弃了,并添加了自动假化。但是,这在 FakeTensorMode 的许多使用中仍然没有默认启用。
假张量上的元数据变异 如果有一个假张量,并且对其调用 t_(),则假张量上的元数据会发生变化。从表面上看这是合理的,但有时也希望将假张量存储为 FX 节点上的元数据;变异假张量很糟糕,因为这将使旧元数据失效!
事实上,这里存在一个根本性的矛盾,那就是假张量保持了关于张量的极其准确的元数据,包括对象标识。如果对象元数据在 FX 图中随着时间推移而发生变化,实际上没有办法表示这种随时间推移的变化。大多数时候,我们对功能化图进行严肃的 FX 分析,这些图没有这种问题,但偶尔需要对非功能化图进行分析。也许将假张量放入 meta['val'] 中是一个错误
关于张量子类¶
FakeTensor 使用子类和模式张量子类模式,其中 FakeTensor.__torch_dispatch__ 启用与假张量关联的 FakeTensorMode,然后重新分派(依赖 FakeTensorMode 进行繁重的工作)。如果假张量操作得到它不认识的子类参数,它将返回 NotImplemented,让另一个子类有机会首先运行(希望反糖为普通张量操作),然后它会再次尝试。这会导致无限循环。
每个单独的操作是如何实现的?¶
不幸的是,任何给定操作可能在相当复杂的一组地方实现。一些重要的案例需要知道
张量子类如果元素数量非常少,则支持有限的常量传播(这有助于处理一些我们立即对这些张量调用 item() 的情况。)
我们为某些操作提供了一些快速路径实现,这些实现完全在假张量中完成,出于性能原因。
如果使用 @custom_op 生成自定义张量,这些将直接向假张量注册 impl_abstract。
假张量本身对设备转换操作有一些硬编码的特例。
如果没有元实现或任何分解,我们将生成真实零填充张量,并尝试直接运行操作以找出结果将是什么。如果操作尝试使用数据进行索引,这会导致段错误,因此我们不会为自定义操作默认打开它。
转换器是如何工作的?¶
因为假张量用于对张量的精确属性非常敏感的情况,所以假张量会非常小心地进行转换,保留叶子状态、requires_grad 状态、别名以及一大堆其他属性。繁重的工作大部分都在 MetaConverter 中完成。
性能特征¶
你会认为假张量很快,因为它们不执行任何张量计算。但在小张量尺寸下,我们实际上完全是开销绑定的,而且,假张量是在 Python 中的,我们经常做很多工作来进行单个张量操作(因为它们被实现为分解)。因此,假张量在实践中实际上相当慢,尤其是在涉及符号形状时。我们目前在假张量中有两个重要的快速路径,在实践中产生了很大差异
逐点操作不会经过 PrimTorch 分解,而是我们手写了它们的传播规则。
如果可能的话,我们应该。
假张量的假张量?¶
有人对将假张量作为用户输入发送到 PT2 堆栈感兴趣,这意味着我们需要能够创建一个假张量的假张量。现在还不支持,但也许做起来并不太难。
与动态形状的交互¶
每个 FakeTensorMode 都包含一个 ShapeEnv,它跟踪所有符号形状信息。它们的生命周期通常是绑定的:它们一起存在和消亡。
因为 FakeTensorMode 有一个 ShapeEnv(但元实现没有),所以依赖于数据的元函数需要分配一个无后备的 SymInt,它们存在于假张量中。假张量还负责对无后备的 SymInt 进行记忆,因此,例如,如果对同一个假张量调用 nonzero() 两次,你将获得相同的符号大小。