快捷方式

Trainer

class torchrl.trainers.Trainer(*args, **kwargs)[]

一个通用的 Trainer 类。

一个 trainer 负责收集数据和训练模型。为了使此类尽可能通用,Trainer 不构建任何其特定的操作:所有操作都必须在训练循环中的特定点进行钩入 (hooked)。

要构建一个 Trainer,需要一个可迭代的数据源(一个 collector)、一个损失模块和一个优化器。

参数:
  • collector (Sequence[TensorDictBase]) – 一个可迭代对象,返回形状为 [batch x time steps] 的 TensorDict 形式的数据批次。

  • total_frames (int) – 训练期间收集的总帧数。

  • loss_module (LossModule) – 一个模块,读取 TensorDict 批次(可能从回放缓冲区中采样),并返回一个损失 TensorDict,其中每个键指向不同的损失分量。

  • optimizer (optim.Optimizer) – 一个训练模型参数的优化器。

  • logger (Logger, optional) – 一个将处理日志记录的 Logger。

  • optim_steps_per_batch (int) – 每次数据收集的优化步数。一个 trainer 工作方式如下:一个主循环收集数据批次(epoch loop),一个子循环(training loop)在两次数据收集之间执行模型更新。

  • clip_grad_norm (bool, optional) – 如果为 True,梯度将根据模型参数的总范数进行裁剪。如果为 False,所有偏导数将被限制在 (-clip_norm, clip_norm) 之间。默认为 True

  • clip_norm (Number, optional) – 用于裁剪梯度的值。默认为 None(不进行范数裁剪)。

  • progress_bar (bool, optional) – 如果为 True,将使用 tqdm 显示进度条。如果未安装 tqdm,此选项将不起作用。默认为 True

  • seed (int, optional) – 用于 collector、pytorch 和 numpy 的种子。默认为 None

  • save_trainer_interval (int, optional) – trainer 应多久保存到磁盘一次,以帧数计。默认为 10000。

  • log_interval (int, optional) – 应多久记录一次值,以帧数计。默认为 10000。

  • save_trainer_file (path, optional) – 保存 trainer 的路径。默认为 None(不保存)。

load_from_file(file: Union[str, Path], **kwargs) Trainer[]

加载文件及其 state-dict 到 trainer 中。

关键字参数传递给 load() 函数。


© 版权所有 2022,Meta。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源