what does torch.nn.CosineEmbeddi
2020-07-29 本文已影响0人
asl_1da7
loss function for each sample
def CustomCosineEmbeddingLoss(x1, x2, target):
x1_ = torch.sqrt(torch.sum(x1 * x1, dim = 1)) # |x1|
x2_ = torch.sqrt(torch.sum(x2 * x2, dim = 1)) # |x2|
cos_x1_x2 = torch.sum(x1 * x2, dim = 1)/(x1_ * x2_)
ans = torch.mean(target- cos_x1_x2)
return ans
cirt = torch.nn.CosineEmbeddingLoss(reduction = "mean")
x1 = torch.randn((5,3))
x2 = torch.randn((5,3))
a1 = cirt(x1,x2,target)
print(a1)
a2 =CustomCosineEmbeddingLoss(x1,x2, target)
print(a2)
# Out[11]:
# tensor(1.0479)
# tensor(1.0479)