快捷方式

torchrl.trainers 包

trainer 包提供了编写可重用训练脚本的工具。核心思想是使用一个实现嵌套循环的训练器,其中外层循环运行数据收集步骤,内层循环运行优化步骤。我们认为这适用于多种强化学习训练方案,例如 on-policy、off-policy、基于模型和无模型的解决方案、离线 RL 等。更特殊的案例,例如 meta-RL 算法,其训练方案可能存在显著差异。

trainer.train() 方法的示意图如下

训练器循环
        >>> for batch in collector:
        ...     batch = self._process_batch_hook(batch)  # "batch_process"
        ...     self._pre_steps_log_hook(batch)  # "pre_steps_log"
        ...     self._pre_optim_hook()  # "pre_optim_steps"
        ...     for j in range(self.optim_steps_per_batch):
        ...         sub_batch = self._process_optim_batch_hook(batch)  # "process_optim_batch"
        ...         losses = self.loss_module(sub_batch)
        ...         self._post_loss_hook(sub_batch)  # "post_loss"
        ...         self.optimizer.step()
        ...         self.optimizer.zero_grad()
        ...         self._post_optim_hook()  # "post_optim"
        ...         self._post_optim_log(sub_batch)  # "post_optim_log"
        ...     self._post_steps_hook()  # "post_steps"
        ...     self._post_steps_log_hook(batch)  #  "post_steps_log"

There are 10 hooks that can be used in a trainer loop:

        >>> for batch in collector:
        ...     batch = self._process_batch_hook(batch)  # "batch_process"
        ...     self._pre_steps_log_hook(batch)  # "pre_steps_log"
        ...     self._pre_optim_hook()  # "pre_optim_steps"
        ...     for j in range(self.optim_steps_per_batch):
        ...         sub_batch = self._process_optim_batch_hook(batch)  # "process_optim_batch"
        ...         losses = self.loss_module(sub_batch)
        ...         self._post_loss_hook(sub_batch)  # "post_loss"
        ...         self.optimizer.step()
        ...         self.optimizer.zero_grad()
        ...         self._post_optim_hook()  # "post_optim"
        ...         self._post_optim_log(sub_batch)  # "post_optim_log"
        ...     self._post_steps_hook()  # "post_steps"
        ...     self._post_steps_log_hook(batch)  #  "post_steps_log"

There are 10 hooks that can be used in a trainer loop:

     >>> for batch in collector:
     ...     batch = self._process_batch_hook(batch)  # "batch_process"
     ...     self._pre_steps_log_hook(batch)  # "pre_steps_log"
     ...     self._pre_optim_hook()  # "pre_optim_steps"
     ...     for j in range(self.optim_steps_per_batch):
     ...         sub_batch = self._process_optim_batch_hook(batch)  # "process_optim_batch"
     ...         losses = self.loss_module(sub_batch)
     ...         self._post_loss_hook(sub_batch)  # "post_loss"
     ...         self.optimizer.step()
     ...         self.optimizer.zero_grad()
     ...         self._post_optim_hook()  # "post_optim"
     ...         self._post_optim_log(sub_batch)  # "post_optim_log"
     ...     self._post_steps_hook()  # "post_steps"
     ...     self._post_steps_log_hook(batch)  #  "post_steps_log"

训练器循环中可以使用 10 个钩子:"batch_process", "pre_optim_steps", "process_optim_batch", "post_loss", "post_steps", "post_optim", "pre_steps_log", "post_steps_log", "post_optim_log""optimizer"。它们在应用的地方已在注释中注明。钩子可分为 3 类:数据处理 ("batch_process""process_optim_batch"),日志记录 ("pre_steps_log", "post_optim_log""post_steps_log") 和 操作 钩子 ("pre_optim_steps", "post_loss", "post_optim""post_steps")。

  • 数据处理 钩子用于更新数据的 TensorDict。钩子的 __call__ 方法应接受一个 TensorDict 对象作为输入,并根据某种策略对其进行更新。这类钩子的示例包括回放缓冲区扩展 (ReplayBufferTrainer.extend)、数据归一化(包括归一化常数更新)、数据子采样(:class:~torchrl.trainers.BatchSubSampler)等。

  • 日志记录 钩子接收表示为 TensorDict 的数据批次,并在记录器中写入从该数据中检索到的信息。示例包括 LogValidationReward 钩子、奖励记录器 (LogScalar) 等。钩子应返回一个包含要记录的数据的字典(或 None 值)。键 "log_pbar" 保留用于布尔值,指示记录的值是否应显示在训练日志中打印的进度条上。

  • 操作 钩子是在模型、数据收集器、目标网络更新等方面执行特定操作的钩子。例如,使用 UpdateWeights 同步收集器的权重,或使用 ReplayBufferTrainer.update_priority 更新回放缓冲区的优先级,都是操作钩子的示例。它们是数据独立的(不需要 TensorDict 输入),只需在每次迭代(或每 N 次迭代)执行一次即可。

TorchRL 提供的钩子通常继承自一个共同的抽象基类 TrainerHookBase,并且都实现了三个基本方法:用于检查点保存的 state_dictload_state_dict 方法,以及在训练器中以默认值注册钩子的 register 方法。此方法接收训练器和模块名称作为输入。例如,以下日志记录钩子在每次调用 "post_optim_log" 10 次后执行

