如果您希望在设备端实现强大的 AI 功能,且又不希望耗尽内存预算或让手机变得滚烫,那么您需要的工具应比“训练后量化”(Post-training Quantization)更为精细,因为在训练后量化中,模型无法恢复由此导致的精度损失。
我们制作了一系列实用的 Jupyter Notebook 教程,旨在向开发者和机器学习研究人员介绍各种先进的软硬件协同设计主题。我们展示了混合精度量化、量化感知技术以及专家混合模型(Mixture of Expert models)如何构建既小巧又强大,并可通过诸如 ExecuTorch 等边缘推理运行时在 Arm 设备上高效运行的模型。
量化与软硬件协同设计
其目标是调整精度,以在最大化模型压缩的同时最大限度地减少精度损失。传统的量化(从 FP32 到 INT8)是一种强大但过于粗糙的工具,因为神经网络中的每一层对精度损失的敏感度并不相同。
所需的精度水平取决于数据的分布。下图展示了在 4-bit 量化下,Transformer 中不同的前馈和注意力组件会产生差异巨大的量化误差。这说明要将精度损失降至最低,需要自适应地分配比特,以便为每个部分使用适当的精度。我们展示了如何通过 PyTorch 的 QConfig API 轻松实现混合精度级别。

此外,为了在混合精度模型中保持强劲的性能,Arm 的 KleidiAI 提供了高度优化的计算内核(目前支持低至 4-bit),确保低比特张量类型能高效映射到 Arm 硬件指令上。对于希望部署到智能手机和笔记本电脑等 Arm 设备上的开发者来说,在使用 PyTorch 和边缘推理运行时 ExecuTorch,并通过 KleidiAI 和 Arm VGF 后端时,这一切都是透明完成的。
在我们的 Notebook 中,我们还探讨了软硬件协同设计,其目标不仅是训练模型以最小化损失,还通过让模型学习如何量化每一层来最小化模型大小。这能引导开发者在精度和模型紧凑性之间取得平衡,并训练出能够可靠地适应特定内存占用空间要求的模型。

上述公式展示了我们如何构建一个损失函数,其中包含了模型精度的软件成本项以及静态模型大小的硬件成本项。我们在 PyTorch 中直接实现了该示例损失函数,并在我们的 软硬件协同设计教程中进行了深入探讨,在该教程中,我们将该函数用于在 Tiny Shakespeare 数据集上训练一个基于 Transformer 的网络。
极端量化
基于这种协同设计的视角,我们的教程进一步探讨了使激进的低比特部署变得可行的训练算法。量化感知训练(QAT)在训练过程中让模型接触模拟的低精度算术,从而使其权重和激活能够适应舍入噪声。
QAT 不再将量化视为部署的最后一步,而是让优化器在整个训练过程中“看到”量化器,这在压缩至 8-bit 以下或面临严苛内存预算时尤为重要。极端量化更进一步,探讨在保持有用精度的前提下,我们能多接近二进制表示。关于超低比特大语言模型的最新研究,例如 《1-bit LLMs 时代:所有大语言模型均为 1.58 bits》(Ma et al., 2024),表明精心的算法-硬件协同设计可以在保持现代架构功能的前提下实现极高程度的压缩。
在我们的 PyTorch Notebook 中,您可以端到端地实验这些想法:从基础模型开始,启用 QAT,探索更激进的量化方案,并观察精度、模型大小和性能在实践中的权衡。下图(在我们的教程中有详细讨论)显示,通过 QAT,1-2 bit 的二值/三值模型几乎可以媲美 8-bit 的基准模型,同时表现远超简单的低比特训练后量化。
使用专家混合架构实现高效模型推理
除了量化,我们的课程还涵盖了 实现专家混合(MoE)模型的入门内容。与稠密模型(每个参数都用于处理每个输入)不同,MoE 架构仅在给定 token 时激活网络中的特定部分,即“专家”。

为了熟悉这些主题,我们发布了一系列详尽的 Jupyter Notebook,作为实用的分步指南。我们的实验包含约 10 小时的实用内容,旨在让您能够在自己的硬件上运行和修改代码,从而将理论即时转化为实践。 点击此处查看。
该系列教程是由 Arm 的 Kieran Hejmadi 与多位领先学者共同合作完成的,其中包括南安普顿大学 AI 研究员 Oliver Grainge,以及 IEEE Fellow、德克萨斯大学奥斯汀分校电气与计算机工程系教授 Constantine Caramanis。
此外,感谢来自 IIT 德里和 IIT 海德拉巴的学术审稿人,他们确保了材料既处于技术前沿,又经过了严谨的验证。
如果您正在寻找更多入门内容,请尝试我们的课程: “在 Arm 处理器上优化生成式 AI:从边缘到云端”。
参考文献
Ma, S., et al. (2024). The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits. arXiv