作者:Evgeni Burovski、Ralf Gommers 和 Mario Lezcano

Quansight 工程师已在 PyTorch 2.1 中实现了通过 torch.compile 跟踪 NumPy 代码的支持。此功能利用 PyTorch 的编译器生成高效的融合向量化代码,而无需修改原始 NumPy 代码。更重要的是,它还允许仅通过在 torch.device("cuda") 下运行 torch.compile,在 CUDA 上执行 NumPy 代码!

在这篇文章中,我们将介绍如何使用此功能,并提供一些技巧和窍门,以充分利用它。

将 NumPy 代码编译为并行 C++

我们以 K-Means 算法中的一步作为运行示例。这段代码借用自这本 NumPy 书籍

import numpy as np

def kmeans(X, means):
    return np.argmin(np.linalg.norm(X - means[:, None], axis=2), axis=0)

我们创建了一个包含 2000 万个随机 2D 点的合成数据集。我们可以看到,在适当选择均值的情况下,该函数为所有点返回正确的聚类

npts = 10_000_000
X = np.repeat([[5, 5], [10, 10]], [npts, npts], axis=0)
X = X + np.random.randn(*X.shape)  # 2 distinct "blobs"
means = np.array([[5, 5], [10, 10]])
np_pred = kmeans(X, means)

在此函数上进行基准测试,我们在 AMD 3970X CPU 上得到 1.26 秒 的基线。

现在编译此函数就像用 torch.compile 包装它并使用示例输入执行它一样简单

import torch

compiled_fn = torch.compile(kmeans)
compiled_pred = compiled_fn(X, means)
assert np.allclose(np_pred, compiled_pred)

编译后的函数在 1 个核心上运行时产生 9 倍的加速。更棒的是,与 NumPy 相比,我们生成的代码确实利用了处理器中的所有核心。因此,当我们在 32 个核心上运行时,我们获得了 57 倍的加速。请注意,除非明确限制,否则 PyTorch 始终使用所有可用核心,因此这是使用 torch.compile 时的默认行为。

我们可以通过使用环境变量 TORCH_LOGS=output_code 运行脚本来检查生成的 C++ 代码。这样做时,我们可以看到 torch.compile 能够将广播和两个归约编译成一个 for 循环,并使用 OpenMP 对其进行并行化

