OneHotCategorical¶
- 类 torchrl.modules.OneHotCategorical(logits: Optional[Tensor] = None, probs: Optional[Tensor] = None, grad_method: ReparamGradientStrategy = ReparamGradientStrategy.PassThrough, **kwargs)[source]¶
独热分类分布。
此类行为与 torch.distributions.Categorical 完全一致,区别在于它读取并生成离散张量的独热编码。
- 参数:
logits (torch.Tensor) – 事件的对数概率(未归一化)
probs (torch.Tensor) – 事件概率
grad_method (ReparamGradientStrategy, 可选) –
收集重参数化样本的策略。
ReparamGradientStrategy.PassThrough
将通过使用 softmax 值的对数概率作为样本梯度的代理来计算样本梯度。ReparamGradientStrategy.RelaxedOneHot
将使用torch.distributions.RelaxedOneHot
从分布中进行采样。
示例
>>> torch.manual_seed(0) >>> logits = torch.randn(4) >>> dist = OneHotCategorical(logits=logits) >>> print(dist.rsample((3,))) tensor([[1., 0., 0., 0.], [0., 0., 0., 1.], [1., 0., 0., 0.]])