Deep Learning

scikit-learn中classification_repo

2018-01-23  本文已影响1513人  王二牛牛
from sklearn.metrics import classification_report, accuracy_score
classification_report(y_test, y_pre, target_names=target_names)

scikit-learn中的classification_report是强大的函数,可以计算查全率,查准率,F1参数,keras中没有相关的函数,并且keraslabel为one-hot,输出的为[0.3.0.2,0.5]这样的softmax数据,如何转化为[4,5,5]这样的标签数据用于适配classification_report函数。

1、one-hot转化为整数label

代码如下:

import numpy as np
def onehot_to_category(onehot):
      b = np.array([[0], [1], [2],[3],[4],[5]])
      return np.dot(onehot,b).flatten()

原理很简单,矩阵的乘法
输入时n*6的矩阵,n个样本,一共6类,6代表onehot编码如[0,0,0,0,0,1]

b为列向量,shape为6*1

np.dot(a,b)代表a矩阵与b矩阵的乘法,输出矩阵为n*1,即转化为整数的lable形式

2、softmax输出转化为整数label

代码如下
def softmax_to_category(a):
max2 = []
for item in a:
i=np.argmax(item)
max2.append(i)
return max2
比较low的方法,诸葛数据取最大值,取最大值的索引,就是整数形式的label

上一篇 下一篇

猜你喜欢

热点阅读