torcheval.metrics.WordErrorRate¶
- class torcheval.metrics.WordErrorRate(*, device: device | None = None)¶
计算预测词序列(s)与参考词序列(s)的词错误率。其函数式版本是
torcheval.metrics.functional.word_error_rate()
.示例
>>> import torch >>> from torcheval.metrics import WordErrorRate
>>> metric = WordErrorRate() >>> metric.update(["this is the prediction", "there is an other sample"], ["this is the reference", "there is another one"]) >>> metric.compute() tensor(0.5)
>>> metric = WordErrorRate() >>> metric.update(["this is the prediction", "there is an other sample"], ["this is the reference", "there is another one"]) >>> metric.update(["hello world", "welcome to the facebook"], ["hello metaverse", "welcome to meta"]) >>> metric.compute() tensor(0.53846)
- __init__(*, device: device | None = None) None ¶
初始化度量对象及其内部状态。
使用
self._add_state()
初始化度量类状态变量。状态变量应该是torch.Tensor
、一个torch.Tensor
列表、一个以torch.Tensor
为值的字典,或者是一个torch.Tensor
的双端队列。
方法
__init__
(*[, device])初始化度量对象及其内部状态。
compute
()返回词错误率分数
load_state_dict
(state_dict[, strict])从 state_dict 中加载度量状态变量。
merge_state
(metrics)将度量状态与其来自其他度量实例的对应部分合并。
reset
()将度量状态变量重置为其默认值。
state_dict
()将度量状态变量保存到 state_dict 中。
to
(device, *args, **kwargs)将度量状态变量中的张量移动到设备。
update
(input, target)使用编辑距离和参考序列的长度更新度量状态。
属性
device
Metric.to()
的最后一个输入设备。