快捷方式

torcheval.metrics.functional.binary_recall

torcheval.metrics.functional.binary_recall(input: Tensor, target: Tensor, *, threshold: float = 0.5) Tensor

计算二元分类类的召回分数,计算方法是真阳性 (TP) 数量与实际阳性总数 (TP + FN) 之间的比率。其类版本是 torcheval.metrics.BinaryRecall

参数::
  • input (Tensor) – 预测标签/logits/概率的张量,形状为 (n_sample, )。

  • target (Tensor) – 形状为 (n_sample, ) 的真实标签张量。

  • threshold (float, 默认值 0.5) – 将输入转换为每个样本的预测标签的阈值。 torch.where(input < threshold, 0, 1) 将应用于 input

示例

>>> import torch
>>> from torcheval.metrics.functional.classification import binary_recall
>>> input = torch.tensor([0, 0, 1, 1])
>>> target = torch.tensor([0, 1, 1, 1])
>>> binary_recall(input, target)
tensor(0.6667)  # 2 / 3
>>> input = torch.tensor([0, 0.2, 0.4, 0.7])
>>> target = torch.tensor([1, 0, 1, 1])
>>> binary_recall(input, target)
tensor(0.3333)  # 1 / 3
>>> binary_recall(input, target, threshold=0.4)
tensor(0.5000)  # 1 / 2

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得解答

查看资源