label smooth

2021-12-18  本文已影响0人  三方斜阳

标签平滑:
Label Smoothing(标签平滑)是一个经典的正则化方法,机器学习的样本中通常会存在少量错误标签,这些错误标签会影响到预测的效果。标签平滑采用如下思路解决这个问题:在训练时即假设标签可能存在错误,避免“过分”相信训练样本的标签。当目标函数为交叉熵时,这一思想有非常简单的实现:

# coding: UTF-8
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class CELoss(nn.Module):
    ''' Cross Entropy Loss with label smoothing '''
    def __init__(self, label_smooth=None, class_num=2):
        super().__init__()
        self.label_smooth = label_smooth
        self.class_num = class_num
 
    def forward(self, pred, target):
        ''' 
        Args:
            pred: prediction of model output    [N, M]
            target: ground truth of sampler [N]
        '''
        eps = 1e-12
        
        if self.label_smooth is not None:
            # cross entropy loss with label smoothing
            logprobs = F.log_softmax(pred, dim=1)   # softmax + log
            target = F.one_hot(target, self.class_num)  # 转换成one-hot
            
            # label smoothing
            # 实现 1
            # target = (1.0-self.label_smooth)*target + self.label_smooth/self.class_num    
            # 实现 2
            # implement 2
            target = torch.clamp(target.float(), min=self.label_smooth/(self.class_num-1), max=1.0-self.label_smooth)
            loss = -1*torch.sum(target*logprobs, 1)
        else:
            # standard cross entropy loss
            loss = -1.*pred.gather(1, target.unsqueeze(-1)) + torch.log(torch.exp(pred+eps).sum(dim=1))
        return loss.mean()

在训练过程中调用:

from CELoss import CELoss
loss2 = CELoss(label_smooth=0.05, class_num=2)  # 标签平滑
with torch.no_grad():
    for texts, labels in data_iter:
        outputs = model(texts) # [batch_size, num_class=2]
        loss = F.cross_entropy(outputs, labels)
        # loss=loss2(outputs, labels)
上一篇 下一篇

猜你喜欢

热点阅读