MaskedOneHotCategorical¶
- class torchrl.modules.MaskedOneHotCategorical(logits: Optional[Tensor] = None, probs: Optional[Tensor] = None, mask: Optional[Tensor] = None, indices: Optional[Tensor] = None, neg_inf: float = - inf, padding_value: Optional[int] = None, grad_method: ReparamGradientStrategy = ReparamGradientStrategy.PassThrough)[source]¶
带掩码的分类分布。
参考:https://tensorflowcn.cn/agents/api_docs/python/tf_agents/distributions/masked/MaskedCategorical
- 参数:
logits (torch.Tensor) – 事件对数概率(未归一化)
probs (torch.Tensor) – 事件概率。如果提供,则对应于被屏蔽项目的概率将被清零,并且概率将沿其最后一个维度重新归一化。
- 关键字参数:
mask (torch.Tensor) – 与
logits
/probs
形状相同的布尔掩码,其中False
条目指的是要被屏蔽的条目。或者,如果sparse_mask
为 True,则它表示分布中的有效索引列表。与indices
互斥。indices (torch.Tensor) – 表示必须考虑哪些动作的密集索引张量。与
mask
互斥。neg_inf (float, 可选) – 分配给无效(超出掩码)索引的对数概率值。默认为 -inf。
padding_value – 当 sparse_mask == True 时,掩码张量中的填充值,padding_value 将被忽略。
grad_method (ReparamGradientStrategy, 可选) –
收集重新参数化样本的策略。
ReparamGradientStrategy.PassThrough
将通过使用 softmax 值对数概率作为样本梯度的代理来计算样本梯度。ReparamGradientStrategy.RelaxedOneHot
将使用torch.distributions.RelaxedOneHot
从分布中采样。torch.manual_seed (>>>) –
torch.randn (>>> logits =) –
torch.tensor (>>> mask =) –
MaskedOneHotCategorical (>>> dist =) –
dist.sample (>>> sample =) –
print (>>>) –
0], (tensor([[0, 0, 1,) – [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [1, 0, 0, 0], [0, 0, 1, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0]])
print –
-1.0831, (tensor([-1.1203, -1.0928, -1.0831, -1.1203, -1.1203, -1.0831, -1.1203,) – -1.1203, -1.1203])
torch.zeros_like (>>> sample_non_valid =) –
1 (>>> sample_non_valid[..., 1] =) –
print –
tensor ([-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf]) –
probabilities (>>> # 使用) –
torch.ones (>>> prob =) –
prob.sum() (>>> prob = prob /) –
torch.tensor –
MaskedOneHotCategorical –
torch.arange (>>> s =) –
torch.nn.functional.one_hot (>>> s =) –
print –
-2.1972, (tensor([ -inf, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972,) – -2.1972, -2.1972])