快捷方式

torch.nn.functional.gumbel_softmax

torch.nn.functional.gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1)[源代码]

从 Gumbel-Softmax 分布中抽样(链接 1 链接 2)并可选地离散化。

参数
  • logits (张量) – […, num_features] 未标准化的对数概率

  • tau (浮点数) – 非负标量温度

  • hard (布尔值) – 如果为 True,则返回的样本将被离散化为独热向量,但在自动梯度中将被区分,就好像它是软样本一样

  • dim (整数) – 将计算 softmax 的维度。默认值:-1。

返回值

从 Gumbel-Softmax 分布中抽样的张量,与 logits 的形状相同。如果 hard=True,则返回的样本将是独热的,否则它们将是概率分布,这些概率分布在 dim 上求和为 1。

返回类型

张量

注意

此函数出于遗留原因在此处,将来可能会从 nn.Functional 中删除。

注意

hard 的主要技巧是执行 y_hard - y_soft.detach() + y_soft

它实现了两个目标:- 使输出值完全成为独热向量(因为我们添加然后减去 y_soft 值)- 使梯度等于 y_soft 梯度(因为我们剥离了所有其他梯度)

示例:
>>> logits = torch.randn(20, 32)
>>> # Sample soft categorical using reparametrization trick:
>>> F.gumbel_softmax(logits, tau=1, hard=False)
>>> # Sample hard categorical using "Straight-through" trick:
>>> F.gumbel_softmax(logits, tau=1, hard=True)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源