torcheval.metrics.R2Score¶
- class torcheval.metrics.R2Score(*, multioutput: str = 'uniform_average', num_regressors: int = 0, device: device | None = None)¶
计算 R 平方得分,它是因变量中可以由自变量解释的方差比例。其函数版本为
torcheval.metrics.functional.r2_score()
。- 参数:
multioutput (str, 可选) –
'uniform_average'
[默认值]:返回所有输出的得分的平均值,权重相同。'raw_values'
:返回完整的一组得分。variance_weighted
:返回所有输出的得分的加权平均值,权重为每个输出的方差。
num_regressors (int, 可选) – 使用的自变量的数量,应用于调整后的 R 平方得分。默认为零(标准 R 平方得分)。
- 引发:
ValueError –
如果 multioutput 的值不在 (
raw_values
,uniform_average
,variance_weighted
) 中。 - 如果 num_regressors 的值不是 [0, n_samples - 1] 范围内的整数
。
示例
>>> import torch >>> from torcheval.metrics import R2Score >>> metric = R2Score() >>> input = torch.tensor([0, 2, 1, 3]) >>> target = torch.tensor([0, 1, 2, 3]) >>> metric.update(input, target) >>> metric.compute() tensor(0.6) >>> metric = R2Score() >>> input = torch.tensor([[0, 2], [1, 6]]) >>> target = torch.tensor([[0, 1], [2, 5]]) >>> metric.update(input, target) >>> metric.compute() tensor(0.6250) >>> metric = R2Score(multioutput="raw_values") >>> input = torch.tensor([[0, 2], [1, 6]]) >>> target = torch.tensor([[0, 1], [2, 5]]) >>> metric.update(input, target) >>> metric.compute() tensor([0.5000, 0.7500]) >>> metric = R2Score(multioutput="variance_weighted") >>> input = torch.tensor([[0, 2], [1, 6]]) >>> target = torch.tensor([[0, 1], [2, 5]]) >>> metric.update(input, target) >>> metric.compute() tensor(0.7000) >>> metric = R2Score(multioutput="raw_values", num_regressors=2) >>> input = torch.tensor([1.2, 2.5, 3.6, 4.5, 6]) >>> target = torch.tensor([1, 2, 3, 4, 5]) >>> metric.update(input, target) >>> metric.compute() tensor(0.6200)
- __init__(*, multioutput: str = 'uniform_average', num_regressors: int = 0, device: device | None = None) None ¶
初始化度量对象及其内部状态。
使用
self._add_state()
初始化度量类状态变量。状态变量应为torch.Tensor
、torch.Tensor
列表、以torch.Tensor
为值的字典或torch.Tensor
的 deque。
方法
__init__
(*[, multioutput, num_regressors, ...])初始化度量对象及其内部状态。
计算
()返回 R 平方得分。
load_state_dict
(state_dict[, strict])从 state_dict 加载度量状态变量。
merge_state
(metrics)实现此方法以更新当前度量的状态变量,使其成为当前度量和输入度量的合并状态。
重置
()将度量状态变量重置为其默认值。
state_dict
()在 state_dict 中保存度量状态变量。
to
(device, *args, **kwargs)将度量状态变量中的张量移动到设备。
update
(input, target)使用真值和预测值更新状态。
属性
设备
Metric.to()
的最后一个输入设备。