快捷方式

tensordict.nn.distributions.CompositeDistribution

class tensordict.nn.distributions.CompositeDistribution(params: TensorDictBase, distribution_map: dict, *, name_map: dict | None = None, extra_kwargs=None)

分布的组合。

使用 TensorDict 接口将分布组合在一起。所有方法(log_probcdficdfrsamplesample 等)将返回一个 tensordict,如果输入是 tensordict,则可能在内存中进行修改。

参数:
  • params (TensorDictBase) – 一个嵌套的键-张量映射,其中根条目指向样本名称,叶节点是分布参数。条目名称必须与 distribution_map 中的条目名称匹配。

  • distribution_map (Dict[NestedKey, Type[torch.distribution.Distribution]]) – 指示要使用的分布类型。分布的名称将与 tensordict 中样本的名称匹配。

关键字参数:
  • name_map (Dict[NestedKey, NestedKey]]) – 表示每个样本应该写入位置的字典。如果未提供,将使用 distribution_map 中的键名称。

  • extra_kwargs (Dict[NestedKey, Dict]) – 用于构建分布的可能不完整的额外关键字参数字典。

注意

在这个分布类中,包含 params(params)的输入 tensordict 的批次大小指示分布的批次形状。例如,从调用 log_prob 生成的 "sample_log_prob" 条目将具有 params 的形状(+ 任何补充批次维度)。

示例

>>> params = TensorDict({
...     "cont": {"loc": torch.randn(3, 4), "scale": torch.rand(3, 4)},
...     ("nested", "disc"): {"logits": torch.randn(3, 10)}
... }, [3])
>>> dist = CompositeDistribution(params,
...     distribution_map={"cont": d.Normal, ("nested", "disc"): d.Categorical})
>>> sample = dist.sample((4,))
>>> sample = dist.log_prob(sample)
>>> print(sample)
TensorDict(
    fields={
        cont: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        cont_log_prob: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
            fields={
                disc: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.int64, is_shared=False),
                disc_log_prob: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([4]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([4]),
    device=None,
    is_shared=False)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源