>>> class LoggingHook(TrainerHookBase):
...     def __init__(self):
...         self.counter = 0
...
...     def register(self, trainer, name):
...         trainer.register_module(self, "logging_hook")
...         trainer.register_op("post_optim_log", self)
...
...     def save_dict(self):
...         return {"counter": self.counter}
...
...     def load_state_dict(self, state_dict):
...         self.counter = state_dict["counter"]
...
...     def __call__(self, batch):
...         if self.counter % 10 == 0:
...             self.counter += 1
...             out = {"some_value": batch["some_value"].item(), "log_pbar": False}
...         else:
...             out = None
...         self.counter += 1
...         return out

检查点保存

训练器类和钩子支持检查点保存,可以通过使用 torchsnapshot 后端或常规的 torch 后端来实现。这可以通过全局变量 CKPT_BACKEND 控制

$ CKPT_BACKEND=torchsnapshot python script.py

CKPT_BACKEND 默认为 torch。torchsnapshot 相对于 pytorch 的优势在于它是一个更灵活的 API,支持分布式检查点保存,并且允许用户将存储在磁盘文件中的张量加载到具有物理存储的张量中(pytorch 目前不支持此功能)。例如,这使得可以将张量从和加载到原本无法容纳在内存中的回放缓冲区中。

构建训练器时,可以提供检查点保存路径。对于 torchsnapshot 后端,期望的是目录路径,而 torch 后端期望的是文件路径(通常是 .pt 文件)。

>>> filepath = "path/to/dir/or/file"
>>> trainer = Trainer(
...     collector=collector,
...     total_frames=total_frames,
...     frame_skip=frame_skip,
...     loss_module=loss_module,
...     optimizer=optimizer,
...     save_trainer_file=filepath,
... )
>>> select_keys = SelectKeys(["action", "observation"])
>>> select_keys.register(trainer)
>>> # to save to a path
>>> trainer.save_trainer(True)
>>> # to load from a path
>>> trainer.load_from_file(filepath)

Trainer.train() 方法可用于执行上述包含所有钩子的循环,尽管仅将 Trainer 类用于其检查点保存功能也是完全有效的用法。

训练器和钩子

BatchSubSampler(batch_size[, sub_traj_len, ...])

用于在线 RL 最新实现的数据子采样器。

ClearCudaCache(interval)

按给定间隔清除 cuda 缓存。

CountFramesLog(*args, **kwargs)

一个帧计数钩子。

LogScalar([logname, log_pbar, reward_key])

奖励记录器钩子。

OptimizerHook(optimizer[, loss_components])

为一个或多个损失组件添加优化器。

LogValidationReward(*, record_interval, ...)

Trainer 的记录器钩子。

ReplayBufferTrainer(replay_buffer[, ...])

回放缓冲区钩子提供者。

RewardNormalizer([decay, scale, eps, ...])

奖励归一化器钩子。

SelectKeys(keys)

选择 TensorDict 批次中的键。

Trainer(*args, **kwargs)

一个通用 Trainer 类。

TrainerHookBase()

用于 torchrl Trainer 类的抽象钩子类。

UpdateWeights(collector, update_weights_interval)

一个收集器权重更新钩子类。

构建器

make_collector_offpolicy(make_env, ...[, ...])

返回用于 off-policy 最新实现的数据收集器。

make_collector_onpolicy(make_env, ...[, ...])

在 on-policy 设置中创建一个收集器。

make_dqn_loss(model, cfg)

构建 DQN 损失模块。

make_replay_buffer(device, cfg)

使用从 ReplayArgsConfig 构建的配置构建回放缓冲区。

make_target_updater(cfg, loss_module)

构建一个目标网络权重更新对象。

make_trainer(collector, loss_module[, ...])

根据其组成部分创建一个 Trainer 实例。

parallel_env_constructor(cfg, **kwargs)

从使用适当的解析器构建器构建的 argparse.Namespace 返回一个并行环境。

sync_async_collector(env_fns, env_kwargs[, ...])

运行异步收集器,每个收集器运行同步环境。

sync_sync_collector(env_fns, env_kwargs[, ...])

运行同步收集器,每个收集器运行同步环境。

transformed_env_constructor(cfg[, ...])

从使用适当的解析器构建器构建的 argparse.Namespace 返回一个环境创建器。

工具函数

correct_for_frame_skip(cfg)

根据输入的 frame_skip 调整参数,将所有反映帧计数目的的参数除以 frame_skip。

get_stats_random_rollout(cfg[, ...])

使用随机 rollout 从环境中收集统计数据(loc 和 scale)。

记录器

Logger(exp_name, log_dir)

记录器的模板。

csv.CSVLogger(exp_name[, log_dir, ...])

一个最小依赖的 CSV 记录器。

mlflow.MLFlowLogger(exp_name, tracking_uri)

mlflow 记录器的包装器。

tensorboard.TensorboardLogger(exp_name[, ...])

Tensorboard 记录器的包装器。

wandb.WandbLogger(*args, **kwargs)

wandb 记录器的包装器。

get_logger(logger_type, logger_name, ...)

获取指定 logger_type 的记录器实例。

generate_exp_name(model_name, experiment_name)

使用 UUID 和当前日期为描述的实验生成一个 ID (str)。

录制工具函数

录制工具函数在此处详细介绍。

VideoRecorder(logger, tag[, in_keys, skip, ...])

视频录制器 transform。

TensorDictRecorder(out_file_base[, ...])

TensorDict 记录器。

PixelRenderTransform([out_keys, preproc, ...])

一个 transform,用于在父环境中调用 render,并在 tensordict 中注册像素观察结果。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获取问题解答

查看资源