在新数据上生成预测结果
2019-03-15 本文已影响0人
庵下桃花仙
在新数据上生成预测结果
predictions = model.predict(x_test)
print(predictions[0].shape)
print(np.sum(predictions[0]))
print(np.argmax(predictions[0]))
(46,)
1.0000001
3
小结:
- 如果要对 N 个类别的数据点进行分类,网络的最后一层应该是大小为 N 的 Dense 层;
- 对应单标签、多分类问题,网络的最后一层应该使用 softmax 激活,使得输出在 N 个输出类别上的概率分布;
- 这种问题的损失函数几乎总是分类交叉熵。它将网络输出的概率分布与目标的真实分布之间距离最小化;
- 处理多分类问题的标签有两种方法:
1、通过分类编码(one-hot 编码)对标签进行编码,然后使用 categorical_crossentropy 作为损失函数;
2、将标签编码为整数,然后使用sparse_categorical_crossentropy 作为损失函数 - 如果要将数据划分到许多类别中,应该避免使用太小的中间层,以免在网络中造成信息瓶颈。