快捷方式

torch.nn.functional.gumbel_softmax

torch.nn.functional.gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1)[source][source]

从 Gumbel-Softmax 分布(链接 1 链接 2)中采样,并可选择进行离散化。

参数
  • logits (Tensor) – […, num_features] 未归一化的对数概率

  • tau (float) – 非负标量温度

  • hard (bool) – 如果为 True,返回的样本将被离散化为 one-hot 向量,但在 autograd 中会按软样本进行微分

  • dim (int) – 将计算 softmax 的维度。默认为 -1。

返回

从 Gumbel-Softmax 分布中采样的张量,形状与 logits 相同。如果 hard=True,返回的样本将是 one-hot 向量;否则,它们将是在 dim 维度上求和为 1 的概率分布。

返回类型

Tensor

注意

此函数出于历史原因保留,未来可能从 nn.Functional 中移除。

注意

hard 的主要技巧是执行 y_hard - y_soft.detach() + y_soft

这实现了两件事: - 使输出值精确为 one-hot(因为我们先加后减了 y_soft 的值) - 使梯度等于 y_soft 的梯度(因为我们去除了所有其他梯度)

示例:
>>> logits = torch.randn(20, 32)
>>> # Sample soft categorical using reparametrization trick:
>>> F.gumbel_softmax(logits, tau=1, hard=False)
>>> # Sample hard categorical using "Straight-through" trick:
>>> F.gumbel_softmax(logits, tau=1, hard=True)

文档

访问 PyTorch 的综合开发者文档

查看文档

教程

获取针对初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并解答您的问题

查看资源