torch.nn.functional.gumbel_softmax¶
- torch.nn.functional.gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1)[source][source]¶
从 Gumbel-Softmax 分布(链接 1 链接 2)中采样,并可选择进行离散化。
- 参数
- 返回
从 Gumbel-Softmax 分布中采样的张量,形状与 logits 相同。如果
hard=True
,返回的样本将是 one-hot 向量;否则,它们将是在 dim 维度上求和为 1 的概率分布。- 返回类型
注意
此函数出于历史原因保留,未来可能从 nn.Functional 中移除。
注意
hard
的主要技巧是执行y_hard - y_soft.detach() + y_soft
这实现了两件事: - 使输出值精确为 one-hot(因为我们先加后减了 y_soft 的值) - 使梯度等于 y_soft 的梯度(因为我们去除了所有其他梯度)
- 示例:
>>> logits = torch.randn(20, 32) >>> # Sample soft categorical using reparametrization trick: >>> F.gumbel_softmax(logits, tau=1, hard=False) >>> # Sample hard categorical using "Straight-through" trick: >>> F.gumbel_softmax(logits, tau=1, hard=True)