快捷方式

torchrl.trainers 包

训练器包提供了编写可重用训练脚本的实用程序。核心思想是使用一个实现嵌套循环的训练器,其中外循环运行数据收集步骤,内循环运行优化步骤。我们认为这适合多种 RL 训练方案,例如在线策略、离线策略、基于模型和无模型的解决方案、离线 RL 等。更具体的案例,例如元 RL 算法,可能具有训练方案存在很大差异。

trainer.train() 方法可以概括如下

训练器循环
        >>> for batch in collector:
        ...     batch = self._process_batch_hook(batch)  # "batch_process"
        ...     self._pre_steps_log_hook(batch)  # "pre_steps_log"
        ...     self._pre_optim_hook()  # "pre_optim_steps"
        ...     for j in range(self.optim_steps_per_batch):
        ...         sub_batch = self._process_optim_batch_hook(batch)  # "process_optim_batch"
        ...         losses = self.loss_module(sub_batch)
        ...         self._post_loss_hook(sub_batch)  # "post_loss"
        ...         self.optimizer.step()
        ...         self.optimizer.zero_grad()
        ...         self._post_optim_hook()  # "post_optim"
        ...         self._post_optim_log(sub_batch)  # "post_optim_log"
        ...     self._post_steps_hook()  # "post_steps"
        ...     self._post_steps_log_hook(batch)  #  "post_steps_log"

There are 10 hooks that can be used in a trainer loop:

        >>> for batch in collector:
        ...     batch = self._process_batch_hook(batch)  # "batch_process"
        ...     self._pre_steps_log_hook(batch)  # "pre_steps_log"
        ...     self._pre_optim_hook()  # "pre_optim_steps"
        ...     for j in range(self.optim_steps_per_batch):
        ...         sub_batch = self._process_optim_batch_hook(batch)  # "process_optim_batch"
        ...         losses = self.loss_module(sub_batch)
        ...         self._post_loss_hook(sub_batch)  # "post_loss"
        ...         self.optimizer.step()
        ...         self.optimizer.zero_grad()
        ...         self._post_optim_hook()  # "post_optim"
        ...         self._post_optim_log(sub_batch)  # "post_optim_log"
        ...     self._post_steps_hook()  # "post_steps"
        ...     self._post_steps_log_hook(batch)  #  "post_steps_log"

There are 10 hooks that can be used in a trainer loop:

     >>> for batch in collector:
     ...     batch = self._process_batch_hook(batch)  # "batch_process"
     ...     self._pre_steps_log_hook(batch)  # "pre_steps_log"
     ...     self._pre_optim_hook()  # "pre_optim_steps"
     ...     for j in range(self.optim_steps_per_batch):
     ...         sub_batch = self._process_optim_batch_hook(batch)  # "process_optim_batch"
     ...         losses = self.loss_module(sub_batch)
     ...         self._post_loss_hook(sub_batch)  # "post_loss"
     ...         self.optimizer.step()
     ...         self.optimizer.zero_grad()
     ...         self._post_optim_hook()  # "post_optim"
     ...         self._post_optim_log(sub_batch)  # "post_optim_log"
     ...     self._post_steps_hook()  # "post_steps"
     ...     self._post_steps_log_hook(batch)  #  "post_steps_log"

在训练器循环中可以使用 10 个钩子: "batch_process""pre_optim_steps""process_optim_batch""post_loss""post_steps""post_optim""pre_steps_log""post_steps_log""post_optim_log""optimizer"。它们在应用的位置用注释表示。钩子可以分为 3 类:**数据处理**( "batch_process""process_optim_batch")、**日志记录**( "pre_steps_log""post_optim_log""post_steps_log")和**操作**钩子( "pre_optim_steps""post_loss""post_optim""post_steps")。

  • **数据处理**钩子更新一个数据张量字典。钩子的 __call__ 方法应接受一个 TensorDict 对象作为输入,并根据某种策略进行更新。此类钩子的示例包括回放缓冲区扩展( ReplayBufferTrainer.extend)、数据归一化(包括归一化常数更新)、数据子采样(:class:~torchrl.trainers.BatchSubSampler)等。

  • **日志记录**钩子获取以 TensorDict 表示的一批数据,并在日志记录器中写入从该数据中检索到的某些信息。示例包括 Recorder 钩子、奖励日志记录器( LogReward)等。钩子应返回一个包含要记录数据的字典(或 None 值)。键 "log_pbar" 保留给指示是否应在训练日志上打印的进度条上显示记录值的布尔值。

  • **操作**钩子是在模型、数据收集器、目标网络更新等上执行特定操作的钩子。例如,使用 UpdateWeights 同步收集器的权重或使用 ReplayBufferTrainer.update_priority 更新回放缓冲区的优先级是操作钩子的示例。它们与数据无关(它们不需要 TensorDict 输入),它们仅在每次迭代(或每 N 次迭代)时执行一次。

TorchRL 提供的钩子通常继承自一个通用的抽象类 TrainerHookBase,并且都实现了三个基本方法:用于检查点的 state_dictload_state_dict 方法以及在训练器中的默认值中注册钩子的 register 方法。此方法将训练器和模块名称作为输入。例如,以下日志记录钩子每调用 10 次 "post_optim_log" 就会执行

