快捷方式

torcheval.metrics.Perplexity

class torcheval.metrics.Perplexity(ignore_index: int | None = None, device: device | None = None)

困惑度衡量模型预测样本数据的程度。它通过以下公式计算:

ppl = exp (负对数似然之和 / 令牌数)

其函数版本为 torcheval.metrics.functional.text.perplexity

参数:

ignore_index (Tensor) – 如果指定,则计算困惑度时将忽略具有“ignore_index”的目标类别。默认值为 None。

示例

>>> import torch
>>> from torcheval.metrics.text import Perplexity
>>> metric=Perplexity()
>>> input = torch.tensor([[[0.3659, 0.7025, 0.3104]], [[0.0097, 0.6577, 0.1947]],[[0.5659, 0.0025, 0.0104]], [[0.9097, 0.0577, 0.7947]]])
>>> target = torch.tensor([[2],  [1], [2],  [1]])
>>> metric.update(input, target)
>>> metric.compute()
tensor(3.5257, dtype=torch.float64)
>>> metric=Perplexity(ignore_index=1)
>>> input = torch.tensor([[[0.3659, 0.7025, 0.3104]], [[0.0097, 0.6577, 0.1947]],[[0.5659, 0.0025, 0.0104]], [[0.9097, 0.0577, 0.7947]]])
>>> target = torch.tensor([[2],  [1], [2],  [1]])
>>> metric.update(input, target)
>>> metric.compute()
tensor(3.6347, dtype=torch.float64)
>>> metric1=Perplexity()
>>> input = torch.tensor([[[0.5659, 0.0025, 0.0104]], [[0.9097, 0.0577, 0.7947]]])
>>> target = torch.tensor([[2],  [1], ])
>>> metric1.update(input, target)
>>> metric1.compute()
tensor(4.5051, dtype=torch.float64)
>>> metric2=Perplexity()
>>> input = torch.tensor([[[0.3659, 0.7025, 0.3104]], [[0.0097, 0.6577, 0.1947]]])
>>> target = torch.tensor([[2],  [1]])
>>> metric2.update(input, target)
>>> metric2.compute())
tensor(2.7593, dtype=torch.float64)
>>> metric1.merge_state([metric2])
>>> metric1.compute())
tensor(3.5257, dtype=torch.float64)
__init__(ignore_index: int | None = None, device: device | None = None) None

初始化度量对象及其内部状态。

使用 self._add_state() 初始化度量类状态变量。状态变量应为 torch.Tensortorch.Tensor 列表、以 torch.Tensor 作为值的字典或 torch.Tensor 的 deque。

方法

__init__([ignore_index, device])

初始化度量对象及其内部状态。

计算()

根据 sum_log_probsnum_total 计算困惑度。

load_state_dict(state_dict[, strict])

从 state_dict 加载度量状态变量。

merge_state(metrics)

将度量状态与其来自其他度量实例的对应部分合并。

重置()

将度量状态变量重置为其默认值。

state_dict()

在 state_dict 中保存度量状态变量。

to(device, *args, **kwargs)

将度量状态变量中的张量移动到设备。

update(input, target)

使用新输入更新度量状态。

属性

设备

Metric.to() 的最后一个输入设备。

文档

获取 PyTorch 全面的开发者文档

查看文档

教程

获取初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源