PyTorch/XLA 中的重新编译来源¶
我们首先从一些事实/约束开始:¶
XLA 中的图编译成本很高。
XLA 只处理静态形状。换句话说,即使是相同的 IR 图,当输入形状改变时,XLA 也会重新编译。
重新编译发生时会极大地损害 torch_xla 的性能,从普通 Python 用户的角度来看,这很难理解和调试。
重新编译发生时,我们常常说只需要支持动态形状,然后就可以放心地认为将来支持动态形状后,所有的重新编译都会神奇地消失。但这并非事实,XLA 现在已经有了相当不错的有限动态形状(bounded dynamic shapes)覆盖,但我们仍然会看到重新编译,而且这是预期中的。
本文档旨在详细解释几种常见的重新编译来源,以及我们需要做些什么来消除它们。本文将主要侧重于向没有任何背景的初学者解释问题。为了便于理解,这里提出的“解决方案”可能依赖于不切实际的假设。
#1. 来自输入数据集。¶
是的,输入数据集包含具有不同形状的示例(例如长度可变的句子或大小不同的图像)是很常见的。如果不进行归一化,对于每个新的输入形状都会导致重新编译。
TensorFlow 图模式用户更习惯于进行填充/分桶(padding/bucketization)(tf.pad
),将输入形状归一化到一两个桶中。但这对于 PyTorch eager 前端用户(lazy tensor 前端也旨在针对此类用户)来说,有点像反模式,因为不同的输入形状对 eager CPU/CUDA 后端来说无关紧要。
建议的权宜之计:好的,现在假设我们可以通过教用户进行填充/分桶来解决这个问题(实践中这很难 :P)。下一步是什么?
#2. 来自算子输出。¶
有些算子在语义上是数据依赖的,会产生动态形状输出:例如,torch.nonzero
返回输入张量中非零元素的索引。因此,即使你提供给这个算子的输入张量形状始终相同,它也可能产生不同形状的输出,从而导致重新编译。
2.1 有限动态形状可以解决将动态形状张量作为张量使用而不查询其实际维度的情况。¶
建议的权宜之计:现在假设 XLA 支持所有算子的有限动态形状,这就足够了吗?
有限动态形状意味着我们可以将张量填充到理论上的最大值,以增加内存使用量来换取更少的重新编译/更快的速度。
嗯,可以说是吧。让我们看看下面的例子
a = torch.tensor([1, 2, 0, 1, 3], device='xla')
b = torch.nonzero(a)
c = b * 2
d = c + 1
print(torch_xla._XLAC._get_xla_tensors_text([d]))
在上面的例子中,图中 b
下面的每个节点(即 c, d
以及所有依赖于它们的节点)都将具有动态形状,很明显 b
在维度 0 上具有动态形状,如下所示
%9 = (s64[<=5,1]{1,0}, s64[]) aten::nonzero(%8), num_outputs=2 # b
%10 = s64[5,1]{1,0} aten::mul(%9.0, %3) # c
%11 = s64[5,1]{1,0} aten::add(%10, %2), ROOT=0 # d
虽然没有直接在图中显示,但 c & d
确实也具有动态形状(换句话说,[5, 1] 只是填充后的形状,并且是被掩码的)。
print(torch_xla._XLAC._get_xla_tensor_dimension_size(d, 0)) # prints 4 instead of 5
你可以看到,在这种情况下,只要输入张量 a
的形状是 [5]
,我们就只需要编译一次图。有限动态形状支持起作用了!
2.2 如果在具有动态形状的张量上查询实际维度怎么办?¶
这实际上非常常用,因为并非所有 PyTorch 计算都以张量的形式完成。
例如,PyTorch 中的 tensor.size()
返回一个整数元组而不是一个 dtype=int 的张量。当 tensor
是一个动态形状张量时,这个操作基本上会强制 XLA 切断图并进行求值,以便我们可以返回正确的标量(否则它只会返回填充后的形状,这是错误的)。
更糟的是,许多 PyTorch 操作也接受标量输入。在你执行 s = tensor.size(0)
并将 s
用于其他算子后,它也变成了动态源。在这种情况下,我们可能知道如何填充它及其上限,但我们做不到,因为它甚至不是一个张量!
a = torch.tensor([1, 2, 0, 1, 3], device='xla')
b = torch.nonzero(a)
s = a.size(0) # evaluation happens! nit: we use size() for simplicity, the actual API is _get_xla_tensor_dimension_size.
c = torch.rand(s, device='xla') # c can be of any shape between [0, 5] which causes more recompilations!
d = c + 1
所以,如果没有 PyTorch 前端的帮助,这个问题实际上很难解决。我们需要什么?
简而言之,我们需要一个“张量世界”!
例如,
tensor.size()
应该返回一个张量,这样它就可以成为一个具有动态形状的张量并保留在图中,而无需提前求值。张量访问器,例如对于二维张量,
tensor[0][0]
现在返回一个值,但这也需要返回一个张量。隐含地,这意味着目前接受 int/float/double 作为输入的所有算子也需要一个张量重载。这是一个很大的要求,因为它很容易使我们的算子集急剧增加。
如果我们能让标量到张量的转换成本非常低,这样我们就只需要关心张量重载了,那会更容易。
在实践中,并非所有操作都接受来自之前计算的标量,因此我们一直在通过临时请求添加张量变体。
我认为这也是基于 tracing 方法的常见需求。
好的,现在我们假设 PyTorch 中的每个操作都有我们需要的张量版本,我们完成了吗?
#3. 来自控制流。¶
不!我们实际上只解决了没有数据依赖控制流的问题……
参见下面的例子
if x[0][0] == 3:
bla
else:
blabla
即使 x[0][0]
是一个张量,我们也需要执行/具体化其值以便 Python 解释器继续执行。而且多个控制流中不同的分支选择组合意味着我们也有很多图需要编译!
目前我们还没有办法解决这个问题。要解决它,我们需要将控制流从 Python 降低到图!不必深入考虑实现细节,我们可以通过两种方式做到这一点
要求用户明确使用控制流算子代替 Python 的 if/else/while/for 语句。这目前作为 torch_xla 中的自定义 API 得到支持,但在用户代码中并未广泛采用。(Python 用户习惯了 if/else/for,除非性能有巨大提升,否则很难让他们转向一个更丑陋的 API)。
自动解析 Python 源代码以获取控制流语句。这类似于 TorchScript,并以某种方式将 torchscripted 图正确地合并到惰性跟踪的图中(包括形状信息等)。我确实还没有完全想清楚如何实现这一步骤 :P
但这两种解决方案都需要相当大的努力,无论是用户端还是框架端。这就是为什么考虑到我们目前的可用资源,我们目前只是接受早期求值和多次编译的代价作为短期解决方案。
好的,现在我们假设控制流也能自动地降低到图中,我们就大功告成了吗?
是的!现在你有了以张量操作图表示的整个计算过程,包括控制流,这样编译器就可以消费并运用他们的聪明技巧了!但老实说,此时你的程序已经不再那么“PyTorch 式”了。
结论:¶
重新编译实际上有多种来源,而有限动态形状支持并不能解决所有问题。本文档中提出的权宜之计有时确实不切实际,而且可能有更好的方法来妥善解决每个来源的问题,只是我完全不了解。但我希望,在我们在这篇文档中不断朝着理想的惰性张量栈(lazy tensor stack)前进的过程中,现在你能更容易理解我们面前还存在哪些障碍。
附录:¶
NNC 使用符号形状(symbolic shapes),这有帮助吗?
有帮助,但只是部分。通过拥有符号形状,你的编译优化不再需要具体的形状值。换句话说,你生成的内核比 XLA 的静态形状内核更通用。
具体解决了哪个问题?
它有助于解决 #1 和 #2.1 等情况。
shape [3, 5] -> add -> transpose -> ... -> mul
shape [6, 2] -> add -> transpose -> ... -> mul
# with symbolic shape
shape [x, y] -> add -> transpose -> ... -> mul
有了符号形状,你生成的内核不会像 XLA 处理静态形状那样重新编译。
XLA 以另一种方式解决这个问题,通过使用填充/分桶(针对 #1)和有限动态形状(针对 #2.1)。
Brian Hirsh (@bdhirsh) 在评论中提出了一些非常好的问题,移到这里以便更显眼
对于产生数据依赖输出形状的算子的 XLA 内核,是否值得加入
TORCH_WARN
警告?
是的,torch_warn
对于告诉用户“嘿,你的程序不会飞速运行”很有用。但对于这些数据依赖的算子,除非用户改变模型中的逻辑,否则没有简单的重写方法。(另一个例子是 torch.unique()
)
像
nonzero()
这样的算子如何影响我们 devirtualizesizes()
的能力?如果我们想 devirtualizesizes()
,我们需要能够为每个算子急切地计算大小——那是否意味着我们每次遇到像nonzero()
这样的算子时,都会被迫对图进行求值?与现在的情况相比,听起来用户调用nonzero()
时我们实际上并没有强制求值?
是的,很好的问题!所以以目前的形式来看,它不是一个硬性阻碍,因为 XLA 张量上的 size()
不包含真实大小信息源。如示例所示,真实大小信息源存在于 IRValue
中,并且只能通过 _get_xla_tensor_dimension_size
获取。因此,如果我们决定 devirtualize 大小,只会加剧这种差异。
作为补充,如果我们让 size()
返回张量而不是值,就像上面提出的权宜之计中提到的那样。在这种情况下,size()
将无法 devirtualize,因为它变成了一个算子(接受张量输入并产生张量输出,对于不同的后端有不同的实现)。
如果我在循环中调用
torch.add(input, 1)
,其中 input 的大小从 1 变化到 1000,通常我们必须编译 1000 个不同的图——但有了动态形状,听起来 XLA 内部将能够生成一个单一的图,其中写着“如果输入大小 <=1000 则使用此图”。我的问题是:“动态形状”仅仅是图的一个属性吗?还是图和输入都具有该属性。也就是说,如果我的代码改为在循环中调用x = torch.add(input, 1); x.sizes()
,此时 x 是否具有动态形状,这意味着我们需要运行图才能获取大小?或者即使存在具有动态形状的图,我们是否也能使其成为一个急切计算的属性?
是的,在这种情况下你会编译 1000 个不同的图。动态形状意味着其输入具有动态维度。所以当你查询 x.sizes()
(目前需要使用 get_dimention_size
来获取正确的大小)时,它会触发*执行*(因为大小没有变化,所以不会触发重新编译)。如果没有访问大小的那一行,当输入具有动态维度时,它不会触发任何重新编译/执行。
将控制流集成到图中的另一种方法是,想办法确保 XLA 图不包含控制流?例如,如果我们有一个中间带有单个条件分支的模型,然后让 XLA 生成 3 个图:1 个用于条件分支之前的所有内容,1 个用于 if 分支,1 个用于 else 分支。这意味着你不会因为每种路径组合而导致新图的数量呈指数级增长,但 (a) 图会更小,提供的优化机会更少,并且 (b) 让 XLA 识别出条件路径在哪里被执行可能相当困难。
很好的观点!所以如果我们能将它们分解成更小的图,这确实是可行的。但在实践中,这种模式很麻烦
y = <some computation>
x = y + 2
if x[0] == 2 :
z = y +1
else:
z = y - 1
注意,当你遇到控制流时,会使用一个子图对 x 进行求值,但分支计算中也可能包含先前的变量(比如 y
只比 x 小一个节点,但在你对 x
求值时它并未具体化)。因此,对于这个例子,你实际上是在对 1 个小图和 2 个大图进行求值。并且随着涉及的控制流增多,y 可能在多个分支中更新,这仍然会产生不同组合的大图。