PyTorch 团队

我们很高兴正式发布 torchao,这是一个 PyTorch 原生库,它通过利用低位数据类型 (dtypes)、量化和稀疏性来使模型更快、更小。torchao 是一个易于使用的技术工具集,其代码(大部分)由易于阅读的 PyTorch 代码编写,涵盖推理和训练。本文将帮助您选择对您的工作负载至关重要的技术。

我们在 LLama 3 和 Diffusion 模型等流行的生成式 AI 模型上对我们的技术进行了基准测试,发现准确性下降幅度很小。除非另有说明,基准测试均在 A100 80GB GPU 上以 bf16 运行。

我们在 Llama 3 上的主要指标如下:

  • 使用 int4 仅权重量化和 hqq 的 autoquant 对 Llama 3 8B 推理实现 97% 的加速
  • 使用量化的 KV 缓存,在 128K 上下文长度下对 Llama 3.1 8B 推理实现 73% 的峰值 VRAM 减少
  • 在 H100 上使用 float8 训练对 Llama 3 70B 预训练实现 50% 的加速
  • 使用 4 位量化优化器,对 Llama 3 8B 实现 30% 的峰值 VRAM 减少。

我们在扩散模型推理上的主要指标如下:

  • 在 H100 上使用 float8 动态量化推理和 float8 行级缩放对 flux1.dev 实现 53% 的加速
  • 使用 int8 动态量化,对 CogVideoX 模型 VRAM 减少 50%

下面我们将介绍 torchao 中可用的一些可应用于您的模型进行推理和训练的技术。

推理

我们的推理量化算法适用于包含 nn.Linear 层的任意 PyTorch 模型。可以使用我们的顶层 quantize_ API 选择针对各种数据类型和稀疏布局的仅权重和动态激活量化。

from torchao.quantization import (  
    quantize_,  
    int4_weight_only,  
)  
quantize_(model, int4_weight_only())

有时量化一个层可能会由于开销而使其变慢,因此如果您希望我们为您选择如何量化模型中的每个层,您可以转而运行

model = torchao.autoquant(torch.compile(model, mode='max-autotune'))

quantize_ API 有几种不同的选项,具体取决于您的模型是计算密集型还是内存密集型。

from torchao.quantization import (  
    # Memory bound models  
    int4_weight_only,  
    int8_weight_only,

    # Compute bound models  
    int8_dynamic_activation_int8_semi_sparse_weight,  
    int8_dynamic_activation_int8_weight,  
      
    # Device capability 8.9+  
    float8_weight_only,  
    float8_dynamic_activation_float8_weight,  
)

我们还与 HuggingFace diffusers 团队合作在 diffusers-torchao 中对扩散模型进行了广泛的基准测试,我们在 Flux.1-Dev 上展示了 53.88% 的加速,在 CogVideoX-5b 上展示了 27.33% 的加速

我们的 API 是可组合的,例如我们将稀疏性和量化结合起来,为 ViT-H 推理带来了 5% 的加速

此外,我们还可以将权重量化为 int4,将 kv 缓存量化为 int8,以支持 Llama 3.1 8B 在完整的 128K 上下文长度下以低于 18.9GB 的 VRAM 运行

QAT

训练后量化,特别是在小于 4 位时,可能会导致严重的准确性下降。使用量化感知训练 (QAT),我们在 hellaswag 数据集上成功恢复了高达 96% 的准确性下降。我们已将其集成到 torchtune 中作为一个端到端示例,并提供了一个简要教程

训练

低精度计算和通信

torchao 提供了易于使用的端到端工作流程,用于降低训练计算和分布式通信的精度,首先是 `torch.nn.Linear` 层的 float8。下面是将您的训练运行计算 GEMM 转换为 float8 的一行代码:

from torchao.float8 import convert_to_float8_training  
convert_to_float8_training(model)

有关如何使用 float8 将 LLaMa 3 70B 预训练加速高达 1.5 倍的端到端示例,请参阅我们的README 以及 torchtitan 的博客float8 示例

LLaMa 3 70B 的 float8 预训练性能和准确性,与 bfloat16 对比

(来源:https://dev-discuss.pytorch.org/t/enabling-float8-all-gather-in-fsdp2/2359

我们正在将我们的训练工作流程扩展到更多数据类型和布局

  1. torchtune 中的 NF4 QLoRA
  2. 原型 int8 训练支持
  3. 加速的稀疏 2:4 训练

低位优化器

受 Bits and Bytes 的启发,我们还添加了对 8 位和 4 位优化器的原型支持,可直接替代 AdamW。

from torchao.prototype.low_bit_optim import AdamW8bit, AdamW4bit  
optim = AdamW8bit(model.parameters())

集成

我们一直在积极努力确保 torchao 在一些最重要的开源项目中良好运行。

  1. 将 Huggingface transformers 作为推理后端
  2. diffusers-torchao 中作为加速扩散模型的参考实现
  3. 在 HQQ 中实现快速 4 位推理
  4. torchtune 中提供 PyTorch 原生 QLoRA 和 QAT 示例
  5. torchchat 中进行训练后量化
  6. 在 SGLang 中进行 int4 和 int8 训练后量化

结论

如果您有兴趣让您的模型在训练或推理时更快、更小,我们希望您会发现 torchao 非常有用且易于集成。

pip install torchao

接下来我们有很多令人兴奋的事情,包括降到 4 位以下、用于高吞吐量推理的高性能内核、扩展到更多层、扩展类型或粒度、MX 硬件支持以及支持更多硬件后端。如果您对上述任何内容感兴趣,可以在此处关注我们的进展:https://github.com/pytorch/ao

如果您有兴趣参与 torchao 的工作,我们创建了贡献者指南;如果您有任何问题,我们会在 #torchao 频道交流,地址是 discord.gg/gpumode

致谢

我们很幸运能够站在巨人的肩膀上,并与开源领域的一些最优秀的人士合作。谢谢你们!

  1. Bits and Bytes,感谢他们在低位优化器和 QLoRA 方面的开创性工作
  2. Answer.ai,感谢他们在使 FSDP 和 QLoRA 可组合方面的工程工作
  3. Mobius Labs,感谢他们在量化算法和低位内核方面的愉快的交流
  4. HuggingFace transformers,感谢他们在实战测试和集成我们的工作方面的帮助
  5. HuggingFace diffusers,感谢我们在广泛的基准测试和最佳实践方面的合作
  6. torch.compile,使我们能够用纯 PyTorch 编写算法
  7. GPU MODE,感谢我们大多数早期贡献者