在 TorchEval 中使用指标¶
PyTorch 评估指标是 TorchEval 的核心功能之一。对于大多数指标,我们提供两种接口:一种是有状态的基于类的接口,它只累积必要的数据,直到被告知计算指标;另一种是纯函数式接口。
类指标¶
类指标跟踪指标状态,这使得它们能够通过跨多个进程的累积和同步来计算值。基类是 torcheval.metrics.Metric
。
类指标的核心 API 是 update()
、compute()
和 reset()
。
update()
:使用输入数据更新指标状态。这通常用于需要添加新数据以进行指标计算的情况。compute()
:从指标状态计算指标值,指标状态由之前的update()
调用更新。计算频率可以低于更新频率。reset()
:将指标状态变量重置为其默认值。通常,这会在每个 epoch 结束时调用,以清除指标状态。
注意
类指标跟踪内部状态,这些状态由传递给 update()
调用的输入数据更新。这意味着指标状态应该移动到与输入数据相同的设备。您可以在初始化时直接传入设备,也可以使用 to(device)
API。 .device
属性显示指标状态的设备。
下面是在一个简单的训练脚本中使用类指标的示例。
import torch
from torcheval.metrics import MulticlassAccuracy
device = "cuda" if torch.cuda.is_available() else "cpu"
metric = MulticlassAccuracy(device=device)
num_epochs, num_batches, batch_size = 4, 8, 10
num_classes = 3
# number of batches between metric computations
compute_frequency = 2
for epoch in range(num_epochs):
for batch_idx in range(num_batches):
input = torch.randint(high=num_classes, size=(batch_size,), device=device)
target = torch.randint(high=num_classes, size=(batch_size,), device=device)
# metric.update() updates the metric state with new data
metric.update(input, target)
if (batch_idx + 1) % compute_frequency == 0:
print(
"Epoch {}/{}, Batch {}/{} --- acc: {:.4f}".format(
epoch + 1,
num_epochs,
batch_idx + 1,
num_batches,
# metric.compute() returns metric value from all seen data
metric.compute(),
)
)
# metric.reset() reset metric states. It's typically called after the epoch completes.
metric.reset()
保存和加载指标¶
类指标还实现了状态协议,.state_dict()
和 .load_state_dict()
。这些函数可用于保存和加载指标。
import torch
from torcheval.metrics import MulticlassAccuracy
metric = MulticlassAccuracy()
input = torch.tensor([0, 2, 1, 3])
target = torch.tensor([0, 1, 2, 3])
metric.update(input, target)
state_dict = metric.state_dict()
loaded_metric = MulticlassAccuracy()
loaded_metric.load_state_dict(state_dict)
# returns torch.tensor(0.5)
loaded_metric.compute()
函数式指标¶
函数式指标是简单的 Python 函数,用于根据输入数据计算指标值。它们是轻量级的,并且相对更快,因为它们不需要保留和操作指标状态。下面的示例显示了使用函数式版本计算指标值。
import torch
from torcheval.metrics.functional import multiclass_accuracy
input = torch.tensor([0, 2, 1, 3])
target = torch.tensor([0, 1, 2, 3])
# returns torch.tensor(0.5)
multiclass_accuracy(input, target)