torcheval.metrics.Perplexity¶
- class torcheval.metrics.Perplexity(ignore_index: int | None = None, device: device | None = None)¶
困惑度衡量模型预测样本数据的程度。它通过以下公式计算:
ppl = exp (负对数似然之和 / 令牌数)
其函数版本为
torcheval.metrics.functional.text.perplexity
。- 参数:
ignore_index (Tensor) – 如果指定,则计算困惑度时将忽略具有“ignore_index”的目标类别。默认值为 None。
示例
>>> import torch >>> from torcheval.metrics.text import Perplexity
>>> metric=Perplexity() >>> input = torch.tensor([[[0.3659, 0.7025, 0.3104]], [[0.0097, 0.6577, 0.1947]],[[0.5659, 0.0025, 0.0104]], [[0.9097, 0.0577, 0.7947]]]) >>> target = torch.tensor([[2], [1], [2], [1]]) >>> metric.update(input, target) >>> metric.compute() tensor(3.5257, dtype=torch.float64)
>>> metric=Perplexity(ignore_index=1) >>> input = torch.tensor([[[0.3659, 0.7025, 0.3104]], [[0.0097, 0.6577, 0.1947]],[[0.5659, 0.0025, 0.0104]], [[0.9097, 0.0577, 0.7947]]]) >>> target = torch.tensor([[2], [1], [2], [1]]) >>> metric.update(input, target) >>> metric.compute() tensor(3.6347, dtype=torch.float64)
>>> metric1=Perplexity() >>> input = torch.tensor([[[0.5659, 0.0025, 0.0104]], [[0.9097, 0.0577, 0.7947]]]) >>> target = torch.tensor([[2], [1], ]) >>> metric1.update(input, target) >>> metric1.compute() tensor(4.5051, dtype=torch.float64)
>>> metric2=Perplexity() >>> input = torch.tensor([[[0.3659, 0.7025, 0.3104]], [[0.0097, 0.6577, 0.1947]]]) >>> target = torch.tensor([[2], [1]]) >>> metric2.update(input, target) >>> metric2.compute()) tensor(2.7593, dtype=torch.float64)
>>> metric1.merge_state([metric2]) >>> metric1.compute()) tensor(3.5257, dtype=torch.float64)
- __init__(ignore_index: int | None = None, device: device | None = None) None ¶
初始化度量对象及其内部状态。
使用
self._add_state()
初始化度量类状态变量。状态变量应为torch.Tensor
、torch.Tensor
列表、以torch.Tensor
作为值的字典或torch.Tensor
的 deque。
方法
__init__
([ignore_index, device])初始化度量对象及其内部状态。
计算
()根据 sum_log_probs 和 num_total 计算困惑度。
load_state_dict
(state_dict[, strict])从 state_dict 加载度量状态变量。
merge_state
(metrics)将度量状态与其来自其他度量实例的对应部分合并。
重置
()将度量状态变量重置为其默认值。
state_dict
()在 state_dict 中保存度量状态变量。
to
(device, *args, **kwargs)将度量状态变量中的张量移动到设备。
update
(input, target)使用新输入更新度量状态。
属性
设备
Metric.to()
的最后一个输入设备。