自适应LogSoftmaxWithLoss¶
- class torch.nn.AdaptiveLogSoftmaxWithLoss(in_features, n_classes, cutoffs, div_value=4.0, head_bias=False, device=None, dtype=None)[源代码]¶
高效的 softmax 近似。
如 Edouard Grave、Armand Joulin、Moustapha Cissé、David Grangier 和 Hervé Jégou 发表的《GPU 上的高效 softmax 近似》 中所述。
自适应 softmax 是一种用于训练具有大型输出空间的模型的近似策略。当标签分布高度不平衡时,它最有效,例如在自然语言建模中,单词频率分布近似遵循 齐夫定律。
自适应 softmax 根据标签的频率将标签划分为几个集群。这些集群中的每个集群可能包含不同数量的目标。此外,包含较不常见标签的集群为这些标签分配低维嵌入,从而加快计算速度。对于每个小批量,仅评估至少存在一个目标的集群。
这样做的原理是,访问频率高的集群(如包含最常见标签的第一个集群)也应该易于计算 - 即包含少量分配的标签。
我们强烈建议您查看原始论文以了解更多详细信息。
cutoffs
应该是一个按升序排序的整数序列。它控制集群数量和目标到集群的划分。例如,设置cutoffs = [10, 100, 1000]
表示前 10 个目标将被分配到自适应 softmax 的“头部”,目标 11、12、...、100 将被分配到第一个集群,目标 101、102、...、1000 将被分配到第二个集群,而目标 1001、1002、...、n_classes - 1 将被分配到最后一个,即第三个集群。div_value
用于计算每个附加集群的大小,该大小由 给出,其中 是集群索引(较不常见词的集群具有较大的索引,索引从 开始)。head_bias
如果设置为 True,则会向自适应 softmax 的“头部”添加偏差项。有关详细信息,请参阅论文。在官方实现中设置为 False。
警告
传递到此模块的输入标签应按其频率排序。这意味着最常见的标签应由索引 0 表示,而最不常见的标签应由索引 n_classes - 1 表示。
注意
此模块返回一个包含
output
和loss
字段的NamedTuple
。有关详细信息,请参阅进一步的文档。注意
要计算所有类别的对数概率,可以使用
log_prob
方法。- 参数
- 返回值
output 是一个大小为
N
的张量,包含为每个示例计算的目标对数概率loss 是一个表示计算出的负对数似然损失的标量
- 返回类型
NamedTuple
,包含output
和loss
字段
- 形状
input: 或
target: 或 ,其中每个值都满足
output1: 或
output2:
Scalar