pytorch中的损失函数

2022-06-08  本文已影响0人  Jlan

1. 多标签分类损失函数

pytorch中能计算多标签分类任务loss的方法有好几个。
binary_cross_entropy和binary_cross_entropy_with_logits都是来自torch.nn.functional的函数,BCELoss和BCEWithLogitsLoss都来自torch.nn,它们的区别:

函数名 解释
binary_cross_entropy Function that measures the Binary Cross Entropy between the target and the output
binary_cross_entropy_with_logits Function that measures Binary Cross Entropy between target and output logits
BCELoss Function that measures the Binary Cross Entropy between the target and the output
BCEWithLogitsLoss Function that measures Binary Cross Entropy between target and output logits

区别只在于这个logits,损失函数(类)名字中带了with_logits,这里的logits指的是该损失函数已经内部自带了计算logit的操作,无需在传入给这个loss函数之前手动使用sigmoid/softmax将之前网络的输入映射到[0,1]之间。
nn.functional.xxx是函数接口,而nn.Xxx是nn.functional.xxx的类封装,并且nn.Xxx都继承于一个共同祖先nn.Module。

In [257]: import torch
In [258]: import torch.nn as nn
In [259]: import torch.nn.functional as F

In [260]: true = torch.tensor([[1., 0., 1.], [1., 0., 0.]])
In [261]: pred = torch.rand((2,3))

In [262]: true
Out[262]:
tensor([[1., 0., 1.],
        [1., 0., 0.]])

In [263]: pred
Out[263]:
tensor([[0.0391, 0.7691, 0.1190],
        [0.8846, 0.1628, 0.2641]])

In [264]: F.binary_cross_entropy(torch.sigmoid(pred), true)
Out[264]: tensor(0.7361)

In [265]: F.binary_cross_entropy_with_logits(pred, true)
Out[265]: tensor(0.7361)


In [267]: lf2 = nn.BCELoss()
In [268]: lf2(torch.sigmoid(pred), true)
Out[268]: tensor(0.7361)

In [269]: lf = nn.BCEWithLogitsLoss()
In [270]: lf(pred, true)
Out[270]: tensor(0.7361)

# -(ylog(p)+(1-y)log(1-p))
In [268]: torch.sum(-(true*torch.log(torch.sigmoid(pred))+(1-true)*torch.log(1-torch.sigmoid(pred))))/6  
Out[268]: tensor(0.7361)
上一篇下一篇

猜你喜欢

热点阅读