快捷方式

GAILLoss

torchrl.objectives.GAILLoss(*args, **kwargs)[源]

TorchRL 实现的生成对抗模仿学习 (GAIL) 损失函数。

发表于 “Generative Adversarial Imitation Learning” <https://arxiv.org/pdf/1606.03476>

参数

discriminator_network (TensorDictModule) – 随机 actor

关键字参数
  • use_grad_penalty (bool, 可选) – 是否使用梯度惩罚。默认值:False

  • gp_lambda (float, 可选) – 梯度惩罚 lambda。默认值:10

  • reduction (str, 可选) – 指定要应用于输出的归约方式:"none" | "mean" | "sum""none": 不应用归约,"mean": 输出的总和将除以输出中的元素数量,"sum": 对输出求和。默认值:"mean"

default_keys

_AcceptedKeys 的别名

forward(tensordict: TensorDictBase = None) TensorDictBase[源]

forward 方法。

计算判别器损失,如果 use_grad_penalty 设置为 True,则计算梯度惩罚。如果 use_grad_penalty 设置为 True,还会返回分离的梯度惩罚损失,用于日志记录。要查看输入 tensordict 中期望的键和输出中期望的键,请检查类的 “in_keys”“out_keys” 属性。

文档

访问 PyTorch 的完整开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源