注意
转到末尾 下载完整的示例代码。
模型优化入门¶
作者: Vincent Moens
注意
要在 Notebook 中运行此教程,请在开头添加一个安装单元,其中包含
!pip install tensordict !pip install torchrl
在 TorchRL 中,我们尝试以 PyTorch 惯用的方式处理优化,使用专门的损失模块,这些模块的唯一目的就是优化模型。这种方法有效地将策略的执行与其训练解耦开来,并允许我们设计与传统监督学习示例中相似的训练循环。
因此,典型的训练循环如下所示
..code - block::Python
>>> for i in range(n_collections): ... data = get_next_batch(env, policy) ... for j in range(n_optim): ... loss = loss_fn(data) ... loss.backward() ... optim.step()
在这个简洁的教程中,您将简要了解损失模块。由于 API 在基本用法上通常很直接,本教程将保持简短。
RL 目标函数¶
在强化学习 (RL) 中,创新通常涉及探索优化策略的新方法(即新算法),而不是像在其他领域那样专注于新架构。在 TorchRL 中,这些算法被封装在损失模块中。一个损失模块协调您算法的各个组成部分,并产生一组损失值,这些值可以通过反向传播来训练相应的组成部分。
在本教程中,我们将以一种流行的离策略算法 DDPG 作为示例。
要构建损失模块,唯一需要的是一组定义为 :class:`~tensordict.nn.TensorDictModule` 的网络。大多数时候,其中一个模块将是策略。可能还需要其他辅助网络,例如 Q 值网络或某种评论家网络。让我们看看这在实践中是什么样子的:DDPG 需要一个从观测空间到动作空间的确定性映射,以及一个预测状态-动作对值的价值网络。DDPG 损失函数将尝试找到能够输出在给定状态下最大化值的动作的策略参数。
要构建损失函数,我们需要 Actor 网络和价值网络。如果它们是根据 DDPG 的期望构建的,那么这就是我们获得可训练损失模块所需的全部内容。
from torchrl.envs import GymEnv
env = GymEnv("Pendulum-v1")
from torchrl.modules import Actor, MLP, ValueOperator
from torchrl.objectives import DDPGLoss
n_obs = env.observation_spec["observation"].shape[-1]
n_act = env.action_spec.shape[-1]
actor = Actor(MLP(in_features=n_obs, out_features=n_act, num_cells=[32, 32]))
value_net = ValueOperator(
MLP(in_features=n_obs + n_act, out_features=1, num_cells=[32, 32]),
in_keys=["observation", "action"],
)
ddpg_loss = DDPGLoss(actor_network=actor, value_network=value_net)
就是这样!我们的损失模块现在可以使用来自环境的数据运行了(我们省略了探索、存储和其他功能,以专注于损失函数的功能)
rollout = env.rollout(max_steps=100, policy=actor)
loss_vals = ddpg_loss(rollout)
print(loss_vals)
LossModule 的输出¶
如您所见,我们从损失模块获得的值不是一个单一的标量,而是一个包含多个损失的字典。
原因很简单:因为可能同时训练多个网络,并且由于一些用户可能希望在不同步骤中分开优化每个模块,TorchRL 的目标函数将返回包含各种损失组成部分的字典。
这种格式还允许我们 همراه 损失值传递元数据。一般来说,我们确保只有损失值是可微分的,这样您就可以简单地对字典中的值求和以获得总损失。如果您想确保完全控制正在发生的事情,您可以仅对键以 "loss_"
前缀开头的条目求和。
total_loss = 0
for key, val in loss_vals.items():
if key.startswith("loss_"):
total_loss += val
训练 LossModule¶
鉴于这一切,训练模块与在任何其他训练循环中所做的没有太大区别。因为它封装了模块,获取可训练参数列表的最简单方法是调用 parameters()
方法。
我们将需要一个优化器(如果您的选择是每个模块一个优化器)。
以下项目通常会在您的训练循环中找到
optim.step()
optim.zero_grad()
进一步考虑:目标参数¶
另一个重要的方面需要考虑的是离策略算法(如 DDPG)中目标参数的存在。目标参数通常代表参数随时间的延迟或平滑版本,它们在策略训练期间的价值估计中起着至关重要的作用。与使用价值网络参数的当前配置相比,利用目标参数进行策略训练通常会显著提高效率。一般来说,目标参数的管理由损失模块处理,减轻了用户的直接顾虑。但是,根据具体要求更新这些值仍然是用户的责任。TorchRL 提供了一些更新器,即 HardUpdate
和 SoftUpdate
,它们可以轻松实例化,无需深入了解损失模块的底层机制。
from torchrl.objectives import SoftUpdate
updater = SoftUpdate(ddpg_loss, eps=0.99)
在您的训练循环中,您需要在每个优化步骤或每个收集步骤中更新目标参数
updater.step()
这就是关于损失模块您入门所需了解的全部内容!
要进一步探索该主题,请查看