快捷方式

torchtune.training

检查点

torchtune 提供检查点工具,以实现训练过程中的检查点格式无缝转换以及与生态系统中其他部分的互操作性。有关检查点的全面概述,请参阅检查点深入探讨

FullModelHFCheckpointer

读取和写入 HF 格式检查点的检查点工具。

FullModelMetaCheckpointer

读取和写入 Meta 格式检查点的检查点工具。

FullModelTorchTuneCheckpointer

读取和写入与 torchtune 兼容格式检查点的检查点工具。

ModelType

ModelType 用于检查点工具,以区分不同的模型架构。

FormattedCheckpointFiles

此类提供了一种更简洁的方式来表示格式为 file_{i}_of_{n_files}.pth 的文件名列表。

update_state_dict_for_classifier

验证分类器模型的检查点加载状态字典。

降低精度

在降低精度设置下工作的实用工具。

get_dtype

获取与给定精度字符串对应的 torch.dtype。

set_default_dtype

用于设置 torch 默认 dtype 的上下文管理器。

validate_expected_param_dtype

验证所有输入参数都具有预期的 dtype。

get_quantizer_mode

给定量化器对象,返回一个指定量化类型的字符串。

分布式

启用和使用分布式训练的实用工具。

init_distributed

初始化 torch.distributed 所需的进程组。

is_distributed

检查初始化 torch.distributed 所需的所有环境变量是否已设置以及分布式环境是否已正确安装。

gather_cpu_state_dict

将分片的状态字典转换为 CPU 上的完整状态字典。仅在 rank0 上返回非空结果,以避免 CPU 内存峰值。目前我们可以使用分布式状态字典 API 来处理不包含 NF4Tensor 的模型。

get_distributed_backend

根据设备类型获取 PyTorch 分布式后端。

内存管理

减少训练期间内存消耗的实用工具。

apply_selective_activation_checkpointing

设置激活检查点并包装模型以进行检查点的实用工具。

set_activation_checkpointing

将激活检查点应用于传入模型的实用工具。

OptimizerInBackwardWrapper

一个用于在反向传播中运行的优化器进行检查点保存和加载的基础类。

create_optim_in_bwd_wrapper

创建一个用于在反向传播中运行的优化器步骤的包装器。

register_optim_in_bwd_hooks

注册用于在反向传播中运行的优化器步骤的钩子。

调度器

控制训练过程中学习率的实用工具。

get_cosine_schedule_with_warmup

创建一个学习率调度器,它在 num_warmup_steps 期间将学习率从 0.0 线性增加到 lr,然后在剩余的 num_training_steps-num_warmup_steps 期间按照余弦调度器降低到 0.0(假设 num_cycles = 0.5)。

get_lr

Full_finetune_distributed 和 full_finetune_single_device 假定所有优化器具有相同的学习率 (LR),此处用于验证所有学习率是否相同,如果相同则返回。

指标日志记录

各种日志记录实用工具。

metric_logging.CometLogger

用于 Comet (https://www.comet.com/site/) 的日志记录器。

metric_logging.WandBLogger

用于 Weights and Biases 应用 (https://wandb.ai/) 的日志记录器。

metric_logging.TensorBoardLogger

用于 PyTorch 实现的 TensorBoard (https://pytorch.ac.cn/docs/stable/tensorboard.html) 的日志记录器。

metric_logging.StdoutLogger

记录到标准输出的日志记录器。

metric_logging.DiskLogger

记录到磁盘的日志记录器。

metric_logging.MLFlowLogger

用于 MLFlow (https://mlflow.org/) 的日志记录器。

性能与性能分析

torchtune 提供实用工具来分析和调试微调作业的内存和性能。

get_memory_stats

计算传入设备的内存摘要。

log_memory_stats

将包含内存统计信息的字典记录到日志记录器。

setup_torch_profiler

设置 profile 并返回带设置后更新的性能分析器配置。

其他

get_unmasked_sequence_lengths

返回每个批处理元素的序列长度(0-索引),不包括掩码 token。

disable_dropout

在给定模型中禁用 dropout 层。

set_seed

一个函数,用于为常用库设置伪随机数生成器的种子。

文档

访问 PyTorch 全面的开发者文档

查看文档

教程

获取针对初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得疑问解答

查看资源