快捷方式

torchrl.trainers 包

trainer 包提供了编写可重用训练脚本的实用程序。核心思想是使用一个训练器,它实现一个嵌套循环,其中外循环运行数据收集步骤,内循环运行优化步骤。我们相信这适用于多种 RL 训练方案,例如在线策略、离线策略、基于模型和无模型的解决方案、离线 RL 等。更特殊的情况,例如元 RL 算法,可能具有实质上不同的训练方案。

trainer.train() 方法可以概括如下

训练器循环
        >>> for batch in collector:
        ...     batch = self._process_batch_hook(batch)  # "batch_process"
        ...     self._pre_steps_log_hook(batch)  # "pre_steps_log"
        ...     self._pre_optim_hook()  # "pre_optim_steps"
        ...     for j in range(self.optim_steps_per_batch):
        ...         sub_batch = self._process_optim_batch_hook(batch)  # "process_optim_batch"
        ...         losses = self.loss_module(sub_batch)
        ...         self._post_loss_hook(sub_batch)  # "post_loss"
        ...         self.optimizer.step()
        ...         self.optimizer.zero_grad()
        ...         self._post_optim_hook()  # "post_optim"
        ...         self._post_optim_log(sub_batch)  # "post_optim_log"
        ...     self._post_steps_hook()  # "post_steps"
        ...     self._post_steps_log_hook(batch)  #  "post_steps_log"

There are 10 hooks that can be used in a trainer loop:

        >>> for batch in collector:
        ...     batch = self._process_batch_hook(batch)  # "batch_process"
        ...     self._pre_steps_log_hook(batch)  # "pre_steps_log"
        ...     self._pre_optim_hook()  # "pre_optim_steps"
        ...     for j in range(self.optim_steps_per_batch):
        ...         sub_batch = self._process_optim_batch_hook(batch)  # "process_optim_batch"
        ...         losses = self.loss_module(sub_batch)
        ...         self._post_loss_hook(sub_batch)  # "post_loss"
        ...         self.optimizer.step()
        ...         self.optimizer.zero_grad()
        ...         self._post_optim_hook()  # "post_optim"
        ...         self._post_optim_log(sub_batch)  # "post_optim_log"
        ...     self._post_steps_hook()  # "post_steps"
        ...     self._post_steps_log_hook(batch)  #  "post_steps_log"

There are 10 hooks that can be used in a trainer loop:

     >>> for batch in collector:
     ...     batch = self._process_batch_hook(batch)  # "batch_process"
     ...     self._pre_steps_log_hook(batch)  # "pre_steps_log"
     ...     self._pre_optim_hook()  # "pre_optim_steps"
     ...     for j in range(self.optim_steps_per_batch):
     ...         sub_batch = self._process_optim_batch_hook(batch)  # "process_optim_batch"
     ...         losses = self.loss_module(sub_batch)
     ...         self._post_loss_hook(sub_batch)  # "post_loss"
     ...         self.optimizer.step()
     ...         self.optimizer.zero_grad()
     ...         self._post_optim_hook()  # "post_optim"
     ...         self._post_optim_log(sub_batch)  # "post_optim_log"
     ...     self._post_steps_hook()  # "post_steps"
     ...     self._post_steps_log_hook(batch)  #  "post_steps_log"

训练器循环中有 10 个钩子可以使用:"batch_process""pre_optim_steps""process_optim_batch""post_loss""post_steps""post_optim""pre_steps_log""post_steps_log""post_optim_log""optimizer"。它们在应用的地方在注释中指出。钩子可以分为 3 类:数据处理 ("batch_process""process_optim_batch")、日志记录 ("pre_steps_log""post_optim_log""post_steps_log") 和 操作 钩子 ("pre_optim_steps""post_loss""post_optim""post_steps")。

  • 数据处理 钩子更新数据的 tensordict。钩子 __call__ 方法应接受一个 TensorDict 对象作为输入,并根据某些策略对其进行更新。此类钩子的示例包括回放缓冲区扩展 (ReplayBufferTrainer.extend)、数据标准化(包括标准化常数更新)、数据子采样 (:~torchrl.trainers.BatchSubSampler) 等。

  • 日志记录 钩子获取一批数据,以 TensorDict 的形式呈现,并在记录器中写入从该数据检索到的一些信息。示例包括 Recorder 钩子、奖励记录器 (LogReward) 等。钩子应返回一个字典(或 None 值),其中包含要记录的数据。键 "log_pbar" 保留给布尔值,指示记录的值是否应显示在训练日志上打印的进度条上。

  • 操作 钩子是在模型、数据收集器、目标网络更新等上执行特定操作的钩子。例如,使用 UpdateWeights 同步收集器的权重或使用 ReplayBufferTrainer.update_priority 更新回放缓冲区的优先级是操作钩子的示例。它们与数据无关(它们不需要 TensorDict 输入),它们应该在每次迭代(或每 N 次迭代)执行一次。

TorchRL 提供的钩子通常继承自通用的抽象类 TrainerHookBase,并且都实现三个基本方法:用于检查点的 state_dictload_state_dict 方法,以及在训练器中以默认值注册钩子的 register 方法。此方法将训练器和模块名称作为输入。例如,以下日志记录钩子每 10 次调用 "post_optim_log" 执行一次

