快捷方式

训练器

class torchrl.trainers.Trainer(*args, **kwargs)[source]

通用训练器类。

训练器负责收集数据和训练模型。为了使该类尽可能通用,训练器不会构建任何特定操作:所有操作都必须在训练循环中的特定点挂钩。

要构建训练器,需要一个可迭代数据源(一个 collector)、一个损失模块和一个优化器。

参数:
  • collector (Sequence[TensorDictBase]) – 一个可迭代对象,以形状为 [批次 x 时间步长] 的 TensorDict 形式返回数据批次。

  • total_frames (int) – 训练期间要收集的总帧数。

  • loss_module (LossModule) – 一个模块,读取 TensorDict 批次(可能从回放缓冲区采样),并返回一个损失 TensorDict,其中每个键指向不同的损失组件。

  • optimizer (optim.Optimizer) – 一个优化器,用于训练模型的参数。

  • logger (Logger, optional) – 处理日志记录的记录器。

  • optim_steps_per_batch (int) – 每收集一次数据的优化步骤数。训练器的工作原理如下:主循环收集数据批次(时期循环),子循环(训练循环)在两次数据收集之间执行模型更新。

  • clip_grad_norm (bool, optional) – 如果为 True,则根据模型参数的总范数剪裁梯度。如果为 False,则所有偏导数将被钳制到 (-clip_norm, clip_norm)。默认值为 True

  • clip_norm (Number, optional) – 用于剪裁梯度的值。默认值为 None(不剪裁范数)。

  • progress_bar (bool, optional) – 如果为 True,则使用 tqdm 显示进度条。如果未安装 tqdm,此选项将不起作用。默认值为 True

  • seed (int, optional) – 用于收集器、pytorch 和 numpy 的种子。默认值为 None

  • save_trainer_interval (int, optional) – 训练器应保存到磁盘的频率,以帧数为单位。默认值为 10000。

  • log_interval (int, optional) – 值应记录的频率,以帧数为单位。默认值为 10000。

  • save_trainer_file (path, optional) – 保存训练器的路径。默认值为 None(不保存)。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取适合初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得解答

查看资源