ProbabilisticTensorDictSequential¶
- class tensordict.nn.ProbabilisticTensorDictSequential(*args, **kwargs)¶
一个包含至少一个
ProbabilisticTensorDictModule
的TensorDictModules
序列。此类扩展了
TensorDictSequential
,通常配置为一个模块序列,其中最后一个模块是ProbabilisticTensorDictModule
的实例。然而,它也支持一个或多个中间模块是ProbabilisticTensorDictModule
的实例,而最后一个模块可能不是概率性的配置。在所有情况下,它都暴露了get_dist()
方法,以从序列中的ProbabilisticTensorDictModule
实例中恢复分布对象。多个概率性模块可以共存于一个
ProbabilisticTensorDictSequential
中。如果 return_composite 为False
(默认),则只有最后一个模块会产生分布,而其他模块将作为常规的TensorDictModule
实例执行。然而,如果一个 ProbabilisticTensorDictModule 不是序列中的最后一个模块,并且 return_composite=False,则在尝试查询该模块时将引发 ValueError。如果 return_composite=True,所有中间的 ProbabilisticTensorDictModule 实例将共同组成一个单独的CompositeDistribution
实例。如果样本相互依赖,则结果对数概率将是条件概率:当
\[Z = F(X, Y)\]则 Z 的对数概率将是
\[log(p(z | x, y))\]- 参数:
*modules (sequence or OrderedDict of TensorDictModuleBase or ProbabilisticTensorDictModule) – 一个有序的
TensorDictModule
实例序列,通常以ProbabilisticTensorDictModule
结束,用于顺序运行。模块可以是 TensorDictModuleBase 的实例,也可以是任何符合此签名的其他函数。请注意,如果使用了非 TensorDictModuleBase 的可调用对象,其输入和输出键将不会被跟踪,因此不会影响 TensorDictSequential 的 in_keys 和 out_keys 属性。- 关键字参数:
partial_tolerant (bool, optional) – 如果为
True
,输入 tensordict 可以缺少部分输入键。在这种情况下,只有那些根据现有键可以执行的模块会被执行。此外,如果输入的 tensordict 是 tensordicts 的惰性堆叠(lazy stack),并且 partial_tolerant 为True
,并且该堆叠不包含所需的键,则 TensorDictSequential 将扫描子 tensordicts,查找包含所需键(如果有的话)的 tensordicts。默认为False
。return_composite (bool, optional) –
如果为 True,并且找到了多个
ProbabilisticTensorDictModule
或ProbabilisticTensorDictSequential
实例,则将使用一个CompositeDistribution
实例。否则,将仅使用最后一个模块来构建分布。默认为False
。警告
`return_composite` 的行为将在 v0.9 中改变,并从那时起默认为 True。
- 抛出:
ValueError – 如果输入的模块序列为空。
TypeError – 如果最后一个模块不是
ProbabilisticTensorDictModule
或ProbabilisticTensorDictSequential
的实例。
示例
>>> from tensordict.nn import ProbabilisticTensorDictModule as Prob, ProbabilisticTensorDictSequential as Seq >>> import torch >>> # Typical usage: a single distribution is computed last in the sequence >>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import ProbabilisticTensorDictModule as Prob, ProbabilisticTensorDictSequential as Seq, ... TensorDictModule as Mod >>> torch.manual_seed(0) >>> >>> module = Seq( ... Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]), ... Prob(in_keys=["loc"], out_keys=["sample"], distribution_class=torch.distributions.Normal, ... distribution_kwargs={"scale": 1}), ... ) >>> input = TensorDict(x=torch.ones(3)) >>> td = module(input.copy()) >>> print(td) TensorDict( fields={ loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), sample: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(module.get_dist(input)) Normal(loc: torch.Size([3]), scale: torch.Size([3])) >>> print(module.log_prob(td)) tensor([-0.9189, -0.9189, -0.9189]) >>> # Intermediate distributions are ignored when return_composite=False >>> module = Seq( ... Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]), ... Prob(in_keys=["loc"], out_keys=["sample0"], distribution_class=torch.distributions.Normal, ... distribution_kwargs={"scale": 1}), ... Mod(lambda x: x + 1, in_keys=["sample0"], out_keys=["loc2"]), ... Prob(in_keys={"loc": "loc2"}, out_keys=["sample1"], distribution_class=torch.distributions.Normal, ... distribution_kwargs={"scale": 1}), ... return_composite=False, ... ) >>> td = module(TensorDict(x=torch.ones(3))) >>> print(td) TensorDict( fields={ loc2: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), sample0: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), sample1: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(module.get_dist(input)) Normal(loc: torch.Size([3]), scale: torch.Size([3])) >>> print(module.log_prob(td)) tensor([-0.9189, -0.9189, -0.9189]) >>> # Intermediate distributions produce a CompositeDistribution when return_composite=True >>> module = Seq( ... Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]), ... Prob(in_keys=["loc"], out_keys=["sample0"], distribution_class=torch.distributions.Normal, ... distribution_kwargs={"scale": 1}), ... Mod(lambda x: x + 1, in_keys=["sample0"], out_keys=["loc2"]), ... Prob(in_keys={"loc": "loc2"}, out_keys=["sample1"], distribution_class=torch.distributions.Normal, ... distribution_kwargs={"scale": 1}), ... return_composite=True, ... ) >>> input = TensorDict(x=torch.ones(3)) >>> td = module(input.copy()) >>> print(td) TensorDict( fields={ loc2: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), sample0: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), sample1: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(module.get_dist(input)) CompositeDistribution({'sample0': Normal(loc: torch.Size([3]), scale: torch.Size([3])), 'sample1': Normal(loc: torch.Size([3]), scale: torch.Size([3]))}) >>> print(module.log_prob(td)) TensorDict( fields={ sample0_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), sample1_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> # Even a single intermediate distribution is wrapped in a CompositeDistribution when >>> # return_composite=True >>> module = Seq( ... Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]), ... Prob(in_keys=["loc"], out_keys=["sample0"], distribution_class=torch.distributions.Normal, ... distribution_kwargs={"scale": 1}), ... Mod(lambda x: x + 1, in_keys=["sample0"], out_keys=["y"]), ... return_composite=True, ... ) >>> td = module(TensorDict(x=torch.ones(3))) >>> print(td) TensorDict( fields={ loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), sample0: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), y: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(module.get_dist(input)) CompositeDistribution({'sample0': Normal(loc: torch.Size([3]), scale: torch.Size([3]))}) >>> print(module.log_prob(td)) TensorDict( fields={ sample0_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
- build_dist_from_params(tensordict: TensorDictBase) Distribution ¶
根据输入参数构建分布,而不评估序列中的其他模块。
此方法在序列中查找最后一个
ProbabilisticTensorDictModule
,并使用它来构建分布。- 参数:
tensordict (TensorDictBase) – 包含分布参数的输入 tensordict。
- 返回:
构建的分布对象。
- 返回类型:
D.Distribution
- 抛出:
RuntimeError – 如果在序列中未找到
ProbabilisticTensorDictModule
。
- property default_interaction_type¶
使用迭代启发式方法返回模块的 default_interaction_type。
此属性以反向顺序迭代所有模块,尝试从任何子模块中检索 default_interaction_type 属性。遇到的第一个非 None 值将被返回。如果未找到此类值,则返回默认的 interaction_type()。
- forward(tensordict: TensorDictBase = None, tensordict_out: tensordict.base.TensorDictBase | None = None, **kwargs) TensorDictBase ¶
当未设置 tensordict 参数时,使用 kwargs 创建 TensorDict 实例。
- get_dist(tensordict: TensorDictBase, tensordict_out: Optional[TensorDictBase] = None, **kwargs) Distribution ¶
返回将输入 tensordict 通过序列传递后得到的分布。
如果 return_composite 为
False
(默认),此方法将仅考虑序列中的最后一个概率性模块。否则,它将返回一个包含所有概率性模块分布的
CompositeDistribution
实例。- 参数:
tensordict (TensorDictBase) – 输入 tensordict。
tensordict_out (TensorDictBase, optional) – 输出 tensordict。如果为
None
,将创建一个新的 tensordict。默认为None
。
- 关键字参数:
**kwargs – 传递给底层模块的额外关键字参数。
- 返回:
结果分布对象。
- 返回类型:
D.Distribution
- 抛出:
RuntimeError – 如果在序列中未找到概率性模块。
注意
当 return_composite 为
True
时,分布是以前一个序列中的样本为条件的。这意味着如果一个模块依赖于前一个概率性模块的输出,其分布将是条件分布。
- get_dist_params(tensordict: TensorDictBase, tensordict_out: Optional[TensorDictBase] = None, **kwargs) tuple[torch.distributions.distribution.Distribution, tensordict.base.TensorDictBase] ¶
返回分布参数和输出 tensordict。
此方法运行 ProbabilisticTensorDictSequential 模块的确定性部分以获取分布参数。交互类型被设置为当前全局交互类型(如果可用),否则默认为最后一个模块的交互类型。
- 参数:
tensordict (TensorDictBase) – 输入 tensordict。
tensordict_out (TensorDictBase, optional) – 输出 tensordict。如果为
None
,将创建一个新的 tensordict。默认为None
。
- 关键字参数:
**kwargs – 传递给模块确定性部分的额外关键字参数。
- 返回:
一个包含分布对象和输出 tensordict 的元组。
- 返回类型:
tuple[D.Distribution, TensorDictBase]
注意
在此方法的执行期间,交互类型被临时设置为指定的值。
- log_prob(tensordict, tensordict_out: Optional[TensorDictBase] = None, *, dist: Optional[Distribution] = None, **kwargs)¶
返回输入 tensordict 的对数概率。
如果 self.return_composite 为
True
且分布是一个CompositeDistribution
,则此方法将返回整个复合分布的对数概率。否则,它将仅考虑序列中的最后一个概率性模块。
- 参数:
tensordict (TensorDictBase) – 输入 tensordict。
tensordict_out (TensorDictBase, optional) – 输出 tensordict。如果为
None
,将创建一个新的 tensordict。默认为None
。
- 关键字参数:
dist (torch.distributions.Distribution, optional) – 分布对象。如果为
None
,将使用 get_dist 计算。默认为None
。- 返回:
输入 tensordict 的对数概率。
- 返回类型:
警告
在未来的版本(v0.9)中,aggregate_probabilities、inplace 和 include_sum 的默认值将发生变化。为避免警告,建议显式地将这些参数传递给 log_prob 方法或在构造函数中设置它们。