快捷方式

torcheval.metrics.functional.topk_multilabel_accuracy

torcheval.metrics.functional.topk_multilabel_accuracy(input: Tensor, target: Tensor, *, criteria: str = 'exact_match', k: int = 2) Tensor

计算多标签准确率得分,即预测的 top k 标签与目标匹配的频率。它的类版本是 torcheval.metrics.TopKMultilabelAccuracy

参数:
  • input (Tensor) – 形状为 (n_sample, n_class) 的 logits/概率张量。

  • target (Tensor) – 形状为 (n_sample, n_class) 的真实标签张量。

  • criteria

  • [默认] (- 'exact_match') – 预测样本的 top-k 标签集必须与目标中对应的标签集完全匹配。也称为子集准确率。

  • 'hamming' (-) – top-k 正确标签在所有标签总数中的比例。

  • 'overlap' (-) – 预测样本的 top-k 标签集必须与目标中对应的标签集重叠。

  • 'contain' (-) – 预测样本的 top-k 标签集必须包含目标中对应的标签集。

  • 'belong' (-) – 预测样本的 top-k 标签集必须(完全)属于目标中对应的标签集。

  • k – 要考虑的 top 概率数量。K 应为大于或等于 1 的整数。

示例

>>> import torch
>>> from torcheval.metrics.functional import topkmultilabel_accuracy
>>> input = torch.tensor([[0.1, 0.5, 0.2], [0.3, 0.2, 0.1], [0.2, 0.4, 0.5], [0, 0.1, 0.9]])
>>> target = torch.tensor([[1, 1, 0], [0, 1, 0], [1, 1, 1], [0, 1, 0]])
>>> topkmultilabel_accuracy(input, target, k = 2)
tensor(0)  # 0 / 4

>>> input = torch.tensor([[0.1, 0.5, 0.2], [0.3, 0.2, 0.1], [0.2, 0.4, 0.5], [0, 0.1, 0.9]])
>>> target = torch.tensor([[1, 1, 0], [0, 1, 0], [1, 1, 1], [0, 1, 0]])
>>> topkmultilabel_accuracy(input, target, criteria="hamming", k = 2)
tensor(0.583)  # 7 / 12

>>> input = torch.tensor([[0.1, 0.5, 0.2], [0.3, 0.2, 0.1], [0.2, 0.4, 0.5], [0, 0.1, 0.9]])
>>> target = torch.tensor([[1, 1, 0], [0, 1, 0], [1, 1, 1], [0, 1, 0]])
>>> topkmultilabel_accuracy(input, target, criteria="overlap", k = 2)
tensor(1)  # 4 / 4

>>> input = torch.tensor([[0.1, 0.5, 0.2], [0.3, 0.2, 0.1], [0.2, 0.4, 0.5], [0, 0.1, 0.9]])
>>> target = torch.tensor([[1, 1, 0], [0, 1, 0], [1, 1, 1], [0, 1, 0]])
>>> topkmultilabel_accuracy(input, target, criteria="contain", k = 2)
tensor(0.5)  # 2 / 4

>>> input = torch.tensor([[0.1, 0.5, 0.2], [0.3, 0.2, 0.1], [0.2, 0.4, 0.5], [0, 0.1, 0.9]])
>>> target = torch.tensor([[1, 1, 0], [0, 1, 0], [1, 1, 1], [0, 1, 0]])
>>> topkmultilabel_accuracy(input, target, criteria="belong", k = 2)
tensor(0.25)  # 1 / 4

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源