torcheval.metrics.HitRate¶
- class torcheval.metrics.HitRate(*, k: int | None = None, device: device | None = None)¶
计算前 k 个预测类别中正确类别的命中率。其函数版本为
torcheval.metrics.functional.hit_rate()
。- 参数:
k (int, 可选) – 要考虑的前 k 个类别概率的数量。如果 k 为 None,则考虑所有类别,并返回 1.0 的命中率。
示例
>>> import torch >>> from torcheval.metrics import HitRate >>> metric = HitRate() >>> metric.update(torch.tensor([[0.3, 0.1, 0.6], [0.5, 0.2, 0.3]]), torch.tensor([2, 1])) >>> metric.update(torch.tensor([[0.2, 0.1, 0.7], [0.3, 0.3, 0.4]]), torch.tensor([1, 0])) >>> metric.compute() tensor([1., 1., 1., 1.]) >>> metric = HitRate(k=2) >>> metric.update(torch.tensor([[0.3, 0.1, 0.6], [0.5, 0.2, 0.3]]), torch.tensor([2, 1])) >>> metric.update(torch.tensor([[0.2, 0.1, 0.7], [0.3, 0.3, 0.4]]), torch.tensor([1, 0])) >>> metric.compute() tensor([1., 0., 0., 1.])
- __init__(*, k: int | None = None, device: device | None = None) None ¶
初始化度量对象及其内部状态。
使用
self._add_state()
初始化度量类别的状态变量。状态变量应为torch.Tensor
、torch.Tensor
列表、以torch.Tensor
作为值的字典或torch.Tensor
的 deque。
方法
__init__
(*[, k, device])初始化度量对象及其内部状态。
计算
()返回连接的命中率分数。
load_state_dict
(state_dict[, strict])从 state_dict 加载度量状态变量。
merge_state
(metrics)将度量状态与其来自其他度量实例的对应部分合并。
重置
()将度量状态变量重置为其默认值。
state_dict
()在 state_dict 中保存度量状态变量。
to
(device, *args, **kwargs)将度量状态变量中的张量移动到设备。
update
(input, target)使用地面真实标签和预测更新度量状态。
属性
设备
Metric.to()
的最后一个输入设备。