torch.nn.CosineSimilarity
2022-07-06 本文已影响0人
菌子甚毒
https://pytorch.org/docs/stable/generated/torch.nn.CosineSimilarity.html

import torch.nn as nn
input1 = torch.randn(100, 128)
input2 = torch.randn(100, 128)
sim = nn.CosineSimilarity(dim=1)(input1,input2)
# i.e.,
frac1 = torch.sum(input1*input2,dim=1)
frac2 = ((torch.sum(input1**2,dim=1)**0.5)*(torch.sum(input2**2,dim=1)**0.5))
sim = frac1/frac2