快捷方式

tensordict.nn.distributions.CompositeDistribution

class tensordict.nn.distributions.CompositeDistribution(params: TensorDictBase, distribution_map: dict, *, name_map: Optional[dict] = None, extra_kwargs=None, aggregate_probabilities: Optional[bool] = None, log_prob_key: Optional[NestedKey] = None, entropy_key: Optional[NestedKey] = None)

一个使用 TensorDict 接口将多个分布组合在一起的复合分布。

此类允许对分布集合执行诸如 log_prob_compositeentropy_compositecdficdfrsamplesample 等操作,并返回一个 TensorDict。输入的 TensorDict 可能会被就地修改。

参数
  • params (TensorDictBase) – 一个嵌套的键-张量映射,其中根条目对应于样本名称,叶子条目是分布参数。条目名称必须与 distribution_map 中指定的名称匹配。

  • distribution_map (Dict[NestedKey, Type[torch.distribution.Distribution]]) – 指定要使用的分布类型。分布的名称应与 TensorDict 中的样本名称匹配。

关键字参数
  • name_map (Dict[NestedKey, NestedKey], optional) – 一个映射,指定每个样本应写入的位置。如果未提供,将使用 distribution_map 中的键名称。

  • extra_kwargs (Dict[NestedKey, Dict], optional) – 用于构造分布的额外关键字参数字典。

  • aggregate_probabilities (bool, optional) –

    如果为 Truelog_probentropy 方法将对各个分布的概率和熵求和并返回单个张量。如果为 False,单个对数概率将存储在输入的 TensorDict 中(对于 log_prob),或作为输出 TensorDict 的叶子返回(对于 entropy)。这可以在运行时通过将 aggregate_probabilities 参数传递给 log_probentropy 来覆盖。默认为 False

    警告

    此参数将在 v0.9 中弃用,届时 tensordict.nn.probabilistic.composite_lp_aggregate() 将默认为 False

  • log_prob_key (NestedKey, optional) –

    存储聚合对数概率的键。默认为 ‘sample_log_prob’

    注意

    如果 tensordict.nn.probabilistic.composite_lp_aggregate() 返回 False,则对数概率将写入 (“path”, “to”, “leaf”, “<sample_name>_log_prob”) 下,其中 (“path”, “to”, “leaf”, “<sample_name>”) 是对应于正在采样的叶子张量的 NestedKey。在这种情况下,log_prob_key 参数将被忽略。

  • entropy_key (NestedKey, optional) –

    存储熵的键。默认为 ‘entropy’

    注意

    如果 tensordict.nn.probabilistic.composite_lp_aggregate() 返回 False,则熵将写入 (“path”, “to”, “leaf”, “<sample_name>_entropy”) 下,其中 (“path”, “to”, “leaf”, “<sample_name>”) 是对应于正在采样的叶子张量的 NestedKey。在这种情况下,entropy_key 参数将被忽略。

注意

包含参数(params)的输入 TensorDict 的批次大小决定了分布的批次形状。例如,调用 log_prob 产生的 “sample_log_prob” 条目将具有参数的形状加上任何额外的批次维度。

另请参阅

ProbabilisticTensorDictModuleProbabilisticTensorDictSequential,了解如何将此类用作模型的一部分。

另请参阅

set_composite_lp_aggregate,控制对数概率的聚合。

示例

>>> params = TensorDict({
...     "cont": {"loc": torch.randn(3, 4), "scale": torch.rand(3, 4)},
...     ("nested", "disc"): {"logits": torch.randn(3, 10)}
... }, [3])
>>> dist = CompositeDistribution(params,
...     distribution_map={"cont": d.Normal, ("nested", "disc"): d.Categorical})
>>> sample = dist.sample((4,))
>>> with set_composite_lp_aggregate(False):
...     sample = dist.log_prob(sample)
...     print(sample)
TensorDict(
    fields={
        cont: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        cont_log_prob: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
            fields={
                disc: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.int64, is_shared=False),
                disc_log_prob: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([4]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([4]),
    device=None,
    is_shared=False)

文档

访问 PyTorch 的综合开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源