快捷方式

torcheval.metrics.Cat

class torcheval.metrics.Cat(*, dim: int = 0, device: device | None = None)

沿着维度 dim 连接所有输入张量。其函数版本为 torch.cat(input)

传递给 Cat.update() 的所有输入张量必须具有相同的形状(连接维度除外)或为空。

零维张量不是 Cat.update() 的有效输入。 torch.flatten() 可用于将零维张量展平为一维张量,然后再传递给 Cat.update()

示例

>>> import torch
>>> from torcheval.metrics import Cat
>>> metric = Cat(dim=1)
>>> metric.update(torch.tensor([[1, 2], [3, 4]]))
>>> metric.compute()
tensor([[1, 2],
        [3, 4]]))

>>> metric.update(torch.tensor([[5, 6], [7, 8]]))).compute()
tensor([[1, 2, 5, 6],
        [3, 4, 7, 8]]))

>>> metric.reset()
>>> metric.update(torch.tensor([0])).compute()
tensor([0])
__init__(*, dim: int = 0, device: device | None = None) None

初始化 Cat 指标对象。

参数:

dim – 连接的维度,与 torch.cat() 中相同。

方法

__init__(*[, dim, device])

初始化 Cat 指标对象。

compute()

返回连接后的输入。

load_state_dict(state_dict[, strict])

从 state_dict 加载指标状态变量。

merge_state(metrics)

实现此方法以更新当前指标的状态变量,使其成为当前指标和输入指标的合并状态。

reset()

将指标状态变量重置为其默认值。

state_dict()

在 state_dict 中保存指标状态变量。

to(device, *args, **kwargs)

将指标状态变量中的张量移动到设备。

update(input)

实现此方法以更新指标类的状态变量。

属性

device

Metric.to() 的最后一个输入设备。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取适合初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源