DreamerModelLoss¶
- class torchrl.objectives.DreamerModelLoss(*args, **kwargs)[source]¶
Dreamer 模型损失。
计算 dreamer 世界模型的损失。损失由 RSSM 先验和后验之间的 KL 散度、重建观测的重建损失以及预测奖励的奖励损失组成。
参考: https://arxiv.org/abs/1912.01603。
- 参数:
world_model (TensorDictModule) – 世界模型。
lambda_kl (float, optional) – KL 散度损失的权重。默认值:1.0。
lambda_reco (float, optional) – 重建损失的权重。默认值:1.0。
lambda_reward (float, optional) – 奖励损失的权重。默认值:1.0。
reco_loss (str, optional) – 重建损失。默认值:“l2”。
reward_loss (str, optional) – 奖励损失。默认值:“l2”。
free_nats (int, optional) – 自由纳特。默认值:3。
delayed_clamp (bool, optional) – 如果
True
,则在平均后进行 KL 钳制。如果为 False(默认),则先将 kl 散度钳制为 free nats 值,然后再求平均值。global_average (bool, optional) – 如果
True
,则损失将在所有维度上取平均值。否则,将对所有非批次/时间维度执行求和,并对批次和时间取平均值。默认值:False。