Fake tensor¶
代码: fake_tensor.py
动机¶
在进行 Dynamo 符号评估和编译器传递时,我们经常希望能够运行张量操作来了解输出大小/数据类型/设备,而无需实际运行这些操作(或破坏预先存在的张量),因为这样做会更慢(如果您进行大量计算)并占用大量内存(如果您的编译器在编译程序时需要使用 GPU 内存,那将是很糟糕的)。Fake tensor 就像一个真实张量一样,除了它实际上没有任何数据。例如,当我们进行 Dynamo 跟踪时,我们需要跟踪用户 Tensor 代码并回答有关中间值的问题(例如,如果用户对中间张量执行条件操作)。如果没有 fake tensor,我们将无法获得这些查询的准确信息。
同样,假设您要存储张量的元数据,例如,在 FX IR 节点上 (meta[‘val’])。您可以改为直接在节点上存储 fake tensor,这将为您提供张量所需的所有元数据,包括您可能没有处理的细微之处(例如,别名关系)。
总体架构¶
所有 fake tensor 都与 FakeTensorMode 相关联。由于 fake tensor 的主要用例是对真实张量进行分析,因此一般工作流程是您拥有一堆真实张量,您分配一个 FakeTensorMode,然后您使用 from_real_tensor 将所有这些真实张量转换为 fake tensor,然后您对 fake tensor 执行操作。特别是,FakeTensorMode 持久维护一个备忘录表,将张量(和存储)映射到相同的存储。如果您多次 fakeify 同一张量,您将获得相同的 fake tensor;如果您 fakeify 两个互相别名的张量,您将获得两个别名相同 fake 存储的 fake tensor。FakeTensor 是张量子类,因此如果您对它们执行操作,您将自动获得 fake tensor,但一般来说,您会希望在 FakeTensorMode 处于活动状态时对 fake tensor 执行操作(例如,如果您正在运行 FX 传递);张量操作将自动开启 fake tensor 模式并重试。
Fake tensor 表示为 meta tensor 的 __torch_dispatch__ 张量子类。这意味着在底层,fake tensor 是 meta 设备张量;然后它们使用额外的可扩展性钩子,特别是 dispatch_device,来谎报张量的实际设备。这是早期 fake tensor 中更易出错的部分之一:有时,fake tensor 太擅长谎报自己是 CPU/CUDA 等,您最终会得到一个 CPU 内核被调用,而 fake tensor 试图解引用数据指针,这显然是行不通的。如果您在 fake tensor 代码中遇到段错误,这是您应该首先检查的事情:C++ 回溯是在 CPU 内核中(意外!)还是在 meta 内核中(预期!)Meta 内核就像一个真实内核,但它所做的只是分配输出,它不进行任何数据计算。
张量子类必须定义如何实现各种操作。以下是通用的 fake tensor 配方
在输入 fake tensor 上运行 meta 内核,将它们重新解释为 meta tensor。这是通过一个神奇的上下文管理器 in_kernel_invocation_manager 完成的,它指示 PyTorch 的所有部分将 fake tensor 视为其底层的 meta tensor,而不是将 fake tensor “解包”为 meta tensor(fake tensor 是 meta tensor)。之所以以这种方式表示 fake tensor,是为了避免必须保持两组元数据同步(meta tensor 的元数据和 fake tensor 的元数据);“is a” 关系确保只有一个规范的元数据副本。
如果您是工厂函数,您将改为使用 device=’meta’ 调用底层的工厂函数。
将生成的 meta tensor 转换为 fake tensor,计算张量的输出设备应该是什么(这通常是微不足道的,但有时并非如此,例如,CPU 标量提升或设备转换操作)。
API:重要部分¶
非 PT2 用法(有关更多示例,请查看 test/test_fake_tensor.py)
# Create a fake mode
from torch._subclasses.fake_tensor import FakeTensorMode
fake_mode = FakeTensorMode()
converter = fake_mode.fake_tensor_converter
# Fakeify some real tensors
fake_x = converter.from_real_tensor(fake_mode, x)
with fake_mode:
# Do some operations on the fake tensors
fake_y = fake_x * 2
# Factory operations automatically get fakeified in the context manager
fake_z = torch.empty(20)
问:为什么您将真实张量作为输入?
答:在 PT2 上下文中,这是因为您通常是即时编译,因此对于您正在编译的图的所有输入,您已经有了“真实”输入,因为您是在执行程序时进行编译的。
PT2 pre-AOTAutograd 用法(这很不寻常,您可能不想这样做)
# Fake mode is not enabled!
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode(args)
# if fake_mode isn't None
converter = fake_mode.fake_tensor_converter
fake_args = [converter.from_real_tensor(fake_mode, arg) for arg in args]
with fake_mode:
... # do stuff with the fake args, if needed ...
detect_fake_mode 将搜索多个位置,尝试查找与生命周期关联的“the” fake tensor 模式。通常,它将从跟踪上下文中拉取。
PT2 post-AOTAutograd 用法
# Fake mode is enabled! example_inputs is typically fake already
# TODO: we probably want to change this
# Still do this to access fake mode
fake_mode = detect_fake_mode(example_inputs)
# But in general you don't have to turn it on
其他有用的东西
from torch._subclasses.fake_tensor import unset_fake_temporarily
with unset_fake_temporarily():
... # fake mode is disabled here, you can do real tensor compute
何时可能需要禁用 fake tensor 模式?通常您不想这样做。我们发现它有用的一个特殊情况是在 fake tensor 上实现常量传播:在这种情况下,即使我们处于 fake tensor 模式,我们也需要进行一些实际的张量计算。
import FakeTensorProp from torch.fx.passes.fake_tensor_prop
gm: GraphModule
real_inputs: List[Tensor]
FakeTensorProp(gm).propagate(*real_inputs)
# This will populate meta['val'] on all the FX nodes with a fake tensor
# or if you have a preexisting fake mode, you should use it
FakeTensorProp(gm, mode=fake_mode).propagate(*real_inputs)
# There is also propagate_dont_convert_inputs if your inputs are already fake
fake_inputs: List[FakeTensor]
FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(*fake_inputs)
细节¶
自动转换还是不自动转换?最初,如果您尝试在 FakeTensorMode 区域内对真实张量执行计算,FakeTensorMode 不会自动 fakeify 真实张量。这背后的动机是为了防止以下 footgun
with FakeTensorMode():
real_tensor.t_()
此代码应该做什么?如果我们实际修改了真实张量上的元数据,那将是令人惊讶的。但与此同时,没有任何明显的创建 FakeTensor 的机会。因此,我们保守地决定引发一个错误:“尚不支持在 FakeTensorMode 中使用非 Fake Tensor 输入调用运算符。请先将所有 Tensor 转换为 FakeTensor。”
在实践中,此错误非常烦人。例如,假设您有一个真实的 nn.Module,并且您想通过它输入 fake tensor。您需要以某种方式 fakeify nn.Module。这促使了 FakeCopyMode 的出现。
最终,我们放弃了并添加了自动 fakeification。但是,在 FakeTensorMode 的许多用法中,默认情况下仍未启用此功能。
Fake tensor 上的元数据突变 如果您有一个 fake tensor,并且您对其执行 t_() 操作,则 fake tensor 上的元数据会发生更改。这表面上是合理的,但有时您也希望将 fake tensor 存储为 FX 节点上的元数据;突变 fake tensor 是不好的,因为这会使旧的元数据无效!
事实上,这里存在一个根本的张力,即 fake tensor 维护关于张量的极其准确的元数据,甚至包括对象标识。如果对象元数据随着时间的推移在 FX 图中发生更改,则实际上没有任何方法可以表示这种随时间的变化。在大多数情况下,我们对功能化图进行严肃的 FX 分析,这些图没有这种情况,但有时您需要在非功能化图上进行分析。将 fake tensor 放入 meta[‘val’] 可能是个错误
关于张量子类¶
Fake tensor 同时使用子类和模式张量子类模式,其中 FakeTensor.__torch_dispatch__ 启用与 fake tensor 关联的 FakeTensorMode,然后重新分派(依赖 FakeTensorMode 来完成繁重的工作)。如果 fake tensor 操作获得它无法识别的子类参数,它将返回 NotImplemented,让另一个子类有机会首先运行(希望反糖化为普通张量操作),然后再重试。这可能会导致无限循环。
每个算子是如何实现的?¶
不幸的是,任何给定的运算符都可能在相当复杂的多个位置实现。需要了解的一些重要情况
如果元素数量非常少,张量子类支持有限的常量传播(这有助于处理我们立即在这些张量上调用 item() 的某些情况)。
由于性能原因,我们为某些运算符提供了一些快速路径实现,这些实现完全在 fake tensor 中完成。
如果您使用 @custom_op 生成自定义张量,这些张量将直接向 fake tensor 注册 impl_abstract。
Fake tensor 本身对设备转换操作有一些硬编码的特殊情况。
如果没有元实现或任何分解,我们将生成真实的零填充张量并尝试直接运行运算符以找出结果。如果运算符尝试使用数据进行索引,这可能会导致段错误,因此我们默认情况下不会为自定义操作启用此功能。
转换器是如何工作的?¶
由于 fake tensor 用于对张量的确切属性非常敏感的情况,因此 fake tensor 非常仔细地进行转换,保留叶属性、requires_grad 属性、别名以及其他各种属性。大部分繁重的工作都在 MetaConverter 中完成。
性能特点¶
您可能认为 fake tensor 很快,因为它们不进行任何张量计算。但是在小张量大小下,我们实际上完全受开销限制,而且,嗯,fake tensor 是用 Python 编写的,我们经常需要做大量工作才能完成单个张量操作(因为它们被实现为分解)。因此,fake tensor 在实践中实际上相当慢,尤其是在涉及符号形状时。我们目前在 fake tensor 中有两个重要的快速路径,它们在实践中产生了很大的影响
逐点操作不会通过 PrimTorch 分解,相反,我们手动编码了它们的传播规则。
如果可能,我们应该这样做。
Fake tensor 的 Fake tensor?¶
有人有兴趣将 fake tensor 作为用户输入发送到 PT2 堆栈,这意味着我们需要能够创建一个 fake tensor 的 fake tensor。目前这还不太支持,但也许做起来不会太困难。
与动态形状的交互¶
每个 FakeTensorMode 都包含一个 ShapeEnv,它跟踪所有符号形状信息。它们的生命周期通常是绑定的:它们一起生,一起死。
由于 FakeTensorMode 具有 ShapeEnv(但 meta 实现没有),因此依赖于数据并且需要分配未备份 SymInt 的 meta 函数存在于 fake tensor 中。Fake tensor 还负责备忘录化未备份的 SymInt,以便例如,如果您在同一个 fake tensor 上调用 nonzero() 两次,您会得到相同的符号大小。