快捷方式

SafeProbabilisticTensorDictSequential

class torchrl.modules.tensordict_module.SafeProbabilisticTensorDictSequential(*args, **kwargs)[源代码]

tensordict.nn.ProbabilisticTensorDictSequential 的子类,接受 `TensorSpec` 参数来控制输出域。

TensorDictSequential 类似,但强制要求序列中的最后一个模块是 ProbabilisticTensorDictModule,并暴露 `get_dist` 方法以便从 ProbabilisticTensorDictModule 中恢复分布对象。

参数::
  • modules (TensorDictModules 可迭代对象) – TensorDictModule 实例的有序序列,以 ProbabilisticTensorDictModule 结尾,按顺序运行。

  • partial_tolerant (bool, 可选) – 如果为 `True`,输入 tensordict 可以缺少一些输入键。在这种情况下,只有那些在现有键下可以执行的模块会被执行。此外,如果输入 tensordict 是 tensordict 的惰性堆栈 *并且* partial_tolerant 为 `True` *并且* 堆栈没有所需的键,那么 TensorDictSequential 将扫描子 tensordict,查找包含所需键的那些(如果存在)。

文档

查阅 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源