label smooth
2021-12-18 本文已影响0人
三方斜阳
标签平滑:
Label Smoothing(标签平滑)是一个经典的正则化方法,机器学习的样本中通常会存在少量错误标签,这些错误标签会影响到预测的效果。标签平滑采用如下思路解决这个问题:在训练时即假设标签可能存在错误,避免“过分”相信训练样本的标签。当目标函数为交叉熵时,这一思想有非常简单的实现:
# coding: UTF-8
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class CELoss(nn.Module):
''' Cross Entropy Loss with label smoothing '''
def __init__(self, label_smooth=None, class_num=2):
super().__init__()
self.label_smooth = label_smooth
self.class_num = class_num
def forward(self, pred, target):
'''
Args:
pred: prediction of model output [N, M]
target: ground truth of sampler [N]
'''
eps = 1e-12
if self.label_smooth is not None:
# cross entropy loss with label smoothing
logprobs = F.log_softmax(pred, dim=1) # softmax + log
target = F.one_hot(target, self.class_num) # 转换成one-hot
# label smoothing
# 实现 1
# target = (1.0-self.label_smooth)*target + self.label_smooth/self.class_num
# 实现 2
# implement 2
target = torch.clamp(target.float(), min=self.label_smooth/(self.class_num-1), max=1.0-self.label_smooth)
loss = -1*torch.sum(target*logprobs, 1)
else:
# standard cross entropy loss
loss = -1.*pred.gather(1, target.unsqueeze(-1)) + torch.log(torch.exp(pred+eps).sum(dim=1))
return loss.mean()
在训练过程中调用:
from CELoss import CELoss
loss2 = CELoss(label_smooth=0.05, class_num=2) # 标签平滑
with torch.no_grad():
for texts, labels in data_iter:
outputs = model(texts) # [batch_size, num_class=2]
loss = F.cross_entropy(outputs, labels)
# loss=loss2(outputs, labels)