Fake tensor¶
动机¶
在进行 Dynamo 符号求值和编译器遍时,我们通常希望能够运行张量操作,以了解输出的尺寸/数据类型/设备是什么,而无需实际运行这些操作(或破坏现有张量),因为这样做会更慢(如果计算量很大)且占用大量内存(如果编译器在编译程序时需要使用 GPU 内存,那就不好了)。一个 fake tensor 在所有方面都像一个真实的张量,不同之处在于它实际上没有任何数据。例如,当我们进行 Dynamo 追踪时,我们需要追踪用户的 Tensor 代码,并回答关于中间结果的问题(例如,如果用户对中间张量进行条件判断)。没有 fake tensor,我们将无法获得这些查询的准确信息。
同样,假设您想存储张量的元数据,例如在 FX IR 节点 (meta['val']) 上。您可以直接在节点上存储一个 fake tensor,它将为您提供张量所需的所有元数据,包括一些您可能无法处理的细微之处(例如,别名关系)。
整体架构¶
所有 fake tensor 都与 FakeTensorMode 相关联。由于 fake tensor 的主要用例是对真实张量进行分析,因此一般的工作流程是:您有一堆真实张量,分配一个 FakeTensorMode,然后使用 from_real_tensor 将所有真实张量转换为 fake tensor,接着对这些 fake tensor 进行操作。特别地,FakeTensorMode 维护一个持久的备忘录表,将张量(和存储)映射到相同的存储。如果您多次 fakeify 同一个张量,您将得到同一个 fake tensor;如果您 fakeify 两个相互别名的张量,您将得到两个别名同一 fake storage 的 fake tensor。Fake tensor 是张量子类,因此如果您对它们进行操作,您会自动获得一个 fake tensor,但通常您会希望在 FakeTensorMode 激活的情况下对 fake tensor 进行操作(例如,如果您正在运行 FX pass);张量操作会自动开启 fake tensor 模式并重试。
Fake tensor 被表示为 meta tensor 的一个 __torch_dispatch__ 张量子类。这意味着在底层,fake tensor 是 meta device 张量;然后它们利用额外的可扩展性钩子,特别是 dispatch_device,来伪称张量的实际设备。这是 fake tensor 早期阶段更容易出错的部分之一:有时,fake tensor 太擅长伪装成 CPU/CUDA 或其他设备,结果会导致 CPU kernel 被调用,而 fake tensor 试图解引用数据指针,这显然行不通。如果在 fake tensor 代码中发生段错误,这是您首先应该检查的事情:C++ 回溯是在 CPU kernel 中(意外!)还是在 meta kernel 中(预期!)。meta kernel 类似于真实 kernel,但它只负责分配输出,不进行任何数据计算。
张量子类必须定义如何实现各种操作。以下是一般的 fake tensor 实现方法
在输入 fake tensor 上运行 meta kernel,并将它们重新解释为 meta tensor。这是通过一个神奇的上下文管理器 `in_kernel_invocation_manager` 完成的,该管理器指示 PyTorch 将 fake tensor 视为其底层 meta tensor,而不是将 fake tensor “展开”为 meta tensor(因为 fake tensor 就是 meta tensor)。fake tensor 以这种方式表示,以避免必须同步两组元数据(meta tensor 的元数据和 fake tensor 的元数据);“is a” 关系确保只有一份规范的元数据副本。
如果您是一个工厂函数,您将转而调用底层工厂函数,设备设置为 `device='meta'`。
将生成的 meta tensor 转换为 fake tensor,计算张量的输出设备应该是什么(这通常很简单,但有时并非如此,例如 CPU 标量提升或设备转换操作)。
API:重要部分¶
非 PT2 用法(更多示例请查看 test/test_fake_tensor.py)
# Create a fake mode
from torch._subclasses.fake_tensor import FakeTensorMode
fake_mode = FakeTensorMode()
converter = fake_mode.fake_tensor_converter
# Fakeify some real tensors
fake_x = converter.from_real_tensor(fake_mode, x)
with fake_mode:
# Do some operations on the fake tensors
fake_y = fake_x * 2
# Factory operations automatically get fakeified in the context manager
fake_z = torch.empty(20)
问:为什么输入是真实张量?
答:在 PT2 环境中,这是因为您通常是即时编译,因此对于您正在编译的图的所有输入,您已经有了“真实”输入,因为您在执行程序时进行编译。
PT2 pre-AOTAutograd 用法(这不常见,您可能不希望这样做)
# Fake mode is not enabled!
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode(args)
# if fake_mode isn't None
converter = fake_mode.fake_tensor_converter
fake_args = [converter.from_real_tensor(fake_mode, arg) for arg in args]
with fake_mode:
... # do stuff with the fake args, if needed ...
`detect_fake_mode` 会搜索多个位置,尝试找到与生命周期相关的“那个”fake tensor 模式。通常它会从追踪上下文中获取。
PT2 post-AOTAutograd 用法
# Fake mode is enabled! example_inputs is typically fake already
# TODO: we probably want to change this
# Still do this to access fake mode
fake_mode = detect_fake_mode(example_inputs)
# But in general you don't have to turn it on
其他有用内容
from torch._subclasses.fake_tensor import unset_fake_temporarily
with unset_fake_temporarily():
... # fake mode is disabled here, you can do real tensor compute
您何时可能想要禁用 fake tensor 模式?通常您不希望这样做。一个我们发现有用的特殊情况是在 fake tensor 上实现常量传播:在这种情况下,即使我们处于 fake tensor 模式,也需要进行一些实际的张量计算。
import FakeTensorProp from torch.fx.passes.fake_tensor_prop
gm: GraphModule
real_inputs: List[Tensor]
FakeTensorProp(gm).propagate(*real_inputs)
# This will populate meta['val'] on all the FX nodes with a fake tensor
# or if you have a preexisting fake mode, you should use it
FakeTensorProp(gm, mode=fake_mode).propagate(*real_inputs)
# There is also propagate_dont_convert_inputs if your inputs are already fake
fake_inputs: List[FakeTensor]
FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(*fake_inputs)
详细信息¶
是否自动转换?
with FakeTensorMode():
real_tensor.t_()
最初,如果您尝试在 FakeTensorMode 区域内对真实张量进行计算,FakeTensorMode 不会自动 fakeify 这些真实张量。这样做是为了防止以下陷阱
这段代码应该做什么?如果我们真的修改了真实张量的元数据,那会很令人惊讶。但与此同时,也没有任何明显的机会去创建一个 FakeTensor。因此,我们保守地决定让它抛出一个错误:“在 FakeTensorMode 中使用非 Fake Tensor 输入调用算子尚不受支持。请先将所有 Tensor 转换为 FakeTensor。”
这个错误在实践中相当烦人。例如,假设您有一个真实的 nn.Module,并且想让 fake tensor 通过它。您需要以某种方式 fakeify 这个 nn.Module。这促使了 FakeCopyMode 的出现。
最终,我们放弃了限制,并添加了自动 fakeification 功能。然而,在许多 FakeTensorMode 的使用场景中,此功能默认尚未启用。
fake tensor 上的元数据修改
如果您有一个 fake tensor 并对其调用 `t_()`,则该 fake tensor 上的元数据会改变。这表面上看来合理,但有时您也希望将 fake tensor 作为元数据存储在 FX 节点上;修改 fake tensor 是不好的,因为它会使旧的元数据失效!
事实上,这里存在一个根本性的矛盾,即 fake tensor 维护着极其准确的张量元数据,包括对象标识。如果 FX 图中的对象元数据随时间变化,实际上没有任何方法可以表示这种随时间的变化。大多数时候,我们的重要 FX 分析是在 函数化图 上进行的,这些图没有这个问题,但偶尔您需要在 非函数化图 上进行分析。也许将 fake tensor 放在 meta['val'] 中是一个错误。
关于张量子类¶
Fake tensor 同时使用了子类和 mode 张量子类模式,其中 FakeTensor.__torch_dispatch__ 启用与 fake tensor 关联的 FakeTensorMode,然后重新调度(依赖 FakeTensorMode 完成繁重工作)。如果 fake tensor 操作收到一个它不认识的子类参数,它将返回 NotImplemented,让其他子类有机会先运行(希望能去糖化为普通张量操作),然后再尝试。这可能导致无限循环。
每个算子是如何实现的?¶
不幸的是,任何给定算子的实现位置都相当复杂。一些需要了解的重要情况包括:
如果元素数量非常小,张量子类支持有限的常量传播(这有助于处理一些我们立即对此类张量调用 `item()` 的情况)。
出于性能考虑,我们对某些算子有一些快速路径实现,这些实现完全在 fake tensor 中完成。
如果您使用 `@custom_op` 生成自定义张量,这些将直接向 fake tensor 注册 `impl_abstract`。
Fake tensor 本身对设备转换操作有一些硬编码的特殊情况。
如果没有 meta 实现或任何分解,我们将生成真实的零填充张量,并尝试直接运行算子以确定结果。如果算子尝试使用数据进行索引,这可能导致段错误,因此我们默认不为自定义算子启用此功能。
转换器是如何工作的?¶
由于 fake tensor 用于对张量精确属性非常敏感的情况,因此 fake tensor 会非常仔细地进行转换,保留 leaf 属性、requires_grad 属性、别名关系以及许多其他属性。大部分繁重工作由 MetaConverter 完成。
性能特征¶
您可能会认为 fake tensor 速度很快,因为它不进行任何张量计算。但在张量尺寸较小时,我们实际上完全受开销限制,而且 fake tensor 是用 Python 实现的,我们通常需要做很多工作才能完成一个张量操作(因为它们被实现为分解)。因此,fake tensor 在实践中实际上相当慢,尤其是在涉及符号形状时。目前我们在 fake tensor 中有两个重要的快速路径,它们在实践中起着重要作用
逐点算子不经过 PrimTorch 分解,而是我们手动编码了它们的传播规则。
如果可能,我们应该这样做。
Fake tensor 的 fake tensor?¶
有人有兴趣将 fake tensor 作为用户输入发送到 PT2 堆栈中,这意味着我们需要能够创建 fake tensor 的 fake tensor。这目前尚不支持,但也许做起来不会太难。
与动态形状的交互¶