torcheval.metrics.ReciprocalRank¶
- class torcheval.metrics.ReciprocalRank(*, k: int | None = None, device: device | None = None)¶
计算前 k 个预测类中正确类的倒数排名。它的函数版本是
torcheval.metrics.functional.reciprocal_rank()
.- 参数:
k (int, 可选) – 要考虑的前 k 个类概率的数量。
示例
>>> import torch >>> from torcheval.metrics import ReciprocalRank >>> metric = ReciprocalRank() >>> metric.update(torch.tensor([[0.3, 0.1, 0.6], [0.5, 0.2, 0.3]]), torch.tensor([2, 1])) >>> metric.update(torch.tensor([[0.2, 0.1, 0.7], [0.3, 0.3, 0.4]]), torch.tensor([1, 0])) >>> metric.compute() tensor([1.0000, 0.3333, 0.3333, 0.5000]) >>> metric = ReciprocalRank(k=2) >>> metric.update(torch.tensor([[0.3, 0.1, 0.6], [0.5, 0.2, 0.3]]), torch.tensor([2, 1])) >>> metric.update(torch.tensor([[0.2, 0.1, 0.7], [0.3, 0.3, 0.4]]), torch.tensor([1, 0])) >>> metric.compute() tensor([1.0000, 0.0000, 0.0000, 0.5000])
- __init__(*, k: int | None = None, device: device | None = None) None ¶
初始化一个指标对象及其内部状态。
使用
self._add_state()
初始化指标类的状态变量。状态变量应为torch.Tensor
、torch.Tensor
列表、以torch.Tensor
作为值的字典,或torch.Tensor
的双端队列。
方法
__init__
(*[, k, device])初始化一个指标对象及其内部状态。
compute
()返回连接的倒数排名分数。
load_state_dict
(state_dict[, strict])从 state_dict 加载指标状态变量。
merge_state
(metrics)将指标状态与其来自其他指标实例的对应部分合并。
reset
()将指标状态变量重置为其默认值。
state_dict
()在 state_dict 中保存指标状态变量。
to
(device, *args, **kwargs)将指标状态变量中的张量移动到设备。
update
(input, target)使用真实标签和预测更新指标状态。
属性
device
Metric.to()
的最后一个输入设备。