pytorch: NLLLoss

2019-06-20  本文已影响0人  Andy512

In the pytorch NLLLoss, if the label is a value which is negative, the return loss will be zero:

from __future__ import print_function, absolute_import

import torch
import torch.nn as nn
from torch.autograd import Variable

class testCrossEntropy(nn.Module):
    def __init__(self):
        super(testCrossEntropy, self).__init__()
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, inputs, target):
        return self.criterion(inputs, target)

if __name__ == '__main__':
    predict = torch.FloatTensor([[0, 0.2, 0.7, 0.1, 0]])
    crit = testCrossEntropy()
    v1 = crit(Variable(predict.log()), Variable(torch.LongTensor([-100])))
    v2 = crit(Variable(predict.log()), Variable(torch.LongTensor([1])))

v1=0.
v2=1.6094
NLLLoss is the core part of cross entropy in pytorch.

上一篇 下一篇

猜你喜欢

热点阅读