torch.nn.functional.one_hot¶
- torch.nn.functional.one_hot(tensor, num_classes=-1) LongTensor ¶
接受形状为
(*)
的索引值的 LongTensor,并返回形状为(*, num_classes)
的张量,该张量在除最后一个维度的索引与输入张量的对应值匹配的位置之外的所有位置都为零,在这种情况下,它将为 1。另请参阅 维基百科上的 One-hot 编码。
- 参数
tensor (LongTensor) – 任意形状的类值。
num_classes (int) – 类的总数。如果设置为 -1,则类的数量将被推断为比输入张量中最大的类值大 1。
- 返回
LongTensor,它具有一个额外的维度,其中在最后一个维度的索引由输入指示的位置的值为 1,其他位置都为 0。
示例
>>> F.one_hot(torch.arange(0, 5) % 3) tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1, 0]]) >>> F.one_hot(torch.arange(0, 5) % 3, num_classes=5) tensor([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [1, 0, 0, 0, 0], [0, 1, 0, 0, 0]]) >>> F.one_hot(torch.arange(0, 6).view(3,2) % 3) tensor([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [1, 0, 0]], [[0, 1, 0], [0, 0, 1]]])