torch.nn.functional.gumbel_softmax¶
- torch.nn.functional.gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1)[源代码]¶
从 Gumbel-Softmax 分布中抽样(链接 1 链接 2)并可选地离散化。
- 参数
- 返回值
从 Gumbel-Softmax 分布中抽样的张量,与 logits 的形状相同。如果
hard=True
,则返回的样本将是独热的,否则它们将是概率分布,这些概率分布在 dim 上求和为 1。- 返回类型
注意
此函数出于遗留原因在此处,将来可能会从 nn.Functional 中删除。
注意
hard 的主要技巧是执行 y_hard - y_soft.detach() + y_soft
它实现了两个目标:- 使输出值完全成为独热向量(因为我们添加然后减去 y_soft 值)- 使梯度等于 y_soft 梯度(因为我们剥离了所有其他梯度)
- 示例:
>>> logits = torch.randn(20, 32) >>> # Sample soft categorical using reparametrization trick: >>> F.gumbel_softmax(logits, tau=1, hard=False) >>> # Sample hard categorical using "Straight-through" trick: >>> F.gumbel_softmax(logits, tau=1, hard=True)