SafeProbabilisticTensorDictSequential¶
- 类 torchrl.modules.tensordict_module.SafeProbabilisticTensorDictSequential(*args, **kwargs)[source]¶
tensordict.nn.ProbabilisticTensorDictSequential
子类,它接受TensorSpec
作为参数来控制输出域。类似于
TensorDictSequential
,但强制序列中的最后一个模块是ProbabilisticTensorDictModule
,并且还公开了get_dist
方法,以从ProbabilisticTensorDictModule
中恢复分布对象- 参数:
modules (TensorDictModules 的可迭代对象) – TensorDictModule 实例的有序序列,以 ProbabilisticTensorDictModule 结尾,按顺序运行。
partial_tolerant (bool, 可选) – 如果为
True
,则输入 tensordict 可以缺少一些输入键。如果是这样,则只会执行那些在给定现有键的情况下可以执行的模块。此外,如果输入 tensordict 是 tensordict 的惰性堆栈,并且 partial_tolerant 为True
,并且堆栈没有所需的键,则 TensorDictSequential 将扫描子 tensordict,查找具有所需键的子 tensordict(如果有)。