set_composite_lp_aggregate¶
- 类 tensordict.nn.set_composite_lp_aggregate(mode: bool = True)¶
控制是否将
CompositeDistribution
的对数概率和熵聚合成单个张量。当
composite_lp_aggregate()
返回True
时,CompositeDistribution
的对数概率/熵将被求和到具有根 tensordict 形状的单个张量中。这种行为正在被弃用,转而倾向于非聚合的对数概率,后者提供了更大的灵活性和某种程度上更自然的 API(tensordict 样本、tensordict 对数概率、tensordict 熵)。composite_lp_aggregate 的值也可以通过 COMPOSITE_LP_AGGREGATE 环境变量进行控制。
示例
>>> _ = torch.manual_seed(0) >>> from tensordict import TensorDict >>> from tensordict.nn import CompositeDistribution, set_composite_lp_aggregate >>> import torch >>> from torch import distributions as d >>> params = TensorDict({ ... "cont": {"loc": torch.randn(3, 4), "scale": torch.rand(3, 4)}, ... ("nested", "disc"): {"logits": torch.randn(3, 10)} ... }, [3]) >>> dist = CompositeDistribution(params, ... distribution_map={"cont": d.Normal, ("nested", "disc"): d.Categorical}) >>> sample = dist.sample((4,)) >>> with set_composite_lp_aggregate(False): ... lp = dist.log_prob(sample) ... print(lp) TensorDict( fields={ cont_log_prob: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), nested: TensorDict( fields={ disc_log_prob: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4, 3]), device=None, is_shared=False)}, batch_size=torch.Size([4, 3]), device=None, is_shared=False) >>> with set_composite_lp_aggregate(True): ... lp = dist.log_prob(sample) ... print(lp) tensor([[-2.0886, -1.2155, -0.0414], [-2.8973, -5.5165, 2.4402], [-0.2806, -1.2799, 3.1733], [-3.0407, -4.3593, 0.5763]])