快捷方式

SafeProbabilisticTensorDictSequential

class torchrl.modules.tensordict_module.SafeProbabilisticTensorDictSequential(*args, **kwargs)[source]

tensordict.nn.ProbabilisticTensorDictSequential 的子类,它接受 TensorSpec 作为参数来控制输出域。

类似于 TensorDictSequential,但强制要求序列中的最后一个模块为 ProbabilisticTensorDictModule,并且还公开了 get_dist 方法以从 ProbabilisticTensorDictModule 中恢复分布对象。

参数:
  • modules (TensorDictModules 的可迭代对象) – TensorDictModule 实例的有序序列,以 ProbabilisticTensorDictModule 结尾,将依次运行。

  • partial_tolerant (bool, 可选) – 如果为 True,则输入 tensordict 可能会缺少一些输入键。如果是这样,则只会执行给定存在的键可以执行的模块。此外,如果输入 tensordict 是 tensordict 的延迟堆栈,并且如果 partial_tolerant 为 True,并且如果堆栈没有所需的键,则 TensorDictSequential 将扫描子 tensordict 以查找具有所需键的 tensordict(如果有)。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源