torch.nn.functional.cosine_similarity¶
- torch.nn.functional.cosine_similarity(x1, x2, dim=1, eps=1e-8) Tensor ¶
返回
x1
和x2
之间的余弦相似度,沿 dim 计算。x1
和x2
必须可广播到公共形状。dim
指的是此公共形状中的维度。 输出的维度dim
会被压缩(参见torch.squeeze()
),从而使输出 Tensor 减少 1 个维度。支持 类型提升。
- 参数
示例
>>> input1 = torch.randn(100, 128) >>> input2 = torch.randn(100, 128) >>> output = F.cosine_similarity(input1, input2) >>> print(output)