快捷方式

DreamerModelLoss

class torchrl.objectives.DreamerModelLoss(*args, **kwargs)[源代码]

Dreamer 模型损失。

计算 dreamer 世界模型的损失。该损失由 RSSM 先验分布与后验分布之间的 KL 散度、重建观测的重建损失以及预测奖励的奖励损失组成。

参考文献:https://arxiv.org/abs/1912.01603

参数:
  • world_model (TensorDictModule) – 世界模型。

  • lambda_kl (float, optional) – KL 散度损失的权重。默认值:1.0。

  • lambda_reco (float, optional) – 重建损失的权重。默认值:1.0。

  • lambda_reward (float, optional) – 奖励损失的权重。默认值:1.0。

  • reco_loss (str, optional) – 重建损失类型。默认值:“l2”。

  • reward_loss (str, optional) – 奖励损失类型。默认值:“l2”。

  • free_nats (int, optional) – 自由纳特。默认值:3。

  • delayed_clamp (bool, optional) – 如果为 True,则 KL 截断在平均后进行。如果为 False(默认值),则 KL 散度首先截断到 free nats 值,然后进行平均。

  • global_average (bool, optional) – 如果为 True,则损失将在所有维度上进行平均。否则,将对所有非批次/时间维度进行求和,并在批次和时间维度上进行平均。默认值:False。

default_keys

_AcceptedKeys 的别名

forward(tensordict: TensorDict) Tensor[源代码]

它被设计用于读取输入的 TensorDict 并返回另一个包含以“loss*”命名的损失键的 tensordict。

将损失分解为各个组成部分后,训练器可以在训练过程中记录各种损失值。输出 tensordict 中存在的其他标量也将被记录。

参数:

tensordict – 包含计算损失所需值的输入 tensordict。

返回:

一个没有批次维度的新 tensordict,包含各种将被命名为“loss*”的损失标量。损失必须以此名称返回,因为它们将在反向传播前被训练器读取,这一点至关重要。

文档

访问 PyTorch 的完整开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源