快捷方式

动态形状

代码: symbolic_shapes.py

另请参阅: 动态形状手册

动机

深度学习编译器通常仅适用于静态形状,也就是说,它们生成的编译程序仅适用于输入形状的单个特定配置,并且如果任何输入形状发生变化,则必须重新编译。这种假设对于当今大多数常用的深度学习模型来说都非常有效,但在某些情况下,它是不够的。

  • 某些维度,例如批大小或序列长度,可能会发生变化。例如,执行自适应批处理的推理服务将根据其批处理窗口内接收到的请求数量,以不同的批大小执行推理请求。我们可能还希望考虑仅将可变大小的序列填充到批处理中的最大序列长度,这在批处理之间可能会有所不同。

  • 某些模型表现出数据相关的输出形状,也就是说,它们的输出和中间结果的大小可能取决于实际的输入数据,这些数据在运行之间可能会有所不同。例如,检测模型可能会首先生成数量可变的潜在边界框,然后运行更昂贵的图像识别模型来识别目标是否在边界框内。边界框的数量是数据相关的。

  • 处理稀疏表示(例如稀疏张量、锯齿状张量和图神经网络)时,数据相关的形状的一个特别重要的案例出现了。在所有这些情况下,要处理的数据量取决于问题的稀疏结构,这通常会以数据相关的方式发生变化。

在支持动态形状时,我们选择不支持动态秩程序,例如输入张量在维度上发生变化的程序,因为这种模式在实际的深度学习程序中很少出现,并且它避免了需要对形状的符号列表进行归纳推理。

简化后的公共 API

PyTorch 2.1 中的默认动态行为是

  • PT2 默认假设所有内容都是静态的。

  • 如果我们因为大小发生变化而重新编译,我们将尝试将该大小重新编译为动态(发生变化的大小将来可能会发生变化)。这种泛化可能会失败(例如,因为用户代码对相关大小进行了条件分支,或者 PT2 中缺少动态形状支持)。如果您试图了解为什么 PT2 对某些代码进行了过度专门化,请使用 TORCH_LOGS=dynamic 运行并查找“eval”条目,这些条目会说明何时添加了保护以及原因。

  • 如果您提前知道某些内容将是动态的,您可以使用 torch._dynamo.mark_dynamic(tensor, dim) 跳过第一次重新编译。如果您提前知道此维度可以取的 minmax 值,您可以指定 torch._dynamo.mark_dynamic(tensor, dim, min=min, max=max)

  • 如果您说 torch.compile(dynamic=False),我们将关闭重新编译时的自动动态形状,并始终为每个不同的尺寸重新编译。相反,如果您说 torch.compile(dynamic=True),我们将尝试使所有内容尽可能动态。这主要对小型运算符有用;如果您在大型模型上尝试它,它将 (1) 可能导致 PT2 崩溃,以及 (2) 无缘无故地运行缓慢。

保护模型

在考虑如何向 TorchDynamo 和 TorchInductor 添加对动态形状的支持时,我们做了一个重大的设计决策:为了重用以 Python/C++ 编写的针对 PyTorch API 的分解和其他预先存在的代码,我们必须能够跟踪动态形状。与可能捕获条件两个分支的完全符号系统不同,我们始终选择一个分支,并在假设我们将来在该分支中做出相同选择的条件下专门化我们的跟踪。为此,我们为每个符号大小维护一个“提示”,说明其在编译时的具体值是什么(因为 TorchDynamo 是一个即时编译器,它始终知道实际的输入大小是什么)。当我们对张量执行条件时,我们只需查阅提示以找出要采用哪个分支。

这极大地简化了我们生成的符号形状公式,但意味着我们有一个更复杂的系统来管理保护。例如,考虑以下程序

def f(x, y):
    z = torch.cat([x, y])
    if z.size(0) > 2:
        return z.mul(2)
    else:
        return z.add(2)

我们将使用 TorchInductor 编译的最终 IR 将是 torch.cat([x, y]).add(2)torch.cat([x, y]).mul(2)(条件已展开),但要确定我们所处的哪个分支,我们需要知道中间结果 z 的大小,即 z.size(0),因为 TorchDynamo 必须预先知道编译的跟踪是否有效(我们不支持像某些 JIT 编译器那样的回退),我们必须能够将 z.size(0) 作为输入的表达式来减少,即 x.size(0) + y.size(0)。这是通过为 PyTorch 中的所有运算符编写元函数来完成的,这些函数可以在不实际对节点执行计算的情况下将大小信息传播到张量的输出。

总体架构

