torcheval.metrics.MeanSquaredError¶
- class torcheval.metrics.MeanSquaredError(*, multioutput: str = 'uniform_average', device: device | None = None)¶
计算均方误差 (MSE),即 input 和 target 之间平方误差的平均值。它的函数版本是
torcheval.metrics.functional.mean_squared_error()
.- 参数:
multioutput (str, Optional) –
'uniform_average'
[默认]: 返回所有输出的平均得分,权重相同。'raw_values'
: 返回完整的一组得分。
- 异常:
ValueError –
如果 multioutput 的值不存在于 (
raw_values
,uniform_average
) 中。 - 如果 input 或 target 的维度不是 1D 或 2D。 - 如果 input 和 target 的大小不同。 - 如果 input、target 和 sample_weight 的第一维不同。
示例
>>> import torch >>> from torcheval.metrics import MeanSquaredError >>> metric = MeanSquaredError() >>> input = torch.tensor([0.9, 0.5, 0.3, 0.5]) >>> target = torch.tensor([0.5, 0.8, 0.2, 0.8]) >>> metric.update(input, target) >>> metric.compute() tensor(0.0875) >>> metric = MeanSquaredError() >>> input = torch.tensor([[0.9, 0.5], [0.3, 0.5]]) >>> target = torch.tensor([[0.5, 0.8], [0.2, 0.8]]) >>> metric.update(input, target) >>> metric.compute() tensor(0.0875) >>> metric = MeanSquaredError(multioutput="raw_values") >>> input = torch.tensor([[0.9, 0.5], [0.3, 0.5]]) >>> target = torch.tensor([[0.5, 0.8], [0.2, 0.8]]) >>> metric.update(input, target) >>> metric.compute() tensor([0.0850, 0.0900]) >>> input = torch.tensor([[0.9, 0.5], [0.3, 0.5]]) >>> target = torch.tensor([[0.5, 0.8], [0.2, 0.8]]) >>> metric.update(input, target, sample_weight=torch.tensor([0.2, 0.8])) >>> metric.compute() tensor(0.0650)
- __init__(*, multioutput: str = 'uniform_average', device: device | None = None) None ¶
初始化一个指标对象及其内部状态。
使用
self._add_state()
初始化指标类状态变量。状态变量应该是torch.Tensor
、torch.Tensor
列表、包含torch.Tensor
作为值的字典,或者torch.Tensor
的双端队列。
方法
__init__
(*[, multioutput, device])初始化一个指标对象及其内部状态。
compute
()返回均方误差。
load_state_dict
(state_dict[, strict])从 state_dict 加载指标状态变量。
merge_state
(metrics)实现此方法以更新当前指标的状态变量,使其成为当前指标和输入指标的合并状态。
reset
()将指标状态变量重置为其默认值。
state_dict
()将指标状态变量保存在 state_dict 中。
to
(device, *args, **kwargs)将指标状态变量中的张量移动到设备。
update
(input, target, *[, sample_weight])使用地面真实值和预测值更新状态。
属性
device
Metric.to()
的最后一个输入设备。