torch.nn.functional.cosine_similarity¶
- torch.nn.functional.cosine_similarity(x1, x2, dim=1, eps=1e-8) Tensor¶
返回
x1和x2之間的餘弦相似度,沿 dim 計算。x1和x2必須可廣播到公共形狀。dim指的是此公共形狀中的維度。 輸出的維度dim會被壓縮(參見torch.squeeze()),從而使輸出 Tensor 減少 1 個維度。支援 型別提升。
- 引數
示例
>>> input1 = torch.randn(100, 128) >>> input2 = torch.randn(100, 128) >>> output = F.cosine_similarity(input1, input2) >>> print(output)