• 教程 >
  • 循环 DQN:训练循环策略
快捷方式

循环 DQN:训练循环策略

创建时间:2023 年 11 月 8 日 | 最后更新:2025 年 1 月 27 日 | 最后验证:未验证

作者Vincent Moens

你将学到什么
  • 如何在 TorchRL 中将 RNN 集成到 actor 中

  • 如何将基于内存的策略与回放缓冲区和损失模块一起使用

先决条件
  • PyTorch v2.0.0

  • gym[mujoco]

  • tqdm

概述

基于内存的策略不仅在观测部分可见时至关重要,而且在必须考虑时间维度以做出明智决策时也至关重要。

循环神经网络长期以来一直是基于内存策略的热门工具。其思想是在两个连续步骤之间将循环状态保存在内存中,并将其与当前观测一起用作策略的输入。

本教程展示了如何在 TorchRL 中使用 RNN 来实现策略。

关键学习要点

  • 在 TorchRL 中将 RNN 集成到 actor 中;

  • 将基于内存的策略与回放缓冲区和损失模块一起使用。

在 TorchRL 中使用 RNN 的核心思想是使用 TensorDict 作为数据载体,用于从一个步骤到下一个步骤的隐藏状态。我们将构建一个策略,该策略从当前 TensorDict 中读取先前的循环状态,并将当前的循环状态写入下一个状态的 TensorDict 中

Data collection with a recurrent policy

如图所示,我们的环境会用归零的循环状态填充 TensorDict,这些状态由策略与观测一起读取以产生动作,以及将用于下一步的循环状态。调用 step_mdp() 函数时,来自下一个状态的循环状态会被带到当前的 TensorDict。让我们看看如何在实践中实现这一点。

如果你在 Google Colab 中运行此程序,请确保安装以下依赖项

!pip3 install torchrl
!pip3 install gym[mujoco]
!pip3 install tqdm

设置

