快捷方式

distance_loss

torchrl.objectives.distance_loss(v1: Tensor, v2: Tensor, loss_function: str, strict_shape: bool = True)[源]

计算两个张量之间的距离损失。

参数:
  • v1 (张量) – 与 v2 形状兼容的张量

  • v2 (张量) – 与 v1 形状兼容的张量

  • loss_function (str) – “l2”、“l1”或“smooth_l1”之一,表示要使用的损失函数。

  • strict_shape (bool) – 如果为 False,则允许 v1 和 v2 具有不同的形状。默认值为 True

返回:

一个形状为 v1.view_as(v2) 或 v2.view_as(v1) 的张量,其值等于两个张量之间的距离损失。

文档

查阅 PyTorch 的完整开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源