• 文档 >
  • 在 TorchEval 中使用指标
快捷方式

在 TorchEval 中使用指标

PyTorch 评估指标是 TorchEval 的核心功能之一。对于大多数指标,我们提供两种接口:一种是有状态的基于类的接口,它只累积必要的数据,直到被告知计算指标;另一种是纯函数式接口。

类指标

类指标跟踪指标状态,这使得它们能够通过跨多个进程的累积和同步来计算值。基类是 torcheval.metrics.Metric

类指标的核心 API 是 update()compute()reset()

  • update():使用输入数据更新指标状态。这通常用于需要添加新数据以进行指标计算的情况。

  • compute():从指标状态计算指标值,指标状态由之前的 update() 调用更新。计算频率可以低于更新频率。

  • reset():将指标状态变量重置为其默认值。通常,这会在每个 epoch 结束时调用,以清除指标状态。

注意

类指标跟踪内部状态,这些状态由传递给 update() 调用的输入数据更新。这意味着指标状态应该移动到与输入数据相同的设备。您可以在初始化时直接传入设备,也可以使用 to(device) API。 .device 属性显示指标状态的设备。

下面是在一个简单的训练脚本中使用类指标的示例。

import torch
from torcheval.metrics import MulticlassAccuracy

device = "cuda" if torch.cuda.is_available() else "cpu"
metric = MulticlassAccuracy(device=device)
num_epochs, num_batches, batch_size = 4, 8, 10
num_classes = 3

# number of batches between metric computations
compute_frequency = 2

for epoch in range(num_epochs):
    for batch_idx in range(num_batches):
        input = torch.randint(high=num_classes, size=(batch_size,), device=device)
        target = torch.randint(high=num_classes, size=(batch_size,), device=device)

        # metric.update() updates the metric state with new data
        metric.update(input, target)

        if (batch_idx + 1) % compute_frequency == 0:
                print(
                    "Epoch {}/{}, Batch {}/{} --- acc: {:.4f}".format(
                        epoch + 1,
                        num_epochs,
                        batch_idx + 1,
                        num_batches,
                        # metric.compute() returns metric value from all seen data
                        metric.compute(),
                    )
                )

    # metric.reset() reset metric states. It's typically called after the epoch completes.
    metric.reset()

保存和加载指标

类指标还实现了状态协议,.state_dict().load_state_dict()。这些函数可用于保存和加载指标。

import torch
from torcheval.metrics import MulticlassAccuracy

metric = MulticlassAccuracy()
input = torch.tensor([0, 2, 1, 3])
target = torch.tensor([0, 1, 2, 3])
metric.update(input, target)

state_dict = metric.state_dict()
loaded_metric = MulticlassAccuracy()
loaded_metric.load_state_dict(state_dict)

# returns torch.tensor(0.5)
loaded_metric.compute()

函数式指标

函数式指标是简单的 Python 函数,用于根据输入数据计算指标值。它们是轻量级的,并且相对更快,因为它们不需要保留和操作指标状态。下面的示例显示了使用函数式版本计算指标值。

import torch
from torcheval.metrics.functional import multiclass_accuracy

input = torch.tensor([0, 2, 1, 3])
target = torch.tensor([0, 1, 2, 3])
# returns torch.tensor(0.5)
multiclass_accuracy(input, target)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源