tensordict.nn.distributions.CompositeDistribution¶
- class tensordict.nn.distributions.CompositeDistribution(params: TensorDictBase, distribution_map: dict, *, name_map: Optional[dict] = None, extra_kwargs=None, aggregate_probabilities: Optional[bool] = None, log_prob_key: NestedKey = 'sample_log_prob', entropy_key: NestedKey = 'entropy')¶
分布的组合。
将分布与 TensorDict 接口组合在一起。方法(
log_prob_composite
、entropy_composite
、cdf
、icdf
、rsample
、sample
等)将返回一个 tensordict,如果输入是 tensordict,则可能会就地修改。- 参数:
params (TensorDictBase) – 一个嵌套的键-张量映射,其中根条目指向样本名称,叶子是分布参数。条目名称必须与
distribution_map
的名称匹配。distribution_map (Dict[NestedKey, Type[torch.distribution.Distribution]]) – 指示要使用的分布类型。分布的名称将与 tensordict 中样本的名称匹配。
- 关键词参数:
name_map (Dict[NestedKey, NestedKey]]) – 一个字典,表示每个样本应写入的位置。如果未提供,则将使用
distribution_map
中的键名称。extra_kwargs (Dict[NestedKey, Dict]) – 一个可能不完整的字典,其中包含要构建的分布的额外关键词参数。
aggregate_probabilities (bool) – 如果为
True
,则log_prob()
和entropy()
方法将对各个分布的概率和熵求和,并返回单个张量。如果为False
,则单个对数概率将注册在输入 tensordict 中(对于log_prob()
),或者作为输出 tensordict 的叶子返回(对于entropy()
)。可以通过将aggregate_probabilities
参数传递给log_prob
和entropy
来在运行时覆盖此参数。默认为False
。log_prob_key (NestedKey, optional) – 写入 log_prob 的键。默认为 ‘sample_log_prob’。
entropy_key (NestedKey, optional) – 写入熵的键。默认为 ‘entropy’。
注意
在此分布类中,包含参数 (
params
) 的输入 tensordict 的批大小指示分布的批形状。例如,调用log_prob
产生的"sample_log_prob"
条目的形状将与参数的形状相同(+ 任何补充批维度)。示例
>>> params = TensorDict({ ... "cont": {"loc": torch.randn(3, 4), "scale": torch.rand(3, 4)}, ... ("nested", "disc"): {"logits": torch.randn(3, 10)} ... }, [3]) >>> dist = CompositeDistribution(params, ... distribution_map={"cont": d.Normal, ("nested", "disc"): d.Categorical}) >>> sample = dist.sample((4,)) >>> sample = dist.log_prob(sample) >>> print(sample) TensorDict( fields={ cont: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), cont_log_prob: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), nested: TensorDict( fields={ disc: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.int64, is_shared=False), disc_log_prob: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4]), device=None, is_shared=False)}, batch_size=torch.Size([4]), device=None, is_shared=False)