快捷方式

tensordict.nn.distributions.CompositeDistribution

class tensordict.nn.distributions.CompositeDistribution(params: TensorDictBase, distribution_map: dict, *, name_map: Optional[dict] = None, extra_kwargs=None, aggregate_probabilities: Optional[bool] = None, log_prob_key: NestedKey = 'sample_log_prob', entropy_key: NestedKey = 'entropy')

分布的组合。

将分布与 TensorDict 接口组合在一起。方法(log_prob_compositeentropy_compositecdficdfrsamplesample 等)将返回一个 tensordict,如果输入是 tensordict,则可能会就地修改。

参数:
  • params (TensorDictBase) – 一个嵌套的键-张量映射,其中根条目指向样本名称,叶子是分布参数。条目名称必须与 distribution_map 的名称匹配。

  • distribution_map (Dict[NestedKey, Type[torch.distribution.Distribution]]) – 指示要使用的分布类型。分布的名称将与 tensordict 中样本的名称匹配。

关键词参数:
  • name_map (Dict[NestedKey, NestedKey]]) – 一个字典,表示每个样本应写入的位置。如果未提供,则将使用 distribution_map 中的键名称。

  • extra_kwargs (Dict[NestedKey, Dict]) – 一个可能不完整的字典,其中包含要构建的分布的额外关键词参数。

  • aggregate_probabilities (bool) – 如果为 True,则 log_prob()entropy() 方法将对各个分布的概率和熵求和,并返回单个张量。如果为 False,则单个对数概率将注册在输入 tensordict 中(对于 log_prob()),或者作为输出 tensordict 的叶子返回(对于 entropy())。可以通过将 aggregate_probabilities 参数传递给 log_probentropy 来在运行时覆盖此参数。默认为 False

  • log_prob_key (NestedKey, optional) – 写入 log_prob 的键。默认为 ‘sample_log_prob’

  • entropy_key (NestedKey, optional) – 写入熵的键。默认为 ‘entropy’

注意

在此分布类中,包含参数 (params) 的输入 tensordict 的批大小指示分布的批形状。例如,调用 log_prob 产生的 "sample_log_prob" 条目的形状将与参数的形状相同(+ 任何补充批维度)。

示例

>>> params = TensorDict({
...     "cont": {"loc": torch.randn(3, 4), "scale": torch.rand(3, 4)},
...     ("nested", "disc"): {"logits": torch.randn(3, 10)}
... }, [3])
>>> dist = CompositeDistribution(params,
...     distribution_map={"cont": d.Normal, ("nested", "disc"): d.Categorical})
>>> sample = dist.sample((4,))
>>> sample = dist.log_prob(sample)
>>> print(sample)
TensorDict(
    fields={
        cont: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        cont_log_prob: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
            fields={
                disc: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.int64, is_shared=False),
                disc_log_prob: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([4]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([4]),
    device=None,
    is_shared=False)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得解答

查看资源