博客

通过 PyTorch/XLA 在 Cloud TPU 上了解 LazyTensor 系统性能

作者: 2022年3月2日2024年11月15日暂无评论

简介

易用性、表达能力和可调试性是 PyTorch 的核心原则。实现易用性的关键驱动因素之一是 PyTorch 默认采用“即时(eager)”执行模式,即逐个算子执行,从而保留了程序的命令式特性。然而,即时执行无法提供基于编译器的优化(例如,当计算可以表示为图时所进行的优化)。

LazyTensor [1] 最早随 PyTorch/XLA 引入,旨在结合这些看似截然不同的方法。虽然 PyTorch 的即时执行方式被广泛使用、直观且广为人知,但延迟执行(lazy execution)目前尚不普及。

在本文中,我们将探讨 LazyTensor 系统的一些基本概念,目标是将这些概念应用于理解和调试 PyTorch 中基于 LazyTensor 的实现性能。尽管我们将以 Cloud TPU 上的 PyTorch/XLA 为例来探讨这些概念,但我们希望这些思路对于理解其他基于 LazyTensor 构建的系统同样有所帮助。

LazyTensor

在 PyTorch 张量上执行的任何操作,默认情况下都会作为内核(kernel)或内核组合分发到底层硬件。这些内核在底层硬件上异步执行。除非获取张量的值,否则程序执行不会被阻塞。这种方法在诸如 GPU 等大规模并行编程硬件上表现出极佳的扩展性。

LazyTensor 系统的起点是一种自定义张量类型。在 PyTorch/XLA 中,这种类型称为 XLA 张量。与 PyTorch 的原生张量类型不同,对 XLA 张量执行的操作会被记录到一个 IR(中间表示)图中。让我们来看一个对两个张量积进行求和的例子。

import torch
import torch_xla
import torch_xla.core.xla_model as xm

dev = xm.xla_device()

x1 = torch.rand((3, 3)).to(dev)
x2 = torch.rand((3, 8)).to(dev)

y1 = torch.einsum('bs,st->bt', x1, x2)
print(torch_xla._XLAC._get_xla_tensors_text([y1]))

您可以执行此 Colab 笔记本来查看 y1 的生成图。请注意,此时尚未执行任何计算。

y1 = y1 + x2
print(torch_xla._XLAC._get_xla_tensors_text([y1]))

操作将持续记录,直到 PyTorch/XLA 遇到屏障(barrier)。此屏障可以是 mark_step() API 调用,或者任何其他强制执行已记录图的事件。

xm.mark_step()
print(torch_xla._XLAC._get_xla_tensors_text([y1]))

一旦调用 mark_step(),图就会被编译并在 TPU 上执行,即张量已具体化(materialized)。因此,该图现在被简化为单行的 y1 张量,其中保存了计算结果。

一次编译,多次执行

XLA 编译过程提供了多种优化(例如算子融合,即通过对多个算子使用暂存内存来降低 HBM 压力,参考),并利用底层 XLA 基础设施来最优地使用硬件。然而,有一个注意事项:编译过程开销很大,会增加训练步骤的时间。因此,这种方法只有在能够“一次编译,多次执行”(编译缓存有所帮助,确保同一个图不会被重复编译)时才能很好地扩展。

在接下来的例子中,我们创建一个小型计算图并测量执行时间。

y1 = torch.rand((3, 8)).to(dev)
def dummy_step() :
  y1 = torch.einsum('bs,st->bt', y1, x)
  xm.mark_step()
  return y1
%timeit dummy_step
The slowest run took 29.74 times longer than the fastest. This could mean that an intermediate result is being cached.
10000000 loops, best of 5: 34.2 ns per loop

您会注意到最慢的一步比最快的一步要长得多。这是因为图编译开销只在给定的图形状、输入形状和输出形状下发生一次。随后的步骤更快,因为不需要进行图编译。

这也意味着当“一次编译,多次执行”的假设被打破时,会出现性能瓶颈。理解这一假设何时被打破是理解和优化 LazyTensor 系统性能的关键。让我们检查一下触发编译的原因。

图编译、执行与 LazyTensor 屏障

我们已经看到,当遇到 LazyTensor 屏障时,计算图会被编译和执行。有三种情况下会自动或手动引入 LazyTensor 屏障。第一种是如前例所示显式调用 mark_step() API。当您使用 MpDeviceLoader 包装数据加载器时,mark_step() 也会在每一步隐式调用(强烈建议这样做,以便重叠计算和向 TPU 设备上传数据)。xla_model 的 Optimizer step 方法也允许隐式调用 mark_step(当设置 barrier=True 时)。

引入屏障的第二种情况是当 PyTorch/XLA 发现一个没有映射(降级/lowering)到等效 XLA HLO 算子的操作时。PyTorch 拥有 2000 多个操作。虽然其中大多数是复合操作(即可以用其他基本操作来表示),但有些操作在 XLA 中没有相应的降级实现。

