pytorch中交叉熵函数用法
2019-09-26 本文已影响0人
习惯了千姿百态
loss = F.cross_entropy(pred, gold,ignore_index=IGNORE_ID,reduction='elementwise_mean')
pred:预测得到的矩阵m*n,表示共有m个样本,类数为n,该矩阵某个元素pred[i][j]表示,第i个样本分到第j类的概率
gold:表示标准的label矩阵,1*m。表示每个样本对应的正确的类别
ignore_index:表示在计算过程中,不考虑gold中值为IGNORE_ID的样本
reduction='elementwise_mean':pred去log-softmax之后的输出与Label对应的那个值拿出来,再去掉负号,再求均值。这是默认值