快捷方式

torchtune.training

检查点

torchtune 提供检查点程序,允许在训练的检查点格式之间无缝转换,并与生态系统的其余部分互操作。有关检查点的全面概述,请参阅检查点深入探讨

FullModelHFCheckpointer

检查点程序,用于读取和写入 HF 格式的检查点。

FullModelMetaCheckpointer

检查点程序,用于读取和写入 Meta 格式的检查点。

FullModelTorchTuneCheckpointer

检查点程序,用于读取和写入与 torchtune 兼容的格式的检查点。

ModelType

ModelType 被检查点程序用来区分不同的模型架构。

FormattedCheckpointFiles

此类提供了一种更简洁的方式来表示格式为 file_{i}_of_{n_files}.pth 的文件名列表。

update_state_dict_for_classifier

验证分类器模型检查点加载的状态字典。

降低精度

用于在降低精度设置下工作的实用程序。

get_dtype

获取与给定精度字符串对应的 torch.dtype。

set_default_dtype

上下文管理器,用于设置 torch 的默认 dtype。

validate_expected_param_dtype

验证所有输入参数是否具有预期的 dtype。

get_quantizer_mode

给定一个量化器对象,返回一个指定量化类型的字符串。

分布式

用于启用和使用分布式训练的实用程序。

init_distributed

初始化 torch.distributed 所需的进程组。

is_distributed

检查是否设置了初始化 torch.distributed 所需的所有环境变量,以及是否正确安装了 distributed。

get_world_size_and_rank

函数,用于获取默认进程组中当前进程的世界大小(又名总 rank 数)和 rank 号。

gather_cpu_state_dict

将分片状态字典转换为 CPU 上的完整状态字典。仅在 rank0 上返回非空结果以避免 CPU 内存峰值

内存管理

用于减少训练期间内存消耗的实用程序。

apply_selective_activation_checkpointing

用于设置激活检查点并包装模型以进行检查点的实用程序。

set_activation_checkpointing

用于将激活检查点应用于传入模型的实用程序。

OptimizerInBackwardWrapper

一个简易类,旨在为向后运行的优化器保存和加载检查点。

create_optim_in_bwd_wrapper

为向后运行的优化器步骤创建包装器。

register_optim_in_bwd_hooks

为向后运行的优化器步骤注册钩子。

调度器

用于在训练过程中控制 lr 的实用程序。

get_cosine_schedule_with_warmup

创建一个学习率调度器,该调度器在 num_warmup_steps 步内将学习率从 0.0 线性增加到 lr,然后在剩余的 num_training_steps-num_warmup_steps 步内以余弦调度器的方式降低到 0.0(假设 num_cycles = 0.5)。

get_lr

Full_finetune_distributed 和 full_finetune_single_device 假设所有优化器都具有相同的 LR,这里用于验证所有 LR 是否相同,并在为 True 时返回。

指标日志记录

各种日志记录实用程序。

metric_logging.CometLogger

用于 Comet 的 Logger (https://www.comet.com/site/)。

metric_logging.WandBLogger

用于 Weights and Biases 应用程序的 Logger (https://wandb.ai/)。

metric_logging.TensorBoardLogger

用于 PyTorch 的 TensorBoard 实现的 Logger (https://pytorch.ac.cn/docs/stable/tensorboard.html)。

metric_logging.StdoutLogger

输出到标准输出的 Logger。

metric_logging.DiskLogger

记录到磁盘的 Logger。

性能和分析

torchtune 提供了实用程序来分析和调试微调作业的内存和性能。

get_memory_stats

计算传入设备的内存摘要。

log_memory_stats

将包含内存统计信息的字典记录到 logger。

setup_torch_profiler

设置 profile 并返回带有设置后更新的分析器配置。

杂项

get_unmasked_sequence_lengths

返回每个批次元素的序列长度,排除掩码标记。

set_seed

函数,用于为常用库中的伪随机数生成器设置种子。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源