• 文档 >
  • 什么是训练方案?
快捷方式

什么是训练方案?

本次深入探讨将引导您了解 torchtune 中训练方案(training-recipes)的设计。

本次深入探讨将涵盖的内容
  • 什么是训练方案?

  • 构成训练方案的核心组件是什么?

  • 如何构建新的训练方案?

训练方案(Recipes)是 torchtune 用户的主要入口点。它们可以被视为针对训练大型语言模型(LLM)并可选地进行评估的“有针对性的”端到端管道。每个训练方案实现了一种训练方法(例如:全量微调),并在一组有意义的功能(例如:FSDP + 激活检查点 + 梯度累积 + 混合精度训练)下应用于特定的模型家族(例如:Llama2)。

随着模型训练变得越来越复杂,在权衡所有可能的折衷(例如:内存 vs 模型质量)的同时,预测新的模型架构和训练方法变得更加困难。我们认为 a) 用户最适合根据自己的用例做出特定的权衡,并且 b) 没有一概而论的解决方案。因此,训练方案旨在易于理解、扩展和调试,*而不是*适用于所有可能设置的通用入口点。

根据您的用例和专业知识水平,您会经常发现自己在修改现有训练方案(例如:添加新功能)或编写新的训练方案。torchtune 通过提供经过充分测试的模块化组件/构建块和通用工具(例如:WandB 日志记录检查点)来简化训练方案的编写。


训练方案设计

torchtune 中的训练方案设计原则如下:

  • 简单。完全使用原生的 PyTorch 编写。

  • 正确。对每个组件进行数值一致性验证,并与参考实现和基准进行广泛比较。

  • 易于理解。每个训练方案提供有限的一组有意义的功能,而不是将所有可能的功能隐藏在数百个标志后面。优先选择代码重复而不是不必要的抽象。

  • 易于扩展。不依赖于训练框架,没有实现继承。用户无需深入层层抽象来弄清楚如何扩展核心功能。

  • 适用于不同水平的用户。用户可以决定如何与 torchtune 训练方案进行交互:
    • 通过修改现有配置文件开始训练模型

    • 修改现有训练方案以应对自定义场景

    • 直接使用现有的构建块编写全新的训练方案/训练范式

每个训练方案包含三个组件:

  • 可配置参数,通过 yaml 配置文件和命令行覆盖指定

  • 训练方案脚本,将所有内容整合在一起的入口点,包括解析和验证配置文件、设置环境以及正确使用训练方案类

  • 训练方案类,训练所需的核心逻辑,通过一组 API 向用户公开

在接下来的部分,我们将更详细地介绍这些组件。有关完整的可运行示例,请参考 torchtune 中的 全量微调训练方案 以及相关的 配置文件


训练方案不是什么?

  • 整体式训练器。 训练方案**不是**旨在通过数百个标志支持所有可能功能的整体式训练器。

  • 通用入口点。 训练方案**不**旨在支持所有可能的模型架构或微调方法。

  • 外部框架的包装器。 训练方案**不**旨在成为外部框架的包装器。它们完全使用 torchtune 构建块用原生 PyTorch 编写。依赖项主要是额外工具或与周围生态系统的互操作性(例如:EleutherAI 的评估工具集)的形式。


训练方案脚本

这是每个训练方案的主要入口点,让用户控制如何设置训练方案、如何训练模型以及如何使用后续的检查点。这包括:

  • 环境设置

  • 解析和验证配置文件

  • 训练模型

  • 使用多个训练方案类设置多阶段训练(例如:知识蒸馏)

脚本通常应按以下顺序组织操作:

  • 初始化训练方案类,进而初始化训练方案状态

  • 加载并验证检查点,以便在恢复训练时更新训练方案状态

  • 从检查点初始化训练方案组件(模型、分词器、优化器、损失函数和数据加载器)(如果适用)

  • 训练模型

  • 训练完成后清理训练方案状态

示例脚本大致如下:

# Initialize the process group
init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl")

# Setup the recipe and train the model
recipe = FullFinetuneRecipeDistributed(cfg=cfg)
recipe.setup(cfg=cfg)
recipe.train()
recipe.cleanup()

# Other stuff to do after training is complete
...

训练方案类

训练方案类包含训练模型的核心逻辑。每个类实现相关的接口并暴露一组 API。对于微调,此类的结构如下:

初始化训练方案状态,包括种子、设备、数据类型、指标记录器、相关标志等

def __init__(...):

    self._device = utils.get_device(device=params.device)
    self._dtype = training.get_dtype(dtype=params.dtype, device=self._device)
    ...

加载检查点,从检查点更新训练方案状态,初始化组件并从检查点加载状态字典

def setup(self, cfg: DictConfig):

    ckpt_dict = self.load_checkpoint(cfg.checkpointer)

    # Setup the model, including FSDP wrapping, setting up activation checkpointing and
    # loading the state dict
    self._model = self._setup_model(...)
    self._tokenizer = self._setup_tokenizer(...)

    # Setup Optimizer, including transforming for FSDP when resuming training
    self._optimizer = self._setup_optimizer(...)
    self._loss_fn = self._setup_loss(...)
    self._sampler, self._dataloader = self._setup_data(...)

在所有 epoch 中运行前向和后向传播,并在每个 epoch 结束时保存检查点

def train(...):

    self._optimizer.zero_grad()
    for curr_epoch in range(self.epochs_run, self.total_epochs):

        for idx, batch in enumerate(self._dataloader):
            ...

            with self._autocast:
                logits = self._model(...)
                ...
                loss = self._loss_fn(logits, labels)

            if self.global_step % self._log_every_n_steps == 0:
                self._metric_logger.log_dict(...)

            loss.backward()
            self._optimizer.step()
            self._optimizer.zero_grad()

            # Update the number of steps when the weights are updated
            self.global_step += 1

        self.save_checkpoint(epoch=curr_epoch)

清理训练方案状态

def cleanup(...)

    self.metric_loggers.close()
    ...

使用配置文件运行训练方案

要使用一组用户定义的参数运行训练方案,您需要编写一个配置文件。您可以在我们的配置文件深入探讨中了解所有关于配置文件的信息。

使用 parse 进行配置文件和 CLI 解析

我们提供一个方便的装饰器 parse(),它包装您的训练方案,使其能够使用 tune 命令从命令行运行,并解析配置文件和 CLI 覆盖项。

@config.parse
def recipe_main(cfg: DictConfig) -> None:
    recipe = FullFinetuneRecipe(cfg=cfg)
    recipe.setup(cfg=cfg)
    recipe.train()
    recipe.cleanup()

运行您的训练方案

您应该能够通过提供指向您的自定义训练方案和自定义配置文件的直接路径,并使用 tune 命令以及任何 CLI 覆盖项来运行您的训练方案

tune run <path/to/recipe> --config <path/to/config> k1=v1 k2=v2 ...

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源