符号形状工作流程

  1. 当我们开始在 Dynamo 中编译一个帧时,我们分配一个 ShapeEnv(附加到 FakeTensorMode),它跟踪符号形状状态。

  2. 我们在输入时为张量分配符号大小(静态或动态是一个策略决策,有一些旋钮可以调整)。

  3. 我们将符号大小通过算子传播,同时维护 (1) FX IR 以便我们可以忠实地导出符号计算,以及 (2) 表示大小变量的 Sympy 表达式,以便我们可以对其进行推理。

  4. 当我们在 Dynamo 追踪或 Inductor 优化中根据符号大小进行条件判断时,我们会根据条件添加保护语句。这些保护语句可以从 Python 和 C++ 中生成。

  5. 这些保护语句可以对符号变量进行进一步的简化。例如,如果你断言 s0 == 4,我们现在就可以用 4 替换所有 s0 的出现。

  6. 当我们完成追踪和优化后,我们会将所有这些保护语句安装到编译后的代码中;只有当所有保护语句都计算为真时,编译后的代码才能重复使用。

重要文件

  • C++ SymInt API:c10/core/SymInt.hSymFloat.hSymBool.h

  • Python SymInt API:torch/__init__.py(查找 SymInt/SymFloat/SymBool

  • C++ 管道:c10/core/SymNodeImpl.htorch/csrc/utils/python_symnode.htorch/csrc/jit/python/init.cpp

  • Python 基础设施:torch/fx/experimental/symbolic_shapes.py

  • 其他重要文件:torch/_subclasses/fake_tensor.pytorch/_meta_registrations.py、decomps、PrimTorch refs

简化内部 API

理解 Python 类层次结构

  • SymInt/SymFloat/SymBool:这些是用户可见的类,模拟它们的 int/float/bool 对应类。如果你添加两个 SymInt,我们会给你一个新的 SymInt,它会象征性地跟踪整数加法操作的发生。

  • SymNode:这是内部结构(可以通过例如 symint.node 访问),它保存实际的符号跟踪信息。SymNode 是类型擦除的;这使得表示混合类型操作更加方便。请注意,从技术上讲,你不必从 SymInt 调用 Python SymNode;例如,XLA 的 C++ SymNodeImpl 将代替 SymNode。

  • ShapeEnv:每个编译上下文状态,跟踪我们迄今为止积累的所有自由符号和保护语句。每个 SymNode 都记录其 ShapeEnv(但反之则不然;只有当 SymNode 参与保护语句时才会使用)。

C++ 非常相似

  • c10::SymInt/SymFloat/SymBool:模拟 int/float/bool 的用户可见类。

  • c10::SymNode/SymNodeImpl:类似于 SymNode

  • C++ 中没有 ShapeEnv;为了方便调试,整个符号推理机制都在 Python 中。

当你编写可以用 make_fx 追踪的代码时,它必须能够处理流经它的 SymInt/SymFloat/SymBool。动态形状手册 提供了一些关于如何执行此操作的指导。

DimDynamic 策略

符号推理

  • 值范围

  • Sympy 使用说明

  • 约束

  • DimDynamic/约束

无后备 SymInt

为了解决控制流问题,我们检查符号整数的提示(即实际值),以确定要进入哪个分支。但是,在某些情况下,我们可能没有提示:当大小变量从数据相关的操作(如 .nonzero().item())出现时,就会出现所谓的无后备符号整数。在这些符号整数上执行控制流是非法的,因此我们必须在这些操作上进行图中断。

如果简单地实现,这限制性太强了:如果你尝试对无后备符号整数执行任何操作,大多数 PyTorch 程序都会立即失败。以下是如何使其真正起作用的最重要的增强功能

  • 在张量创建时,PyTorch 会预先计算大量关于张量的数据;例如,如果你使用 empty_strided 创建张量,我们会急切地对步长进行排序,并确定张量是否是非重叠且密集的。排序会产生很多保护语句。但是,更常见的是使用更高级别的 API(如 empty)直接生成张量,这保证会生成非重叠且密集的张量。我们修改了 PyTorch 以避免不必要地重新计算这些属性。

  • 即使需要非平凡的计算,有时也根本不会查询某个属性。使这些预先计算的属性变为惰性,使我们能够避免在无后备符号整数上进行保护,除非它实际上是必需的。

  • 整数张量中的数据通常不知道是非负的。但是,我们提供了一个 API constrain_range,用户可以通过它指定大小受已知限制的上限和下限。

在 PT2 的未来版本(PT2.1 之后),我们将扩展我们的推理系统,以根据使用情况推断无后备符号整数是否类似于大小。例如,如果你将 .item() 调用的结果传递给工厂函数(如 torch.empty),我们将自动推断结果是一个大小(因为如果不是,它将失败)。此假设将在运行时得到验证,如果未满足,则会引发错误。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取适合初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源