大数据,机器学习,人工智能自然语言处理 (让机器更懂你)pytorch

4 损失函数-庖丁解牛之pytorch

2018-10-21  本文已影响12人  readilen

基类定义

pytorch损失类也是模块的派生,损失类的基类是_Loss,定义如下

class _Loss(Module):
    def __init__(self, size_average=None, reduce=None, reduction='elementwise_mean'):
        super(_Loss, self).__init__()
        if size_average is not None or reduce is not None:
            self.reduction = _Reduction.legacy_get_string(size_average, reduce)
        else:
            self.reduction = reduction

看这个类,有两点我们知道:

子类介绍

从_Loss派生的类有

名称 说明 公式
_WeightedLoss 这个类只是申请了一个权重空间,功能和_Loss一样
L1Loss X、Y可以是任意形状的输入,X与Y的 shape相同
PoissonNLLLoss 适合多目标分类
KLDivLoss 适用于连续分布的距离计算
MSELoss 均方差
BCEWithLogitsLoss 多目标不需要经过sigmoid
HingeEmbeddingLoss Y中的元素只能为1或-1 适用于学习非线性embedding、半监督学习。用于计算两个输入是否相似
MultiLabelMarginLoss 适用于多目标分类
SmoothL1Loss
SoftMarginLoss
CosineEmbeddingLoss
MarginRankingLoss
TripletMarginLoss

从_WeightedLoss继续派生的函数有

名称 说明
NLLLoss
BCELoss
CrossEntropyLoss
MultiLabelSoftMarginLoss
MultiMarginLoss
上一篇 下一篇

猜你喜欢

热点阅读