快捷方式

torch.nn.functional.one_hot

torch.nn.functional.one_hot(tensor, num_classes=-1) LongTensor

接收形状为 (*) 的 LongTensor,其中包含索引值,并返回一个形状为 (*, num_classes) 的 tensor。此 tensor 除最后一维的索引与输入 tensor 对应值匹配的位置为 1 外,其余位置均为 0。

另请参阅 维基百科上的 One-hot(独热码)

参数
  • tensor (LongTensor) – 任何形状的类别值。

  • num_classes (int, optional) – 类别的总数量。如果设置为 -1,则类别数量将推断为输入 tensor 中最大类别值加一。默认值:-1

返回

一个 LongTensor,它在最后一维的输入指定索引位置的值为 1,其余位置为 0,且该维度比输入多一维。

示例

>>> F.one_hot(torch.arange(0, 5) % 3)
tensor([[1, 0, 0],
        [0, 1, 0],
        [0, 0, 1],
        [1, 0, 0],
        [0, 1, 0]])
>>> F.one_hot(torch.arange(0, 5) % 3, num_classes=5)
tensor([[1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 1, 0, 0],
        [1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0]])
>>> F.one_hot(torch.arange(0, 6).view(3,2) % 3)
tensor([[[1, 0, 0],
         [0, 1, 0]],
        [[0, 0, 1],
         [1, 0, 0]],
        [[0, 1, 0],
         [0, 0, 1]]])

文档

访问 PyTorch 全面的开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源