我们激动地宣布,在 PyTorch 的原生低精度库 TorchAO 中,为 Arm CPU 增加了具有低比特位权重(1-8 比特)的嵌入算子,以及具有 8 比特动态量化激活和低比特位权重(1-8 比特)的线性算子。这些算子在所有 PyTorch 平台(包括 eager、torch.compile、AOTI 和 ExecuTorch)上都能无缝工作,并且可以在 torchchat 中使用。
在开发这些线性算子时,我们的重点是在 PyTorch 和 ExecuTorch 之间共享代码,并在高层算子和底层内核之间建立明确的界限。这种设计允许第三方供应商轻松替换成他们自己的内核。我们还着手创建了一个用于实验新 CPU 量化想法并在整个 PyTorch 生态系统中进行测试的场所和基础设施。
通用低比特位内核
目前硬件尚不支持低比特位算术。在我们称之为通用内核的方法中,我们以模块化的方式,明确地将解包低比特位值为 int8 值的逻辑与 int8 GEMV 内核逻辑分离开来。我们从一个 8 比特内核开始,例如,这个使用 Arm neondot 指令的1×8 8 比特 GEMV 内核。在 8 比特内核内部,我们调用一个内联的解包例程,将低比特位值转换为 int8 值。这个解包例程被强制内联,并基于某个低比特位值进行模板化。我们的实验表明,使用一个独立的强制内联解包例程与直接内联嵌入解包代码之间没有性能差异。
这种模块化设计的优势在于提高了开发速度和代码可维护性。在编写了一个 8 比特内核之后,我们通过编写简单的比特封装例程,迅速实现了对所有低比特位的覆盖。事实上,从事比特封装例程开发的开发人员无需是 GEMV/GEMM 内核编写专家。我们还在嵌入内核中复用了线性内核中的相同比特封装例程。未来,我们可以为通用的 GEMM 内核或基于 fma 或 i8mm 指令的内核复用这些相同的比特封装例程。
PyTorch 和 ExecuTorch 之间的共享代码
为了实现 PyTorch 和 ExecuTorch 之间的代码共享,我们使用原始指针而非 PyTorch 张量来编写内核。此外,我们在一个头文件中实现了线性算子,该头文件被包含在独立的 PyTorch 和 ExecuTorch 算子注册代码中。通过只使用 ATen 和 ExecuTorch 张量共有的特性,我们确保了两个框架之间的兼容性。对于多线程计算,我们引入了 torchao::parallel_1d,它会根据编译时标志编译为 at::parallel_for 或 ExecuTorch 的线程池。
可插拔的内核
我们为高层多线程线性算子所做的设计,与底层单线程内核无关,这使得第三方供应商可以换上他们自己的实现。算子和内核之间的接口由一个微内核配置(ukernel config)定义,该配置指定了用于准备激活数据、准备权重数据和运行内核的内核函数指针。负责分块和调度的算子完全通过此配置与内核交互。
性能
在下表中,我们展示了在 M1 Macbook Pro(32GB RAM)上使用 6 个 CPU 线程生成 Llama3.1 8B 令牌的性能。
比特宽度 x | torch.compile (解码令牌数/秒) | ExecuTorch (解码令牌数/秒) | ExecuTorch PTE 大小 (GiB) |
1 | 24.18 | 17.86 | 1.46 |
2 | 27.02 | 19.65 | 2.46 |
3 | 21.01 | 22.25 | 3.46 |
4 | 19.51 | 19.47 | 4.47 |
5 | 14.78 | 16.34 | 5.47 |
6 | 12.80 | 13.61 | 6.47 |
7 | 8.16 | 11.73 | 7.48 |
结果是在一台 M1 Macbook Pro(拥有 8 个性能核心和 2 个能效核心)上运行,配备 32GB RAM,并使用 torchchat 并行 6 个线程。在每次测试中,都生成了最大序列长度为 128 的令牌。对于每个比特宽度 x,嵌入层被分组量化到 x 比特,组大小为 32。在线性层中,激活按令牌动态量化为 8 比特,权重被分组量化到 x 比特,组大小为 256。我们在这里关注的是性能,不报告准确率或困惑度数字。根据模型的不同,较低的比特宽度可能需要量化感知训练、使用混合比特宽度量化模型,或调整组大小以获得可接受的准确率。

试用并贡献!
如果您想亲身体验新的低比特位内核,请通过设置 torchchat 并使用这些内核在本地量化并运行一个 LLM 来试用它们。
如果您想参与贡献,可以考虑为以下领域之一添加支持:
- 为 Arm CPU 添加通用的低比特位 GEMM 内核,复用通用 GEMV 内核中的相同比特封装例程。
- 根据 ISA、封装格式和激活形状改进微内核配置的运行时选择。
- 为其他 CPU ISA(如 x86)添加低比特位内核。
- 将第三方库(如 KleidiAI)与算子框架集成。