快捷方式

torcheval.metrics.functional.mean_squared_error

torcheval.metrics.functional.mean_squared_error(input: Tensor, target: Tensor, *, sample_weight: Tensor | None = None, multioutput: str = 'uniform_average') Tensor

计算均方误差,它是 inputtarget 之间平方误差的平均值。其类版本为 torcheval.metrics.MeanSquaredError

参数:
  • input (Tensor) – 形状为 (n_sample, n_output) 的预测值张量。

  • target (Tensor) – 形状为 (n_sample, n_output) 的真实值张量。

  • sample_weight (可选) – 形状为 (n_sample, ) 的样本权重张量。默认为 None。

  • multioutput (可选) –

    • 'uniform_average' [默认]

      返回所有输出的分数,并使用统一权重进行平均。

    • 'raw_values':

      返回完整的分数集。

引发:

ValueError

  • 如果 multioutput 的值不存在于 (raw_values, uniform_average) 中。 - 如果 inputtarget 的维度不是 1D 或 2D。 - 如果 inputtarget 的大小不相同。 - 如果 inputtargetsample_weight 的第一维不相同。

示例

>>> import torch
>>> from torcheval.metrics.function import mean_squared_error
>>> input = torch.tensor([0.9, 0.5, 0.3, 0.5])
>>> target = torch.tensor([0.5, 0.8, 0.2, 0.8])
>>> mean_squared_error(input, target)
tensor(0.0875)

>>> input = torch.tensor([[0.9, 0.5], [0.3, 0.5]])
>>> target = torch.tensor([[0.5, 0.8], [0.2, 0.8]])
>>> mean_squared_error(input, target)
tensor(0.0875)

>>> input = torch.tensor([[0.9, 0.5], [0.3, 0.5]])
>>> target = torch.tensor([[0.5, 0.8], [0.2, 0.8]])
>>> mean_squared_error(input, target, multioutput="raw_values")
tensor([0.0850, 0.0900])

>>> input = torch.tensor([[0.9, 0.5], [0.3, 0.5]])
>>> target = torch.tensor([[0.5, 0.8], [0.2, 0.8]])
>>> mean_squared_error(input, target, sample_weight=torch.tensor([0.2, 0.8]))
tensor(0.0650)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源