extern "C" void kernel(const double* in_ptr0, const long* in_ptr1, long* out_ptr0) {
    #pragma omp parallel num_threads(32)
    #pragma omp for
    for(long i0=0L; i0<20000000L; i0+=1L) {
        auto tmp0 = in_ptr0[2L*i0];
        auto tmp1 = in_ptr1[0L];
        auto tmp5 = in_ptr0[1L + (2L*i0)];
        auto tmp6 = in_ptr1[1L];
        // Rest of the kernel omitted for brevity

将 NumPy 代码编译为 CUDA

编译我们的代码使其在 CUDA 上运行就像将默认设备设置为 CUDA 一样简单

with torch.device("cuda"):
    cuda_pred = compiled_fn(X, means)
assert np.allclose(np_pred, cuda_pred)

通过 TORCH_LOGS=output_code 检查生成的代码,我们看到 torch.compile 没有直接生成 CUDA 代码,而是生成了相当可读的 triton 代码

def triton_(in_ptr0, in_ptr1, out_ptr0, XBLOCK : tl.constexpr):
    xnumel = 20000000
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (2*x0), xmask)
    tmp1 = tl.load(in_ptr1 + (0))
    // Rest of the kernel omitted for brevity

在 RTX 2060 上运行这个小代码片段,与原始 NumPy 代码相比,获得了 8 倍的加速。这还不错,但考虑到我们在 CPU 上看到的加速,这并不是特别令人印象深刻。让我们看看如何通过一些小的更改来最大限度地发挥 GPU 的性能。

float64 vs float32。许多 GPU,尤其是消费级 GPU,在 float64 上运行操作时相当迟缓。因此,将数据生成更改为 float32,原始 NumPy 代码只会稍微快一点,大约 9%,但我们的 CUDA 代码会快 40%,与纯 NumPy 代码相比,加速 11 倍

torch.compile 默认情况下尊重 NumPy 语义,因此,它使用 np.float64 作为其所有创建操作的默认 dtype。正如讨论的那样,这可能会阻碍性能,因此可以通过设置来更改此默认值

from torch._dynamo import config
config.numpy_default_float = "float32"

CPU <> CUDA 副本。11 倍的加速很好,但与 CPU 数字相差甚远。这是由 torch.compile 在幕后进行的一个小转换引起的。上面的代码接受 NumPy 数组并返回 NumPy 数组。所有这些数组都在 CPU 上,但计算在 GPU 上执行。这意味着每次调用该函数时,torch.compile 都必须将所有这些数组从 CPU 复制到 GPU,然后将结果复制回 CPU 以保留原始语义。NumPy 中没有针对此问题的原生解决方案,因为 NumPy 没有 device 的概念。话虽如此,我们可以通过为此函数创建一个包装器来解决这个问题,以便它接受 PyTorch 张量并返回 PyTorch 张量。

@torch.compile
def tensor_fn(X, means):
    X, means = X.numpy(), means.numpy()
    ret = kmeans(X, means)
    return torch.from_numpy(ret)

def cuda_fn(X, means):
    with torch.device("cuda"):
        return tensor_fn(X, means)

此函数现在接受 CUDA 内存中的张量并返回 CUDA 内存中的张量,但函数本身是用 NumPy 编写的!torch.compile 使用 numpy()from_numpy() 调用作为提示,并优化掉它们,并且在内部它只是使用 PyTorch 张量而根本不移动内存。当我们将张量保留在 CUDA 中并在 float32 中执行计算时,我们看到比最初的 float32 数组上的 NumPy 实现加速 200 倍

混合 NumPy 和 PyTorch。在本例中,我们必须编写一个小适配器来将张量转换为 ndarray,然后再转换回张量。在混合 PyTorch 和 NumPy 的程序中,将张量转换为 ndarray 通常实现为 x.detach().cpu().numpy(),或者简单地 x.numpy(force=True)。由于在 torch.compile 下运行时,我们可以在 CUDA 中运行 NumPy 代码,因此我们可以将此转换模式实现为调用 x.numpy(),就像我们上面所做的那样。这样做并在 device("cuda") 下运行结果代码将从原始 NumPy 调用生成高效的 CUDA 代码,而根本不会将数据从 CUDA 复制到 CPU。请注意,结果代码在没有 torch.compile 的情况下无法运行。为了使其在 eager 模式下运行,需要回滚到 x.numpy(force=True)

进一步加速技巧

一般建议。我们展示的 CUDA 代码已经相当高效,但确实运行示例相当短。在处理更大的程序时,我们可能需要调整其中的一部分以使其更高效。一个好的起点是 torch.compile 的多个教程和常见问题解答。这展示了检查跟踪过程的多种方法,以及如何识别可能导致速度减慢的问题代码。

编译 NumPy 代码时的建议。NumPy 即使与 PyTorch 非常相似,但通常使用方式却大相径庭。在 NumPy 中执行计算,然后根据数组中的值执行 if/else,或者就地执行操作(可能通过布尔掩码)是很常见的。这些构造虽然受 torch.compile 支持,但会影响其性能。像以无分支方式编写代码以避免图形中断,或避免就地操作这样的更改可能会大有帮助。

要编写快速的 NumPy 代码,最好避免循环,但有时它们是不可避免的。当跟踪循环时,torch.compile 将尝试完全展开它。这有时是可取的,但有时甚至可能是不可能的,例如当我们有动态停止条件(如 while 循环)时。在这些情况下,最好只编译循环体,可能一次编译几个迭代(循环展开)。

调试 NumPy 代码。当涉及到编译器时,调试非常棘手。要弄清楚您遇到的错误是 torch.compile 错误还是程序错误,您可以通过将 NumPy 导入替换为 import torch._numpy as np 来在没有 torch.compile 的情况下执行 NumPy 程序。这应该仅用于 调试目的,绝不能替代 PyTorch API,因为它 速度慢得多,并且作为私有 API,可能会在不另行通知的情况下更改。另请参阅 此常见问题解答 以了解其他技巧。

NumPy 和 torch.compile NumPy 之间的差异

NumPy 标量。在几乎所有 PyTorch 将返回 0-D 张量的情况下(例如,来自 np.sum),NumPy 都会返回 NumPy 标量。在 torch.compile 下,NumPy 标量被视为 0-D 数组。这在大多数情况下都很好。它们的行为不同的唯一情况是 NumPy 标量隐式用作 Python 标量时。例如,

>>> np.asarray(2) * [1, 2, 3]  # 0-D array is an array-like
array([2, 4, 6])
>>> u = np.int32(2)
>>> u * [1, 2, 3]              # scalar decays into a Python int
[1, 2, 3, 1, 2, 3]
>>> torch.compile(lambda: u * [1, 2, 3])()
array([2, 4, 6])               # acts as a 0-D array, not as a scalar ?!?!

如果我们编译前两行,我们看到 torch.compileu 视为 0-D 数组。要恢复 eager 语义,我们只需要显式进行强制转换

>>> torch.compile(lambda: int(u) * [1, 2, 3])()
[1, 2, 3, 1, 2, 3]

类型提升和版本控制。NumPy 的类型提升规则有时可能有点令人惊讶

>>> np.zeros(1, dtype=np.int8) + 127
array([127], dtype=int8)
>>> np.zeros(1, dtype=np.int8) + 128
array([128], dtype=int16)

NumPy 2.0 正在更改这些规则,以遵循更接近 PyTorch 的规则。相关的技术文档是 NEP 50torch.compile 率先实施了 NEP 50,而不是即将弃用的规则。

总的来说,torch.compile 中的 NumPy 遵循 NumPy 2.0 预发布版。

超越 NumPy:SciPy 和 scikit-learn

与使 torch.compile 理解 NumPy 代码的努力并行,其他 Quansight 工程师设计并提出了一种在 scikit-learn 和 SciPy 中支持 PyTorch 张量的方法。这受到了这些库的其他维护者的热烈欢迎,因为事实证明,使用 PyTorch 作为后端通常会产生可观的加速。这两个项目现在都合并了对跨多个 API 和子模块的 PyTorch 张量的初始支持。

这为走向未来奠定了基础,在未来,PyTorch 张量可以在 Python 数据生态系统中的其他库中使用。更重要的是,这将使在 GPU 上运行这些其他库,甚至编译混合这些库和 PyTorch 的代码成为可能,类似于我们在本文中讨论的内容。

如果您想了解有关此工作的更多信息、如何使用它或如何帮助推进它,请参阅 这篇其他博文

结论

PyTorch 自成立以来一直致力于成为与 Python 生态系统其余部分兼容的框架。启用编译 NumPy 程序,并建立必要的工具来为其他重要库执行相同的操作,是朝着这个方向迈出的又两步。Quansight 和 Meta 继续携手合作,提高 PyTorch 与生态系统其余部分之间的兼容性。

Quansight 感谢 Mengwei、Voz 和 Ed 在将我们的工作与 torch.compile 集成方面提供的宝贵帮助。我们还要感谢 Meta 资助该项目以及之前在提高 PyTorch 中 NumPy 兼容性的工作,以及导致在 scikit-learn 和 SciPy 中支持 PyTorch 的项目。这些是朝着巩固 PyTorch 作为开源 Python 数据生态系统中的首选框架迈出的巨大一步。