快捷方式

GAILLoss

class torchrl.objectives.GAILLoss(*args, **kwargs)[source]

TorchRL 实现的生成对抗模仿学习 (GAIL) 损失函数。

“Generative Adversarial Imitation Learning” <https://arxiv.org/pdf/1606.03476> 中提出

参数:

discriminator_network (TensorDictModule) – 随机策略网络

关键字参数:
  • use_grad_penalty (bool, optional) – 是否使用梯度惩罚。默认值: False

  • gp_lambda (float, optional) – 梯度惩罚 lambda 值。默认值: 10

  • reduction (str, optional) – 指定应用于输出的 reduction 类型: "none" | "mean" | "sum". "none": 不应用 reduction, "mean": 输出的总和将除以输出元素的数量, "sum": 输出将被求和。默认值: "mean"

forward(tensordict: TensorDictBase = None) TensorDictBase[source]

forward 方法。

如果 use_grad_penalty 设置为 True,则计算判别器损失和梯度惩罚。如果 use_grad_penalty 设置为 True,则还会返回分离的梯度惩罚损失以用于日志记录。要查看输入 tensordict 中期望的键以及作为输出期望的键,请查看该类的 “in_keys”“out_keys” 属性。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源