torcheval.metrics.BinaryAUPRC¶
- class torcheval.metrics.BinaryAUPRC(*, num_tasks: int = 1, device: device | None = None)¶
计算 AUPRC,也称为平均精确率,它是二元分类中精确率-召回率曲线下的面积。
精确率定义为 \(\frac{T_p}{T_p+F_p}\),它是模型预测为正例的样本中真正例的概率。召回率定义为 \(\frac{T_p}{T_p+F_n}\),它是真正例被模型预测为正例的概率。
精确率-召回率曲线在 x 轴上绘制召回率,在 y 轴上绘制精确率,两者都介于 0 到 1 之间。此函数返回该曲线下的面积。如果面积接近 1,则该模型支持一个阈值,该阈值可以正确识别出很大一部分真正例,同时拒绝足够的假例,以便大多数真实预测都是真正例。
二元 auprc 支持多任务,如果输入和目标张量是二维的,则每行将被评估为独立的二元 auprc 实例。
此指标的函数版本是
torcheval.metrics.functional.binary_auprc()
.- 参数:
num_tasks (int) – 需要进行 BinaryAUPRC 计算的任务数量。默认值为 1。每个任务的 Binary AUPRC 将独立计算。结果等效于对每行运行 Binary AUPRC 计算。
示例
>>> import torch >>> from torcheval import BinaryAUPRC >>> metric = BinaryAUPRC() >>> input = torch.tensor([0.1, 0.5, 0.7, 0.8]) >>> target = torch.tensor([1, 0, 1, 1]) >>> metric.update(input, target) >>> metric.compute() tensor(0.9167) # scalar returned with 1D input tensors # with logits >>> metric = BinaryAUPRC() >>> input = torch.tensor([[.5, 2]]) >>> target = torch.tensor([[0, 0]]) >>> metric.update(input, target) >>> metric.compute() tensor([-0.]) >>> input = torch.tensor([[2, 1.5]]) >>> target = torch.tensor([[1, 0]]) >>> metric.update(input, target) >>> metric.compute() tensor([0.5000]) # 1D tensor returned with 2D input tensors # multiple tasks >>> metric = BinaryAUPRC(num_tasks=3) >>> input = torch.tensor([[0.1, 0, 0.1, 0], [0, 1, 0.2, 0], [0, 0, 0.7, 1]]) >>> target = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 1]]) >>> metric.update(input, target) >>> metric.compute() tensor([0.5000, 1.0000, 1.0000])
- __init__(*, num_tasks: int = 1, device: device | None = None) None ¶
初始化一个指标对象及其内部状态。
使用
self._add_state()
初始化指标类状态变量。状态变量应为torch.Tensor
、torch.Tensor
列表、以torch.Tensor
作为值的字典,或torch.Tensor
的 deque。
方法
__init__
(*[, num_tasks, device])初始化一个指标对象及其内部状态。
计算
()实现此方法以根据状态变量计算和返回最终指标值。
load_state_dict
(state_dict[, strict])从 state_dict 加载指标状态变量。
merge_state
(metrics)实现此方法以更新当前指标的状态变量,使其成为当前指标和输入指标的合并状态。
重置
()将指标状态变量重置为其默认值。
state_dict
()将指标状态变量保存在 state_dict 中。
to
(device, *args, **kwargs)将指标状态变量中的张量移动到设备。
update
(input, target)使用真值标签和预测更新状态。
属性
设备
Metric.to()
的最后一个输入设备。