跳转到主要内容
公告

PyTorch 原生架构优化:torchao

作者: 2024年9月26日2025年4月30日暂无评论

我们很高兴正式推出 torchao,这是一个 PyTorch 原生库,它通过利用低位数据类型、量化和稀疏性来使模型更快、更小。torchao 是一个易于访问的技术工具包,(大部分)用易于阅读的 PyTorch 代码编写,涵盖推理和训练。这篇博客将帮助您选择适合您工作负载的技术。

我们对 Llama 3 和扩散模型等流行的生成式人工智能模型进行了技术基准测试,并发现准确性下降极小。除非另有说明,基线是在 A100 80GB GPU 上运行的 bf16。

Llama 3 的核心指标如下:

  • 使用 int4 仅权重(weight only)量化和 HQQ 的 autoquant,Llama 3 8B 推理速度提升 97%
  • Llama 3.1 8B 在 128K 上下文长度推理时,通过量化 KV 缓存,峰值显存(VRAM)减少 73%
  • Llama 3 70B 在 H100 上使用 float8 训练,预训练速度提升 50%
  • Llama 3 8B 使用 4 位量化优化器,峰值显存(VRAM)减少 30%。

扩散模型推理的核心指标如下:

  • 在 H100 上的 flux1.dev 上,使用 float8 动态量化推理和 float8 行式(row-wise)缩放,速度提升 53%
  • CogVideoX 使用 int8 动态量化,模型显存(VRAM)减少 50%

下面我们将介绍 torchao 中可用于推理和训练模型的一些技术。

推理

我们的推理量化算法 适用于包含 nn.Linear 层的任意 PyTorch 模型。可以使用我们的顶级 quantize_ API 选择针对各种数据类型和稀疏布局的仅权重(weight only)和动态激活量化。

from torchao.quantization import (  
    quantize_,  
    int4_weight_only,  
)  
quantize_(model, int4_weight_only())

有时,由于开销,量化层可能会使其变慢,因此如果您希望我们为您选择模型中每个层的量化方式,您可以运行:

model = torchao.autoquant(torch.compile(model, mode='max-autotune'))

quantize_ API 有几种不同的选项,具体取决于您的模型是计算密集型(compute bound)还是内存密集型(memory bound)。

from torchao.quantization import (  
    # Memory bound models  
    int4_weight_only,  
    int8_weight_only,

    # Compute bound models  
    int8_dynamic_activation_int8_semi_sparse_weight,  
    int8_dynamic_activation_int8_weight,  
      
    # Device capability 8.9+  
    float8_weight_only,  
    float8_dynamic_activation_float8_weight,  
)

我们还与 HuggingFace diffusers 团队合作,在 diffusers-torchao 中对扩散模型进行了广泛的基准测试,其中我们展示了 Flux.1-Dev 上 53.88% 的速度提升和 CogVideoX-5b 上 27.33% 的速度提升。

bar chart

我们的 API 是可组合的,例如,我们已经组合了稀疏性和量化,为 ViT-H 推理带来了 5% 的速度提升

但也可以做一些事情,例如将权重量化为 int4,将 KV 缓存量化为 int8,以支持 Llama 3.1 8B 在不到 18.9GB 的 VRAM 中以完整的 128K 上下文长度运行

QAT

训练后量化,尤其是低于 4 位的量化,可能会导致严重的精度下降。通过使用 量化感知训练 (QAT),我们成功地在 hellaswag 上恢复了高达 96% 的精度下降。我们已将其作为 torchtune 中的端到端配方集成,并提供了最少的 教程

训练

低精度计算和通信

torchao 提供了易于使用的端到端工作流程,用于降低训练计算和分布式通信的精度,从 `torch.nn.Linear` 层的 float8 开始。这是一个将您的训练运行的计算 gemms 转换为 float8 的单行代码:

from torchao.float8 import convert_to_float8_training  
convert_to_float8_training(model)

有关如何使用 float8 将 LLaMa 3 70B 预训练加速高达 1.5 倍 的端到端示例,请参阅我们的 README,以及 torchtitan 的 博客float8 配方

LLaMa 3 70B float8 预训练与 bfloat16 的性能和准确性

(来源:https://dev-discuss.pytorch.org/t/enabling-float8-all-gather-in-fsdp2/2359

我们正在将训练工作流程扩展到更多数据类型和布局

  1. torchtune 中的 NF4 QLoRA
  2. 原型 int8 训练支持
  3. 加速稀疏 2:4 训练

低位优化器

受 Bits and Bytes 的启发,我们还添加了 8 位和 4 位优化器的原型支持,可作为 AdamW 的直接替代。

from torchao.prototype.low_bit_optim import AdamW8bit, AdamW4bit  
optim = AdamW8bit(model.parameters())

集成

我们一直积极致力于确保 torchao 在一些最重要的开源项目中良好运行。

  1. Huggingface transformers 作为 推理后端
  2. 在 diffusers-torchao 中 作为加速扩散模型的参考实现
  3. 在 HQQ 中用于 快速 4 位推理
  4. torchtune 中 用于 PyTorch 原生 QLoRA 和 QAT 配方
  5. torchchat 中 用于训练后量化
  6. 在 SGLang 中用于 int4 和 int8 训练后量化

结论

如果您有兴趣让您的模型在训练或推理时更快、更小,我们希望您会发现 torchao 有用且易于集成。

pip install torchao

我们对接下来有很多兴奋的事情,包括低于 4 位、用于高吞吐量推理的高性能内核、扩展到更多层、缩放类型或粒度、MX 硬件支持以及支持更多硬件后端。如果以上任何一项听起来令人兴奋,您可以在这里关注我们的进展:https://github.com/pytorch/ao

如果您有兴趣参与 torchao 的工作,我们创建了 贡献者指南,如果您有任何问题,我们会在 discord.gg/gpumode#torchao 频道上与您交流。

致谢

我们很幸运能站在巨人的肩膀上,并与一些最优秀的开源人士合作。谢谢!

  1. Bits and Bytes 在低位优化器和 QLoRA 方面的开创性工作
  2. Answer.ai 在使 FSDP 和 QLoRA 兼容方面的工程工作
  3. Mobius Labs 在量化算法和低位内核方面的精彩交流
  4. HuggingFace transformers 在实战测试和集成我们工作方面的帮助
  5. HuggingFace diffusers 在广泛基准测试和最佳实践方面的合作
  6. torch.compile 使我们能够用纯 PyTorch 编写算法
  7. GPU MODE 的大部分早期贡献者