>>> class LoggingHook(TrainerHookBase):
...     def __init__(self):
...         self.counter = 0
...
...     def register(self, trainer, name):
...         trainer.register_module(self, "logging_hook")
...         trainer.register_op("post_optim_log", self)
...
...     def save_dict(self):
...         return {"counter": self.counter}
...
...     def load_state_dict(self, state_dict):
...         self.counter = state_dict["counter"]
...
...     def __call__(self, batch):
...         if self.counter % 10 == 0:
...             self.counter += 1
...             out = {"some_value": batch["some_value"].item(), "log_pbar": False}
...         else:
...             out = None
...         self.counter += 1
...         return out

检查点

训练器类和钩子支持检查点,可以通过使用 torchsnapshot 后端或常规的 torch 后端来实现。这可以通过全局变量 CKPT_BACKEND 控制

$ CKPT_BACKEND=torchsnapshot python script.py

CKPT_BACKEND 默认为 torch。torchsnapshot 相对于 pytorch 的优势在于它是一个更灵活的 API,支持分布式检查点,还允许用户将磁盘上存储的文件中的张量加载到具有物理存储的张量中(pytorch 目前不支持)。例如,这允许将张量从回放缓冲区加载到回放缓冲区,否则这些张量将无法容纳在内存中。

构建训练器时,可以提供一个路径,将检查点写入该路径。使用 torchsnapshot 后端时,需要目录路径,而 torch 后端需要文件路径(通常是 .pt 文件)。

>>> filepath = "path/to/dir/or/file"
>>> trainer = Trainer(
...     collector=collector,
...     total_frames=total_frames,
...     frame_skip=frame_skip,
...     loss_module=loss_module,
...     optimizer=optimizer,
...     save_trainer_file=filepath,
... )
>>> select_keys = SelectKeys(["action", "observation"])
>>> select_keys.register(trainer)
>>> # to save to a path
>>> trainer.save_trainer(True)
>>> # to load from a path
>>> trainer.load_from_file(filepath)

Trainer.train() 方法可用于执行上述循环及其所有钩子,尽管仅使用 Trainer 类及其检查点功能也是一个完全有效的用法。

训练器和钩子

BatchSubSampler(batch_size[, sub_traj_len, ...])

用于在线 RL sota 实现的数据子采样器。

ClearCudaCache(interval)

以给定间隔清除 cuda 缓存。

CountFramesLog(*args, **kwargs)

帧计数钩子。

LogReward([logname, log_pbar, reward_key])

奖励日志记录器钩子。

OptimizerHook(optimizer[, loss_components])

为一个或多个损失组件添加优化器。

Recorder(*, record_interval, record_frames)

用于 Trainer 的记录器钩子。

ReplayBufferTrainer(replay_buffer[, ...])

回放缓冲区钩子提供程序。

RewardNormalizer([decay, scale, eps, ...])

奖励归一化钩子。

SelectKeys(keys)

在 TensorDict 批次中选择键。

Trainer(*args, **kwargs)

一个通用的 Trainer 类。

TrainerHookBase()

torchrl Trainer 类的抽象钩子类。

UpdateWeights(collector, update_weights_interval)

一个收集器权重更新钩子类。

构建器

make_collector_offpolicy(make_env, ...[, ...])

返回用于离策略 sota 实现的数据收集器。

make_collector_onpolicy(make_env, ...[, ...])

在策略内设置中创建收集器。

make_dqn_loss(model, cfg)

构建 DQN 损失模块。

make_replay_buffer(device, cfg)

使用从 ReplayArgsConfig 构建的配置构建回放缓冲区。

make_target_updater(cfg, loss_module)

构建目标网络权重更新对象。

make_trainer(collector, loss_module[, ...])

给定其组成部分创建 Trainer 实例。

parallel_env_constructor(cfg, **kwargs)

从使用适当的解析器构造函数构建的 argparse.Namespace 返回并行环境。

sync_async_collector(env_fns, env_kwargs[, ...])

运行异步收集器,每个收集器都运行同步环境。

sync_sync_collector(env_fns, env_kwargs[, ...])

运行同步收集器,每个收集器都运行同步环境。

transformed_env_constructor(cfg[, ...])

从使用适当的解析器构造函数构建的 argparse.Namespace 返回环境创建器。

实用程序

correct_for_frame_skip(cfg)

根据输入的 frame_skip 校正参数,将反映帧数的所有参数除以 frame_skip。

get_stats_random_rollout(cfg[, ...])

使用随机 rollout 从环境中收集统计数据(位置和尺度)。

日志记录器

Logger(exp_name, log_dir)

日志记录器的模板。

csv.CSVLogger(exp_name[, log_dir, ...])

一个最小依赖的 CSV 日志记录器。

mlflow.MLFlowLogger(exp_name, tracking_uri)

mlflow 日志记录器的包装器。

tensorboard.TensorboardLogger(exp_name[, ...])

Tensoarboard 日志记录器的包装器。

wandb.WandbLogger(*args, **kwargs)

wandb 日志记录器的包装器。

get_logger(logger_type, logger_name, ...)

获取提供的 logger_type 的日志记录器实例。

generate_exp_name(model_name, experiment_name)

使用 UUID 和当前日期为描述的实验生成一个 ID(字符串)。

录制实用程序

录制实用程序的详细内容在此

VideoRecorder(logger, tag[, in_keys, skip, ...])

视频录制转换。

TensorDictRecorder(out_file_base[, ...])

TensorDict 记录器。

PixelRenderTransform([out_keys, preproc, ...])

一个在父环境上调用渲染并在 tensordict 中注册像素观察的转换。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源