快捷方式

动态形状

代码: symbolic_shapes.py

另请参阅: 动态形状手册

动机

深度学习编译器通常只适用于静态形状,也就是说,它们生成的编译程序只针对输入形状的一种特定配置有效,如果任何输入形状发生变化,则必须重新编译。这个假设对于当今大多数常用的深度学习模型来说效果很好,但在某些情况下它是不够的

  • 某些维度,如批大小或序列长度,可能会变化。例如,执行自适应批处理的推理服务将根据其批处理窗口中接收到的请求数量,以变化的批大小执行推理请求。我们可能还会考虑将变长序列填充到批次内的最大序列长度,该长度可能因批次而异。

  • 一些模型表现出数据依赖的输出形状,也就是说,它们的输出和中间结果的大小可能取决于实际输入数据,而输入数据在不同运行中可能会变化。例如,检测模型可能首先生成可变数量的潜在边界框,然后运行更昂贵的图像识别模型来识别主体是否在边界框内。边界框的数量取决于数据。

  • 处理稀疏表示(如稀疏张量、不规则张量和图神经网络)时,会遇到一种特别重要的数据依赖形状情况。在所有这些情况下,需要处理的数据量取决于问题的稀疏结构,而稀疏结构通常会以数据依赖的方式变化。

在支持动态形状时,我们选择不支持动态秩(rank)的程序,例如输入张量维度变化的程序,因为这种模式在实际深度学习程序中很少出现,并且可以避免对符号形状列表进行归纳推理的需要。

公共 API 摘要

PyTorch 2.1 中的默认动态行为是

  • PT2 默认假定一切都是静态的

  • 如果我们因为大小改变而重新编译,我们将尝试将该大小重新编译为动态的(改变的大小在将来很可能会再次改变)。这种泛化可能会失败(例如,因为用户代码对相关大小进行了条件分支,或者 PT2 中缺少动态形状支持)。如果你试图理解为什么 PT2 对某些代码进行了过度特殊化,请使用 TORCH_LOGS=dynamic 运行,并查找指示何时添加守卫(guards)以及原因的“eval”条目。

  • 如果你提前知道某些东西将是动态的,你可以使用 torch._dynamo.mark_dynamic(tensor, dim) 跳过首次重新编译。如果你提前知道该维度可以接受的最小和最大值,你可以指定 torch._dynamo.mark_dynamic(tensor, dim, min=min, max=max)

  • 如果你指定 torch.compile(dynamic=False),我们将关闭重新编译时的自动动态形状,并始终为每个不同的尺寸重新编译。反之,如果你指定 torch.compile(dynamic=True),我们将尝试使一切尽可能地动态化。这对于小型算子(operators)非常有用;如果你在一个大型模型上尝试这样做,它很可能会(1)使 PT2 崩溃,并且(2)运行缓慢而没有任何好处。

守卫模型

在考虑如何向 TorchDynamo 和 TorchInductor 添加动态形状支持时,我们做了一个重大的设计决策:为了重用用 Python/C++ 编写并针对 PyTorch API 的分解(decompositions)和其他现有代码,我们必须能够追踪(trace)动态形状。与完全符号化系统(可能会捕获条件分支的两条路径)不同,我们总是选择一个分支并专门化我们的追踪,假设只有当未来做出相同分支选择时,我们才会使用这个追踪。为此,我们为每个符号大小维护一个“提示”,说明其在编译时(由于 TorchDynamo 是即时编译器,它总是知道实际输入大小)的具体值。当我们在张量上执行条件判断时,我们只需查询该提示即可知道应选择哪个分支。

这极大地简化了我们生成的符号形状公式,但意味着我们有一个更复杂的守卫管理系统。例如,考虑以下程序

def f(x, y):
    z = torch.cat([x, y])
    if z.size(0) > 2:
        return z.mul(2)
    else:
        return z.add(2)

我们将使用 TorchInductor 编译的最终 IR 将是 torch.cat([x, y]).add(2)torch.cat([x, y]).mul(2)(条件被消除),但要确定我们处于哪个分支,我们需要知道中间结果 z 的大小。因为 TorchDynamo 必须提前知道编译后的追踪是否有效(我们不支持像某些 JIT 编译器那样的退出),我们必须能够将 z.size(0) 作为一个表达式,用输入 x.size(0) + y.size(0) 来表示。这是通过为 PyTorch 中的所有算子编写元函数(meta functions)来实现的,这些元函数可以在不实际对节点进行计算的情况下将大小信息传播到张量的输出。

整体架构

