快捷方式

torcheval.metrics.MeanSquaredError

class torcheval.metrics.MeanSquaredError(*, multioutput: str = 'uniform_average', device: device | None = None)

计算均方误差 (MSE),即 inputtarget 之间平方误差的平均值。它的函数版本是 torcheval.metrics.functional.mean_squared_error().

参数:

multioutput (str, Optional) –

  • 'uniform_average' [默认]: 返回所有输出的平均得分,权重相同。

  • 'raw_values': 返回完整的一组得分。

异常:

ValueError

  • 如果 multioutput 的值不存在于 (raw_values, uniform_average) 中。 - 如果 inputtarget 的维度不是 1D 或 2D。 - 如果 inputtarget 的大小不同。 - 如果 inputtargetsample_weight 的第一维不同。

示例

>>> import torch
>>> from torcheval.metrics import MeanSquaredError
>>> metric = MeanSquaredError()
>>> input = torch.tensor([0.9, 0.5, 0.3, 0.5])
>>> target = torch.tensor([0.5, 0.8, 0.2, 0.8])
>>> metric.update(input, target)
>>> metric.compute()
tensor(0.0875)

>>> metric = MeanSquaredError()
>>> input = torch.tensor([[0.9, 0.5], [0.3, 0.5]])
>>> target = torch.tensor([[0.5, 0.8], [0.2, 0.8]])
>>> metric.update(input, target)
>>> metric.compute()
tensor(0.0875)

>>> metric = MeanSquaredError(multioutput="raw_values")
>>> input = torch.tensor([[0.9, 0.5], [0.3, 0.5]])
>>> target = torch.tensor([[0.5, 0.8], [0.2, 0.8]])
>>> metric.update(input, target)
>>> metric.compute()
tensor([0.0850, 0.0900])

>>> input = torch.tensor([[0.9, 0.5], [0.3, 0.5]])
>>> target = torch.tensor([[0.5, 0.8], [0.2, 0.8]])
>>> metric.update(input, target, sample_weight=torch.tensor([0.2, 0.8]))
>>> metric.compute()
tensor(0.0650)
__init__(*, multioutput: str = 'uniform_average', device: device | None = None) None

初始化一个指标对象及其内部状态。

使用 self._add_state() 初始化指标类状态变量。状态变量应该是 torch.Tensortorch.Tensor 列表、包含 torch.Tensor 作为值的字典,或者 torch.Tensor 的双端队列。

方法

__init__(*[, multioutput, device])

初始化一个指标对象及其内部状态。

compute()

返回均方误差。

load_state_dict(state_dict[, strict])

从 state_dict 加载指标状态变量。

merge_state(metrics)

实现此方法以更新当前指标的状态变量,使其成为当前指标和输入指标的合并状态。

reset()

将指标状态变量重置为其默认值。

state_dict()

将指标状态变量保存在 state_dict 中。

to(device, *args, **kwargs)

将指标状态变量中的张量移动到设备。

update(input, target, *[, sample_weight])

使用地面真实值和预测值更新状态。

属性

device

Metric.to() 的最后一个输入设备。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取适合初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源