快捷方式

MaskedOneHotCategorical

class torchrl.modules.MaskedOneHotCategorical(logits: Optional[Tensor] = None, probs: Optional[Tensor] = None, mask: Optional[Tensor] = None, indices: Optional[Tensor] = None, neg_inf: float = - inf, padding_value: Optional[int] = None, grad_method: ReparamGradientStrategy = ReparamGradientStrategy.PassThrough)[source]

MaskedCategorical 分布。

参考: https://tensorflowcn.cn/agents/api_docs/python/tf_agents/distributions/masked/MaskedCategorical

参数:
  • logits (torch.Tensor) – 事件对数概率(未归一化)

  • probs (torch.Tensor) – 事件概率。如果提供,则对应于掩码项的概率将归零,并且概率沿其最后一个维度重新归一化。

关键字参数:
  • mask (torch.Tensor) – 与 logits/probs 形状相同的布尔掩码,其中 False 条目是要掩码的条目。或者,如果 sparse_mask 为 True,则它表示分布中的有效索引列表。与 indices 互斥。

  • indices (torch.Tensor) – 表示必须考虑哪些动作的密集索引张量。与 mask 互斥。

  • neg_inf (float, optional) – 分配给无效(超出掩码)索引的对数概率值。默认为 -inf。

  • padding_value – 当 sparse_mask == True 时,掩码张量中的填充值,padding_value 将被忽略。

  • grad_method (ReparamGradientStrategy, optional) –

    收集重参数化样本的策略。 ReparamGradientStrategy.PassThrough 将使用 softmax 值对数概率作为样本梯度的代理来计算样本梯度。

    通过使用 softmax 评估的对数概率作为样本梯度的代理来计算样本梯度。

    ReparamGradientStrategy.RelaxedOneHot 将使用 torch.distributions.RelaxedOneHot 从分布中采样。

  • torch.manual_seed (>>>) –

  • torch.randn (>>> logits =) –

  • torch.tensor (>>> mask =) –

  • MaskedOneHotCategorical (>>> dist =) –

  • dist.sample (>>> sample =) –

  • print (>>>) –

  • 0], (tensor([[0, 0, 1,) – [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [1, 0, 0, 0], [0, 0, 1, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0]])

  • print

  • -1.0831, (tensor([-1.1203, -1.0928, -1.0831, -1.1203, -1.1203, -1.0831, -1.1203,) – -1.1203, -1.1203])

  • torch.zeros_like (>>> sample_non_valid =) –

  • 1 (>>> sample_non_valid[..., 1] =) –

  • print

  • tensor ([-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf]) –

  • probabilities (>>> # with) –

  • torch.ones (>>> prob =) –

  • prob.sum() (>>> prob = prob /) –

  • torch.tensor

  • MaskedOneHotCategorical

  • torch.arange (>>> s =) –

  • torch.nn.functional.one_hot (>>> s =) –

  • print

  • -2.1972, (tensor([ -inf, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972,) – -2.1972, -2.1972])

log_prob(value: Tensor) Tensor[source]

返回在 value 处评估的概率密度/质量函数的对数。

参数:

value (Tensor) –

property mode: Tensor

返回分布的众数。

rsample(sample_shape: Optional[Union[Size, Sequence]] = None) Tensor[source]

生成 sample_shape 形状的重参数化样本,或者如果分布参数是批量的,则生成 sample_shape 形状的重参数化样本批次。

sample(sample_shape: Optional[Union[Size, Sequence[int]]] = None) Tensor[source]

生成 sample_shape 形状的样本,或者如果分布参数是批量的,则生成 sample_shape 形状的样本批次。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源