torcheval.metrics.Cat¶
- class torcheval.metrics.Cat(*, dim: int = 0, device: device | None = None)¶
沿着维度 dim 连接所有输入张量。其函数版本为
torch.cat(input)
。传递给
Cat.update()
的所有输入张量必须具有相同的形状(连接维度除外)或为空。零维张量不是
Cat.update()
的有效输入。torch.flatten()
可用于将零维张量展平为一维张量,然后再传递给Cat.update()
。示例
>>> import torch >>> from torcheval.metrics import Cat >>> metric = Cat(dim=1) >>> metric.update(torch.tensor([[1, 2], [3, 4]])) >>> metric.compute() tensor([[1, 2], [3, 4]])) >>> metric.update(torch.tensor([[5, 6], [7, 8]]))).compute() tensor([[1, 2, 5, 6], [3, 4, 7, 8]])) >>> metric.reset() >>> metric.update(torch.tensor([0])).compute() tensor([0])
- __init__(*, dim: int = 0, device: device | None = None) None ¶
初始化 Cat 指标对象。
- 参数:
dim – 连接的维度,与
torch.cat()
中相同。
方法
__init__
(*[, dim, device])初始化 Cat 指标对象。
compute
()返回连接后的输入。
load_state_dict
(state_dict[, strict])从 state_dict 加载指标状态变量。
merge_state
(metrics)实现此方法以更新当前指标的状态变量,使其成为当前指标和输入指标的合并状态。
reset
()将指标状态变量重置为其默认值。
state_dict
()在 state_dict 中保存指标状态变量。
to
(device, *args, **kwargs)将指标状态变量中的张量移动到设备。
update
(input)实现此方法以更新指标类的状态变量。
属性
device
Metric.to()
的最后一个输入设备。