当使用没有 XLA 降级实现的操作时会发生什么?PyTorch XLA 会停止记录操作,并切断导致未降级算子输入的所有图。切断后的图会被编译并分发执行。执行结果(具体化后的张量)从设备返回主机(Host),然后在主机(CPU)上执行该未降级算子,接着下游的 LazyTensor 操作开始创建新的图,直到再次遇到屏障。

导致 LazyTensor 屏障的第三种也是最后一种情况是,存在需要张量值的控制结构/语句或方法。该语句至少会导致导致该张量的计算图执行(如果图已被识别),或者导致两者的编译和执行。

此类方法的其他例子包括 .item()、isEqual()。通常,任何将 Tensor 映射为 Scalar 的操作都会导致这种行为。

动态图

如前所述,如果相同形状的图被多次执行,图编译成本会被摊薄。这是因为编译后的图是使用从图形状、输入形状和输出形状导出的哈希值进行缓存的。如果这些形状发生变化,就会触发编译,过于频繁的编译会导致训练时间变慢。

让我们考虑以下例子:

def dummy_step(x, y, loss, acc=False):
  z = torch.einsum('bs,st->bt', y, x)
  step_loss = z.sum().view(1,)
  if acc:
    loss = torch.cat((loss, step_loss))
  else:
    loss = step_loss
  xm.mark_step()
  return loss


import time
def measure_time(acc=False):
  exec_times = []
  iter_count = 100
  x = torch.rand((512, 8)).to(dev)
  y = torch.rand((512, 512)).to(dev)
  loss = torch.zeros(1).to(dev)
  for i in range(iter_count):
    tic = time.time()
    loss = dummy_step(x, y, loss, acc=acc)
    toc = time.time()
    exec_times.append(toc - tic)
  return exec_times

dyn = measure_time(acc=True) # acc= True Results in dynamic graph
st = measure_time(acc=False) # Static graph, computation shape, inputs and output shapes don't change

import matplotlib.pyplot as plt
plt.plot(st, label = 'static graph')
plt.plot(dyn, label = 'dynamic graph')
plt.legend()
plt.title('Execution time in seconds')

请注意,静态和动态案例执行相同的计算,但动态图每次都会编译,从而导致总运行时间增加。在实践中,伴随重编译的训练步骤有时可能会慢一个数量级。在下一节中,我们将讨论一些用于调试训练性能下降的 PyTorch/XLA 工具。

使用 PyTorch/XLA 分析训练性能

PyTorch/XLA 分析主要由两个部分组成。首先是客户端分析。只需设置环境变量 PT_XLA_DEBUG 为 1 即可启用此功能。客户端分析会指出源代码中未降级的算子或设备到主机的传输。客户端分析还会报告训练过程中是否存在过于频繁的编译。您可以结合分析器,在此笔记本中探索 PyTorch/XLA 提供的一些指标和计数器。

PyTorch/XLA 分析器提供的第二个组件是内联跟踪标注(inline trace annotation)。例如:

import torch_xla.debug.profiler as xp

def train_imagenet():
  print('==> Preparing data..')
  img_dim = get_model_property('img_dim')
  ....
  server = xp.start_server(3294)
  def train_loop_fn(loader, epoch):
    ....
    model.train()
    for step, (data, target) in enumerate(loader):
      with xp.StepTrace('Train_Step', step_num=step):
        ....
        if FLAGS.amp:
        ....
        else:
          with xp.Trace('build_graph'):
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
          xm.optimizer_step(optimizer)

请注意 start_server API 调用。您在此处使用的端口号与您将在 tensorboard 分析器中使用的端口号相同,以便查看类似的算子跟踪:

算子跟踪与客户端调试功能是一套强大的工具,可用于调试和优化您的 PyTorch/XLA 训练性能。有关分析器使用的更详细说明,读者可参阅 PyTorch/XLA 性能调试系列博客的第一部分第二部分第三部分

总结

在本文中,我们回顾了 LazyTensor 系统的基础知识。我们以这些基础知识为基础,结合 PyTorch/XLA 探讨了导致训练性能下降的潜在原因。我们讨论了为什么“一次编译,多次执行”有助于在 LazyTensor 系统上获得最佳性能,以及当该假设被打破时训练速度为何会变慢。

我们希望 PyTorch 用户会发现这些见解对他们使用 LazyTensor 系统进行创新工作有所帮助。

致谢

非常感谢我的杰出同事 Jack Cao、Milad Mohammedi、Karl Weinmeister、Rajesh Thallam、Jordan Tottan (Google) 和 Geeta Chauhan (Meta) 所做的细致审查和反馈。感谢来自 Google、Meta 和开源社区的广大 PyTorch/XLA 开发团队,让 PyTorch 在 TPU 上成为可能。最后,感谢 LazyTensor 论文的作者,不仅开发了 LazyTensor,还撰写了如此浅显易懂的论文。

参考文献

[1] LazyTensor: Combining Eager Execution with Domain-Specific Compilers (LazyTensor:将即时执行与领域特定编译器相结合)