torcheval.metrics.MulticlassAUROC¶
- class torcheval.metrics.MulticlassAUROC(*, num_classes: int, average: str | None = 'macro', device: device | None = None)¶
计算多类分类中每个类别与其他类别对比的 AUROC,即 ROC 曲线下面积。每个类别与其他类别对比的多类 AUROC 等效于使用 num_classes 个任务运行二分类 AUROC,其中
输入 被转置
目标 从表示正确类别的 1 维张量转换为 2 维张量,其中每一行都是一个列表,包含属于该类别的示例。
有关多类和二分类 AUROC 之间关系的更多详细信息,请参见下面的示例。
此指标的函数版本为
torcheval.metrics.functional.multiclass_auroc()
。- 参数:
num_classes (int) – 类别数量。
average (str, 可选) –
'macro'
[默认]分别计算每个类别的指标,并返回它们的未加权平均值。
None
:分别计算每个类别的指标,并返回每个类别的指标。
示例
>>> import torch >>> from torcheval.metrics import MulticlassAUROC >>> metric = MulticlassAUROC(num_classes=4) >>> input = torch.tensor([[0.1, 0.1, 0.1, 0.1], [0.5, 0.5, 0.5, 0.5], [0.7, 0.7, 0.7, 0.7], [0.8, 0.8, 0.8, 0.8]]) >>> target = torch.tensor([0, 1, 2, 3]) >>> metric.update(input, target) >>> metric.compute() tensor(0.5000) >>> metric = MulticlassAUROC(num_classes=3, average=None) >>> input = torch.tensor([[0.1, 0, 0], [0, 1, 0], [0.1, 0.2, 0.7], [0, 0, 1]]) >>> target = torch.tensor([0, 1, 2, 2]) >>> metric.update(input, target) >>> metric.compute() tensor([0.8333, 1.0000, 1.0000]) the above is equivalent to >>> from torcheval.metrics import BinaryAUROC >>> metric = BinaryAUROC(num_tasks=3) >>> input = torch.tensor([[0.1, 0, 0.1, 0], [0, 1, 0.2, 0], [0, 0, 0.7, 1]]) >>> target = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 1]]) >>> metric.update(input, target) >>> metric.compute() tensor([0.8333, 1.0000, 1.0000])
- __init__(*, num_classes: int, average: str | None = 'macro', device: device | None = None) None ¶
初始化度量对象及其内部状态。
使用
self._add_state()
初始化度量类别的状态变量。状态变量应为torch.Tensor
、torch.Tensor
列表、以torch.Tensor
作为值的字典或torch.Tensor
的双端队列。
方法
__init__
(*, num_classes[, average, device])初始化度量对象及其内部状态。
compute
()实现此方法以根据状态变量计算并返回最终的度量值。
load_state_dict
(state_dict[, strict])从 state_dict 加载度量状态变量。
merge_state
(metrics)实现此方法以更新当前度量的状态变量,使其成为当前度量和输入度量的合并状态。
reset
()将度量状态变量重置为其默认值。
state_dict
()在 state_dict 中保存度量状态变量。
to
(device, *args, **kwargs)将度量状态变量中的张量移动到设备。
update
(input, target)使用真实标签和预测更新状态。
属性
device
Metric.to()
的最后一个输入设备。