>>> class LoggingHook(TrainerHookBase):
...     def __init__(self):
...         self.counter = 0
...
...     def register(self, trainer, name):
...         trainer.register_module(self, "logging_hook")
...         trainer.register_op("post_optim_log", self)
...
...     def save_dict(self):
...         return {"counter": self.counter}
...
...     def load_state_dict(self, state_dict):
...         self.counter = state_dict["counter"]
...
...     def __call__(self, batch):
...         if self.counter % 10 == 0:
...             self.counter += 1
...             out = {"some_value": batch["some_value"].item(), "log_pbar": False}
...         else:
...             out = None
...         self.counter += 1
...         return out

检查点

训练器类和钩子支持检查点,可以使用 torchsnapshot 后端或常规 torch 后端来实现。这可以通过全局变量 CKPT_BACKEND 来控制

$ CKPT_BACKEND=torchsnapshot python script.py

CKPT_BACKEND 默认为 torch。torchsnapshot 相对于 pytorch 的优势在于它是一个更灵活的 API,它支持分布式检查点,并且还允许用户将磁盘上存储的文件中的张量加载到具有物理存储的张量(pytorch 目前不支持)。例如,这允许从和向可能无法容纳在内存中的回放缓冲区加载张量。

构建训练器时,可以提供要写入检查点的路径。使用 torchsnapshot 后端,需要目录路径,而 torch 后端需要文件路径(通常是 .pt 文件)。

>>> filepath = "path/to/dir/or/file"
>>> trainer = Trainer(
...     collector=collector,
...     total_frames=total_frames,
...     frame_skip=frame_skip,
...     loss_module=loss_module,
...     optimizer=optimizer,
...     save_trainer_file=filepath,
... )
>>> select_keys = SelectKeys(["action", "observation"])
>>> select_keys.register(trainer)
>>> # to save to a path
>>> trainer.save_trainer(True)
>>> # to load from a path
>>> trainer.load_from_file(filepath)

Trainer.train() 方法可用于执行上述循环及其所有钩子,尽管仅使用 Trainer 类来实现其检查点功能也是完全有效的用法。

训练器和钩子

BatchSubSampler(batch_size[, sub_traj_len, ...])

用于在线 RL sota 实现的数据子采样器。

ClearCudaCache(interval)

在给定间隔清除 cuda 缓存。

CountFramesLog(*args, **kwargs)

帧计数器钩子。

LogReward([logname, log_pbar, reward_key])

奖励记录器钩子。

OptimizerHook(optimizer[, loss_components])

为一个或多个损失分量添加优化器。

Recorder(*, record_interval, record_frames)

Trainer 的记录器钩子。

ReplayBufferTrainer(replay_buffer[, ...])

回放缓冲区钩子提供程序。

RewardNormalizer([decay, scale, eps, ...])

奖励归一化器钩子。

SelectKeys(keys)

在 TensorDict 批次中选择键。

Trainer(*args, **kwargs)

通用 Trainer 类。

TrainerHookBase()

torchrl Trainer 类的抽象钩子类。

UpdateWeights(collector, update_weights_interval)

收集器权重更新钩子类。

构建器

make_collector_offpolicy(make_env, ...[, ...])

返回用于离线策略 sota 实现的数据收集器。

make_collector_onpolicy(make_env, ...[, ...])

在在线策略设置中创建一个收集器。

make_dqn_loss(model, cfg)

构建 DQN 损失模块。

make_replay_buffer(device, cfg)

使用从 ReplayArgsConfig 构建的配置构建回放缓冲区。

make_target_updater(cfg, loss_module)

构建目标网络权重更新对象。

make_trainer(collector, loss_module[, ...])

根据其组成部分创建 Trainer 实例。

parallel_env_constructor(cfg, **kwargs)

从使用适当的解析器构造函数构建的 argparse.Namespace 返回并行环境。

sync_async_collector(env_fns, env_kwargs[, ...])

运行异步收集器,每个收集器运行同步环境。

sync_sync_collector(env_fns, env_kwargs[, ...])

运行同步收集器,每个收集器运行同步环境。

transformed_env_constructor(cfg[, ...])

从使用适当的解析器构造函数构建的 argparse.Namespace 返回环境创建器。

实用工具

correct_for_frame_skip(cfg)

通过将所有反映帧计数的参数除以 frame_skip,来校正输入 frame_skip 的参数。

get_stats_random_rollout(cfg[, ...])

使用随机 rollout 从环境中收集统计信息(位置和比例)。

记录器

Logger(exp_name, log_dir)

记录器的模板。

csv.CSVLogger(exp_name[, log_dir, ...])

最小依赖项 CSV 记录器。

mlflow.MLFlowLogger(exp_name, tracking_uri)

mlflow 记录器的包装器。

tensorboard.TensorboardLogger(exp_name[, ...])

Tensoarboard 记录器的包装器。

wandb.WandbLogger(*args, **kwargs)

wandb 记录器的包装器。

get_logger(logger_type, logger_name, ...)

获取提供的 logger_type 的记录器实例。

generate_exp_name(model_name, experiment_name)

使用 UUID 和当前日期为描述的实验生成 ID (str)。

记录实用工具

记录实用工具在此处详细介绍:此处

VideoRecorder(logger, tag[, in_keys, skip, ...])

视频记录器转换。

TensorDictRecorder(out_file_base[, ...])

TensorDict 记录器。

PixelRenderTransform([out_keys, preproc, ...])

一种在父环境中调用 render 并在 tensordict 中注册像素观测的转换。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源