用PyTorch编写自定义的损失函数(Custom Loss F

2023-01-22  本文已影响0人  LabVIEW_Python

1. 什么是损失函数(Loss Function)

损失函数在训练AI模型的过程中,用于计算模型预测值与真实值之间的误差(Error),又称误差函数(Error function)

损失函数(Loss Function)

2. PyTorch中内建的损失函数

torch.nn中内建了很多常用的损失函数,依据用途,可以分为三类:

3. 使用PyTorch编写自定义损失函数

3.1 何时自己编写损失函数

当任务不再是单一的回归或者分类,PyTorch的内建损失函数无法满足任务需求时,需要开发者自己手动编写损失函数,例如:YOLOv5的total_loss

total_loss = class loss + bbox loss+ object loss

就需要手动编写。

3.2 用PyTorch编写损失函数模板

方案1 -- 编写一个继承nn.Module的Loss函数类,并在forward()方法中实现loss计算:

class MyLoss(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, x, y):
        return torch.mean(torch.pow((x - y), 2))

方案2 -- 编写一个Loss函数类,并在_ call _方法中实现loss计算:

class MyLoss():
    def __init__(self):
        
    def forward(self, x, y):
        return torch.mean(torch.pow((x - y), 2))

以上两种Loss函数类的使用方式都是:

loss_fn = MyLoss() # 实例化Loss函数
loss = loss_fn(preds, targets) # 计算loss值,即预测值与真实值之间的误差
...
loss.backward() #

在实际开发工作中,以上两种实现方案都非常常见,运行效率几乎一致,无性能优劣之分;由于loss函数并没有需要学习的参数,所以使用哪一种主要看开发者习惯的编程风格(coding style)。

参考资料:

  1. [Solved] What is the correct way to implement custom loss function?
  2. A-Collection-of-important-tasks-in-pytorch
上一篇 下一篇

猜你喜欢

热点阅读