符号形状工作流程

  1. 当我们在 Dynamo 中开始编译一个帧时,我们分配一个 ShapeEnv(附属于 FakeTensorMode),用于跟踪符号形状状态。

  2. 我们在入口处为张量分配符号大小(静态或动态是一个策略决定,有一些控制选项)。

  3. 我们通过算子传播符号大小,同时维护 (1) FX IR,以便忠实地导出符号计算,以及 (2) 代表尺寸变量的 Sympy 表达式,以便我们能够对其进行推理。

  4. 当我们在 Dynamo 追踪或 Inductor 优化中对符号大小设置条件时,我们会根据条件添加守卫。这些守卫可以由 Python 和 C++ 代码引起。

  5. 这些守卫可以进一步简化符号变量。例如,如果你断言 s0 == 4,我们现在可以将所有出现的 s0 替换为 4

  6. 完成追踪和优化后,我们将所有这些守卫与编译后的代码一起安装;只有当所有守卫都评估为真时,编译后的代码才能被重用。

重要文件

  • C++ SymInt API: c10/core/SymInt.h, SymFloat.h, SymBool.h

  • Python SymInt API: torch/__init__.py (查找 SymInt/SymFloat/SymBool)

  • C++ 底层实现: c10/core/SymNodeImpl.h, torch/csrc/utils/python_symnode.h, torch/csrc/jit/python/init.cpp

  • Python 基础设施: torch/fx/experimental/symbolic_shapes.py

  • 其他重要文件: torch/_subclasses/fake_tensor.py, torch/_meta_registrations.py, decomps, PrimTorch refs

内部 API 摘要

理解 Python 类层次结构

  • SymInt/SymFloat/SymBool: 这些是用户可见的类,模拟其对应的 int/float/bool 类型。如果你将两个 SymInt 相加,我们将给你一个新的 SymInt,以符号方式跟踪发生了整数加法。

  • SymNode: 这是内部结构(例如通过 symint.node 访问),用于保存实际的符号跟踪信息。SymNode 是类型擦除的;这使得表示混合类型操作更加方便。请注意,从技术上讲,你不必从 SymInt 调用 Python SymNode;例如,XLA 的 C++ SymNodeImpl 将取代 SymNode。

  • ShapeEnv: 每个编译上下文的状态,用于跟踪我们迄今为止积累的所有自由符号和守卫。每个 SymNode 都记录其 ShapeEnv(反之则不然;SymNode 只有在参与守卫时才会被使用)。

C++ 的情况也类似

  • c10::SymInt/SymFloat/SymBool: 用户可见的类,模拟 int/float/bool 类型。

  • c10::SymNode/SymNodeImpl: 类似于 SymNode

  • C++ 中没有 ShapeEnv;为了便于调试,整个符号推理机制都在 Python 中。

当你编写可以使用 make_fx 进行追踪的代码时,它必须能够处理流经其中的 SymInt/SymFloat/SymBool。动态形状手册提供了一些指导,说明如何做到这一点。

DimDynamic 策略

符号推理

  • 取值范围

  • Sympy 使用注意事项

  • 约束

  • DimDynamic/Constraint

无支持的 SymInt

为了解析控制流,我们检查符号整数的提示(即实际值)来确定选择哪个分支。然而,在某些情况下,我们可能没有提示:所谓的无支持符号整数出现在大小变量由数据依赖的操作(如 .nonzero().item())产生时。对这些符号整数执行控制流是非法的,因此我们必须在这些操作上进行图中断(graph break)。

如果天真地实现,这将过于严格:如果你尝试对无支持的符号整数进行任何操作,大多数 PyTorch 程序会立即失败。以下是使这实际工作起来的最重要的增强功能

  • 在创建张量时,PyTorch 会预先计算关于张量的很多数据;例如,如果你使用 empty_strided 创建张量,我们会急切地对步长(strides)进行排序,并确定张量是否是非重叠且稠密的。排序会产生很多守卫。然而,更常见的是直接使用像 empty 这样的高级 API 创建张量,该 API 保证生成非重叠且稠密的张量。我们修改了 PyTorch,以避免不必要地重新计算这些属性。

  • 即使需要进行非平凡计算,有时某个属性也从未被实际查询过。将这些预计算属性设为惰性(lazy)可以让我们避免对无支持的符号整数添加守卫,除非确实需要。

  • 整数张量中的数据通常不保证是非负的。然而,我们提供了一个 API constrain_range,用户可以通过它指定某个大小的上下界由已知限制确定。

在 PT2 的未来版本(PT2.1 之后),我们将扩展我们的推理系统,根据使用情况推断无支持符号整数是类似尺寸的。例如,如果你将 .item() 调用的结果传递给像 torch.empty 这样的工厂函数,我们将自动推断结果是一个尺寸(因为如果不是,它就会失败)。这个假设将在运行时得到验证,如果不满足则会引发错误。

文档

查阅 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并解答疑问

查看资源