动态形状¶
另请参阅:动态形状手册
动机¶
深度学习编译器通常仅适用于静态形状,也就是说,它们生成的编译程序仅适用于输入形状的单个特定配置,如果任何输入形状发生变化,则必须重新编译。这种假设非常适合当今运行的大多数深度学习模型,但在某些情况下是不够的。
某些维度,例如批次大小或序列长度,可能会发生变化。例如,执行自适应批次处理的推理服务将根据其批次处理窗口内接收到的请求数量,以不同的批次大小执行推理请求。我们可能还希望考虑仅将可变大小的序列填充到批次内的最大序列长度,这在批次之间可能会有所不同。
某些模型表现出数据相关的输出形状,也就是说,它们的输出和中间结果的大小可能取决于实际输入数据,而输入数据在每次运行之间可能会有所不同。例如,检测模型可能首先生成可变数量的潜在边界框,然后再运行更昂贵的图像识别模型以识别主体是否在边界框中。边界框的数量取决于数据。
处理稀疏表示(例如稀疏张量、不规则张量和图神经网络)时,数据相关的形状特别重要。在所有这些情况下,要处理的数据量取决于问题的稀疏结构,这通常会以数据相关的方式发生变化。
在支持动态形状时,我们选择不支持动态秩程序,例如输入张量在维度上发生变化的程序,因为这种模式很少出现在现实世界的深度学习程序中,并且它避免了对形状的符号列表进行归纳推理的必要性。
简化的公共 API¶
PyTorch 2.1 中的默认动态行为是
PT2 默认情况下假设所有内容都是静态的。
如果我们由于大小变化而重新编译,我们将尝试将该大小重新编译为动态的(已更改的大小可能会在将来更改)。这种泛化可能会失败(例如,因为用户代码对相关大小执行条件分支或 PT2 中缺少动态形状支持)。如果您试图了解为什么 PT2 过度专门化了一些代码,请使用
TORCH_LOGS=dynamic
运行并查找说何时添加保护及其原因的“eval”条目。如果您提前知道某个东西是动态的,您可以使用
torch._dynamo.mark_dynamic(tensor, dim)
跳过第一次重新编译。如果您提前知道该维度可以取的min
和max
值,您可以指定torch._dynamo.mark_dynamic(tensor, dim, min=min, max=max)
如果您说
torch.compile(dynamic=False)
,我们将关闭重新编译时的自动动态形状,并始终为每个不同的尺寸重新编译。相反,如果您说torch.compile(dynamic=True)
,我们将尝试使所有内容尽可能动态。这主要对小型操作符有用;如果您在大型模型上尝试它,它将 (1) 可能使 PT2 崩溃,以及 (2) 毫无理由地运行缓慢。
保护模型¶
在考虑如何为 TorchDynamo 和 TorchInductor 添加对动态形状的支持时,我们做出了一个主要的设计决策:为了重用以 Python/C++ 编写的针对 PyTorch API 的分解和其他预先存在的代码,我们必须能够跟踪动态形状。与可能捕获条件两个分支的完全符号系统不同,我们始终选择一个分支并在假设将来我们会在该分支上做出相同选择的条件下专门化我们的跟踪。为此,我们为每个符号大小维护一个“提示”,指示其在编译时的具体值(因为 TorchDynamo 是一个即时编译器,它始终知道实际输入大小)。当我们在张量上执行条件时,我们只需查询提示以找出要采用哪个分支。
这极大地简化了我们生成的符号形状公式,但意味着我们拥有一个更复杂的系统来管理保护。例如,考虑以下程序
def f(x, y):
z = torch.cat([x, y])
if z.size(0) > 2:
return z.mul(2)
else:
return z.add(2)
我们将使用 TorchInductor 编译的最终 IR 将是 torch.cat([x, y]).add(2)
或 torch.cat([x, y]).mul(2)
(条件已扁平化),但为了确定我们处于哪个分支,我们需要知道 z
的大小,这是一个中间结果。因为 TorchDynamo 必须提前知道编译的跟踪是否有效(我们不支持类似某些 JIT 编译器的故障转移),所以我们必须能够将 z.size(0)
作为输入的表达式来减少,x.size(0) + y.size(0)
。这是通过为 PyTorch 中的所有操作符编写元函数来完成的,这些函数可以将大小信息传播到张量的输出,而无需实际对节点执行计算。
总体架构¶
符号形状工作流程
当我们在 Dynamo 中开始编译一个帧时,我们会分配一个 ShapeEnv(附加到 FakeTensorMode),它用于跟踪符号形状状态。
我们在进入时为张量分配符号大小(什么是静态或动态是一个策略决定,有一些旋钮)。
我们通过算子传播符号大小,同时维护 (1) FX IR,以便我们可以忠实地导出符号计算,以及 (2) 表示大小变量的 Sympy 表达式,以便我们可以对它们进行推理。
当我们在 Dynamo 跟踪或 Inductor 优化中根据符号大小进行条件判断时,我们会根据条件添加保护措施。这些保护措施可以从 Python 和 C++ 中推断出来。
这些保护措施可以对符号变量进行进一步简化。例如,如果你断言
s0 == 4
,我们现在可以将所有s0
的出现替换为4
。当我们完成跟踪和优化后,我们会将所有这些保护措施与编译后的代码一起安装;只有当所有保护措施都评估为真时,编译后的代码才能重复使用。
重要文件
C++ SymInt API:
c10/core/SymInt.h
,SymFloat.h
,SymBool.h
Python SymInt API:
torch/__init__.py
(查找SymInt/SymFloat/SymBool
)C++ 管道:
c10/core/SymNodeImpl.h
,torch/csrc/utils/python_symnode.h
,torch/csrc/jit/python/init.cpp
Python 基础设施:
torch/fx/experimental/symbolic_shapes.py
其他重要文件:
torch/_subclasses/fake_tensor.py
,torch/_meta_registrations.py
, 分解,PrimTorch 引用
简化内部 API¶
了解 Python 类层次结构
SymInt/SymFloat/SymBool:这些是用户可见的类,模拟它们的 int/float/bool 对应类。如果你添加两个 SymInt,我们会给你一个新的 SymInt,它会象征性地跟踪整数加法是否已经发生。
SymNode:这是内部结构(可以通过例如
symint.node
访问),它保存实际的符号跟踪信息。SymNode 是类型擦除的;这使得用它来表示混合类型操作更加方便。请注意,从技术上讲,你无需从 SymInt 调用 Python SymNode;例如,XLA 的 C++SymNodeImpl
将代替 SymNode。ShapeEnv:每个编译上下文状态,用于跟踪到目前为止积累的所有自由符号和保护措施。每个 SymNode 记录其 ShapeEnv(反之则不然;只有当 SymNode 参与保护措施时,它们才会被使用)。
C++ 相当类似
c10::SymInt/SymFloat/SymBool:模拟 int/float/bool 的用户可见类。
c10::SymNode/SymNodeImpl:类似于 SymNode
C++ 中没有 ShapeEnv;为了便于调试,整个符号推理机制都在 Python 中。
当你编写可以用 make_fx
跟踪的代码时,它必须能够处理 SymInt/SymFloat/SymBool 流经它。动态形状手册 提供了一些关于如何做到这一点的指导。
无支撑 SymInt¶
为了解决控制流问题,我们会检查符号整数的提示(也称为实际值),以确定要执行哪个分支。但是,在某些情况下,我们可能没有提示:当大小变量从像 .nonzero()
或 .item()
这样的数据相关操作中出现时,就会出现所谓的无支撑符号整数。对这些符号整数执行控制流是非法的,因此我们必须在这些操作上进行图形中断。
如果天真地实现,这将过于严格:如果你尝试使用无支撑符号整数执行任何操作,大多数 PyTorch 程序会立即失败。以下是使其真正起作用的最重要的增强功能
在张量创建时,PyTorch 会预先计算关于张量的大量数据;例如,如果你使用
empty_strided
创建张量,我们会急切地对步幅进行排序并确定张量是否是非重叠且密集的。排序会产生大量保护措施。但是,更常见的是使用像empty
这样的更高级别 API 直接生成张量,这保证会生成非重叠且密集的张量。我们修改了 PyTorch,避免不必要地重新计算这些属性。即使需要进行非平凡的计算,有时也根本不会查询某个属性。使这些预先计算的属性延迟化,使我们能够避免在无支撑符号整数上进行保护,除非它实际上需要。
整数张量中的数据通常不知道是非负的。但是,我们提供了
constrain_range
API,用户可以通过它指定大小的上限和下限。
在 PT2 的未来版本(超出 PT2.1)中,我们将扩展我们的推理系统,根据用法推断出无支撑符号整数是类似大小的。例如,如果你将 .item()
调用的结果传递给像 torch.empty
这样的工厂函数,我们会自动推断结果是一个大小(因为如果不是,它就会失败)。这个假设会在运行时得到验证,如果它没有得到满足,就会引发错误。