分类网络中的数据标注操作

2017-12-05  本文已影响53人  chunleiml

很多标准的开源数据集都是标注好的数据,我们下载下来就可以直接当训练集使用,例如MNIST, CIFAR-10。但在实际的深度学习训练任务中,有时我们拿到的数据并没有标签,需要我们自己对数据进行标注。例如拿到一套头颈部的CT数据,我们需要预测cerebellum的z轴范围,这就是一个典型的分类任务,我们需要做的就是给cerebellum一个标签,非cerebellum一个标签,通常是使用one-hot编码方式。

#给训练集标注数据
for i in range(cerebellum_label.shape[0]):
#index是cerebellum所在层的索引
    if i in index:
        y = [1,0]            
    else:
        y = [0,1]
    y_train.append(y)
print(len(y_train))
y_train = np.array(y_train)

训练完成后,我们用训练好的模型对一套没有进行过标注的数据进行预测

#model为训练网络保存的模型
label_test = model.predict(imgs, batch_size=24, verbose=1)
label_test = np.where(label_test>=0.5,1,0)
index = []
for i, label in enumerate(label_test):
    if label[0]==1:
        if label[1]==0:
            #现在的index就是预测的z轴索引范围
            index.append(i)
上一篇 下一篇

猜你喜欢

热点阅读