import torch
import tqdm
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
from torch import nn
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.envs import (
    Compose,
    ExplorationType,
    GrayScale,
    InitTracker,
    ObservationNorm,
    Resize,
    RewardScaling,
    set_exploration_type,
    StepCounter,
    ToTensorImage,
    TransformedEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import ConvNet, EGreedyModule, LSTMModule, MLP, QValueModule
from torchrl.objectives import DQNLoss, SoftUpdate

is_fork = multiprocessing.get_start_method() == "fork"
device = (
    torch.device(0)
    if torch.cuda.is_available() and not is_fork
    else torch.device("cpu")
)

环境

像往常一样,第一步是构建我们的环境:它帮助我们定义问题并相应地构建策略网络。在本教程中,我们将运行 CartPole gym 环境的单个基于像素的实例,并带有一些自定义 Transforms:转换为灰度图、调整大小到 84x84、缩小奖励并标准化观测值。

注意

StepCounter transform 是辅助性的。由于 CartPole 任务的目标是使轨迹尽可能长,计算步数可以帮助我们跟踪策略的性能。

对于本教程的目的,有两个重要的 Transforms

  • InitTracker 将通过在 TensorDict 中添加一个 "is_init" 布尔掩码来标记对 reset() 的调用,该掩码将跟踪哪些步骤需要重置 RNN 隐藏状态。

  • TensorDictPrimer transform 技术性稍强。它不是使用 RNN 策略所必需的。然而,它指示环境(以及随后的 collector)预期一些额外的键。添加后,调用 env.reset() 将使用归零张量填充 Primer 中指示的条目。知道这些张量是策略所期望的,collector 在收集过程中会传递它们。最终,我们会将隐藏状态存储在回放缓冲区中,这将有助于我们在损失模块中引导 RNN 操作的计算(否则将以 0 初始化)。总之:不包含此 transform 不会对策略的训练产生巨大影响,但会导致循环键从收集到的数据和回放缓冲区中消失,这反过来会导致训练稍微欠佳。幸运的是,我们提出的 LSTMModule 配备了一个帮助方法来构建正是这个 transform,所以我们可以等到构建它之后再来!

env = TransformedEnv(
    GymEnv("CartPole-v1", from_pixels=True, device=device),
    Compose(
        ToTensorImage(),
        GrayScale(),
        Resize(84, 84),
        StepCounter(),
        InitTracker(),
        RewardScaling(loc=0.0, scale=0.1),
        ObservationNorm(standard_normal=True, in_keys=["pixels"]),
    ),
)

一如既往,我们需要手动初始化我们的标准化常数

env.transform[-1].init_stats(1000, reduce_dim=[0, 1, 2], cat_dim=0, keep_dims=[0])
td = env.reset()

策略

我们的策略将有 3 个组件:一个 ConvNet 主干网络、一个 LSTMModule 内存层和一个浅层的 MLP 块,它将 LSTM 输出映射到动作值。

卷积网络

我们构建一个卷积网络,两侧搭配一个 torch.nn.AdaptiveAvgPool2d,它将输出压缩成一个大小为 64 的向量。ConvNet 可以协助我们完成这项工作

feature = Mod(
    ConvNet(
        num_cells=[32, 32, 64],
        squeeze_output=True,
        aggregator_class=nn.AdaptiveAvgPool2d,
        aggregator_kwargs={"output_size": (1, 1)},
        device=device,
    ),
    in_keys=["pixels"],
    out_keys=["embed"],
)

我们在批处理数据上执行第一个模块,以收集输出向量的大小

n_cells = feature(env.reset())["embed"].shape[-1]

LSTM Module

TorchRL 提供了一个专门的 LSTMModule 类,用于将 LSTM 集成到你的代码库中。它是 TensorDictModuleBase 的子类:因此,它有一组 in_keysout_keys,指示在模块执行期间应期望读取和写入/更新哪些值。该类带有可自定义的预定义值,以方便构建。

注意

使用限制:该类支持几乎所有 LSTM 功能,例如 dropout 或多层 LSTM。但是,为了遵守 TorchRL 的约定,此 LSTM 的 batch_first 属性必须设置为 True,这在 PyTorch 中不是默认值。但是,我们的 LSTMModule 更改了此默认行为,因此使用原生调用即可。

此外,LSTM 不能将 bidirectional 属性设置为 True,因为这在线设置中不可用。在这种情况下,默认值是正确的值。

lstm = LSTMModule(
    input_size=n_cells,
    hidden_size=128,
    device=device,
    in_key="embed",
    out_key="embed",
)

让我们看看 LSTM Module 类,特别是它的 in 和 out_keys

print("in_keys", lstm.in_keys)
print("out_keys", lstm.out_keys)

我们可以看到,这些值包含我们指定为 in_key(和 out_key)的键以及循环键名。out_keys 前面带有 “next” 前缀,表明它们需要写入“下一个” TensorDict 中。我们使用此约定(可以通过传递 in_keys/out_keys 参数来覆盖)来确保调用 step_mdp() 会将循环状态移动到根 TensorDict,使其在下一次调用时可用于 RNN(参见简介中的图)。

如前所述,我们还有一个可选的 transform 要添加到我们的环境中,以确保循环状态传递到缓冲区。make_tensordict_primer() 方法正是做了这件事

env.append_transform(lstm.make_tensordict_primer())

就这样!我们现在可以打印环境,检查添加 primer 后一切是否正常

print(env)

MLP

我们使用一个单层 MLP 来表示我们将用于策略的动作值。

mlp = MLP(
    out_features=2,
    num_cells=[
        64,
    ],
    device=device,
)

并将偏差填充为零

mlp[-1].bias.data.fill_(0.0)
mlp = Mod(mlp, in_keys=["embed"], out_keys=["action_value"])

使用 Q 值选择动作

策略的最后一部分是 Q 值模块。Q 值模块 QValueModule 将读取由我们的 MLP 产生的 "action_values" 键,并从中获取具有最大值的动作。我们唯一需要做的是指定动作空间,可以通过传递字符串或 action-spec 来完成。这使我们可以使用分类(有时称为“稀疏”)编码或其独热版本。

qval = QValueModule(spec=env.action_spec)

注意

TorchRL 还提供了一个包装类 torchrl.modules.QValueActor,它将模块与 QValueModule 一起包装到 Sequential 中,就像我们在这里明确做的那样。这样做的好处不大,而且过程不太透明,但最终结果与我们在此处所做类似。

我们现在可以将这些组件组合到一个 TensorDictSequential

stoch_policy = Seq(feature, lstm, mlp, qval)

DQN 是一种确定性算法,因此探索是其关键部分。我们将使用一个 \(\epsilon\)-greedy 策略,其中 epsilon 最初为 0.2,并逐步衰减到 0。这种衰减是通过调用 step() 实现的(见下面的训练循环)。

exploration_module = EGreedyModule(
    annealing_num_steps=1_000_000, spec=env.action_spec, eps_init=0.2
)
stoch_policy = Seq(
    stoch_policy,
    exploration_module,
)

使用模型计算损失

我们构建的模型非常适合在顺序设置中使用。但是,类 torch.nn.LSTM 可以使用 cuDNN 优化的后端在 GPU 设备上更快地运行 RNN 序列。我们当然不想错过加速训练循环的机会!要使用它,我们只需告诉 LSTM 模块在损失计算时以“循环模式”运行。由于我们通常希望有两个 LSTM 模块的副本,我们通过调用 set_recurrent_mode() 方法来做到这一点,该方法将返回一个新的 LSTM 实例(共享权重),该实例将假定输入数据具有顺序性。

policy = Seq(feature, lstm.set_recurrent_mode(True), mlp, qval)

由于我们还有一些未初始化的参数,在创建优化器等之前,我们应该先初始化它们。

policy(env.reset())

DQN 损失

我们的 DQN 损失需要我们传入策略,并且再次需要动作空间。虽然这看起来是重复的,但很重要,因为我们要确保 DQNLossQValueModule 类是兼容的,但彼此之间没有强依赖关系。

为了使用 Double-DQN,我们需要一个 delay_value 参数,它将创建一个网络参数的非可微副本作为目标网络。

loss_fn = DQNLoss(policy, action_space=env.action_spec, delay_value=True)

由于我们使用的是双 DQN,我们需要更新目标参数。我们将使用 SoftUpdate 实例来执行这项工作。

updater = SoftUpdate(loss_fn, eps=0.95)

optim = torch.optim.Adam(policy.parameters(), lr=3e-4)

Collector 和回放缓冲区

我们构建了最简单的数据 collector。我们将尝试用一百万帧训练算法,每次扩展缓冲区 50 帧。缓冲区将设计用于存储 2 万个轨迹,每个轨迹 50 步。在每个优化步骤(每次数据收集 16 次)中,我们将从缓冲区中收集 4 个条目,总共 200 个转换。我们将使用 LazyMemmapStorage 存储将数据保存在磁盘上。

注意

为了提高效率,我们在这里只运行几千次迭代。在实际设置中,总帧数应设置为 100 万。

collector = SyncDataCollector(env, stoch_policy, frames_per_batch=50, total_frames=200, device=device)
rb = TensorDictReplayBuffer(
    storage=LazyMemmapStorage(20_000), batch_size=4, prefetch=10
)

训练循环

为了跟踪进度,我们每进行 50 次数据收集就在环境中运行一次策略,并在训练后绘制结果。

utd = 16
pbar = tqdm.tqdm(total=1_000_000)
longest = 0

traj_lens = []
for i, data in enumerate(collector):
    if i == 0:
        print(
            "Let us print the first batch of data.\nPay attention to the key names "
            "which will reflect what can be found in this data structure, in particular: "
            "the output of the QValueModule (action_values, action and chosen_action_value),"
            "the 'is_init' key that will tell us if a step is initial or not, and the "
            "recurrent_state keys.\n",
            data,
        )
    pbar.update(data.numel())
    # it is important to pass data that is not flattened
    rb.extend(data.unsqueeze(0).to_tensordict().cpu())
    for _ in range(utd):
        s = rb.sample().to(device, non_blocking=True)
        loss_vals = loss_fn(s)
        loss_vals["loss"].backward()
        optim.step()
        optim.zero_grad()
    longest = max(longest, data["step_count"].max().item())
    pbar.set_description(
        f"steps: {longest}, loss_val: {loss_vals['loss'].item(): 4.4f}, action_spread: {data['action'].sum(0)}"
    )
    exploration_module.step(data.numel())
    updater.step()

    with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
        rollout = env.rollout(10000, stoch_policy)
        traj_lens.append(rollout.get(("next", "step_count")).max().item())

让我们绘制结果

if traj_lens:
    from matplotlib import pyplot as plt

    plt.plot(traj_lens)
    plt.xlabel("Test collection")
    plt.title("Test trajectory lengths")

结论

我们已经看到了如何在 TorchRL 中将 RNN 集成到策略中。你现在应该能够

  • 创建一个作为 TensorDictModule 的 LSTM 模块

  • 通过 InitTracker transform 指示 LSTM 模块需要重置

  • 将此模块集成到策略和损失模块中

  • 确保 collector 了解循环状态条目,以便它们可以与其余数据一起存储在回放缓冲区中


评价本教程

© 版权所有 2024, PyTorch。

使用Sphinx 构建,主题由Read the Docs 提供。

文档

查阅 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源