DreamerModelLoss¶
- class torchrl.objectives.DreamerModelLoss(*args, **kwargs)[source]¶
Dreamer 模型损失。
计算 dreamer 世界模型的损失。损失由 RSSM 的先验和后验之间的 KL 散度、重构观察结果的重构损失和预测奖励的奖励损失组成。
参考资料:https://arxiv.org/abs/1912.01603.
- 参数:
world_model (TensorDictModule) – 世界模型。
lambda_kl (float, 可选) – KL 散度损失的权重。默认值:1.0。
lambda_reco (float, 可选) – 重构损失的权重。默认值:1.0。
lambda_reward (float, 可选) – 奖励损失的权重。默认值:1.0。
reco_loss (str, 可选) – 重构损失。默认值:“l2”。
reward_loss (str, 可选) – 奖励损失。默认值:“l2”。
free_nats (int, 可选) – 免费纳特。默认值:3。
delayed_clamp (bool, 可选) – 如果
True
,则 KL 钳位在平均化后发生。如果为 False(默认),则 KL 散度首先被钳位到免费纳特值,然后进行平均。global_average (bool, 可选) – 如果
True
,则损失将在所有维度上取平均值。否则,将对所有非批次/时间维度进行求和,并在批次和时间上取平均值。默认值:False。
- forward(tensordict: TensorDict) Tensor [source]¶
它被设计为读取输入 TensorDict 并返回另一个包含名为“loss*”的损失键的 tensordict。
然后,将损失拆分为其组件可由训练器在整个训练过程中使用,以记录各种损失值。输出 tensordict 中存在的其他标量也将被记录。
- 参数:
tensordict – 一个包含计算损失所需值的输入 tensordict。
- 返回:
一个新的 tensordict,没有批次维度,包含各种损失标量,这些标量将被命名为“loss*”。这些损失必须以这种名称返回,因为它们将在反向传播之前由训练器读取。