keras_regression
2017-05-17 本文已影响43人
Ledestin
问题:神经网络可以用来模拟回归问题 (regression),例如给下面一组数据,用一条线来对数据进行拟合,并可以预测新输入 x 的输出值。
Paste_Image.png用 Keras 构建回归神经网络的步骤:
1.导入模块并创建数据
2.建立模型
3.激活模型
4.训练模型
5.验证模型
6.可视化结果
Demo.py
#导入模块
import numpy as np
import matplotlib.pyplot as plt
from keras.models import Sequential #models.Sequential,用来一层一层一层的去建立神经层;
from keras.layers import Dense #layers.Dense 意思是这个神经层是全连接层
# 创建数据
X = np.linspace(-1, 1, 200)
np.random.shuffle(X)# 数据随机化
Y = 0.5 * X + 2 + np.random.normal(0, 0.05, (200,))# 创建数据及参数, 并加入噪声
# 绘制数据
plt.scatter(X, Y)
plt.show()
# 分为训练数据和测试数据
X_train, Y_train = X[:160], Y[:160] # train 前 160 data points
X_test, Y_test = X[160:], Y[160:] # test 后 40 data points
# 使用keras创建神经网络
# Sequential是指一层层堆叠的神经网络
# Dense是指全连接层
#建立模型
model = Sequential()# 用 Sequential 建立 model
model.add(Dense(units = 1, input_dim = 1))#model.add 添加神经层,添加的是 Dense 全连接神经层。
#Dense参数有两个,一个是输入数据和输出数据的维度,本代码的例子中 x 和 y 是一维的。
#如果需要添加下一个神经层的时候,不用再定义输入的纬度,因为它默认就把前一层的输出作为当前层的输入。
#在这个例子里,只需要一层就够了。
# 激活模型
#选择损失函数和优化方法
model.compile(loss = 'mse', optimizer = 'sgd')#误差函数用的是 mse 均方误差;优化器用的是 sgd 随机梯度下降法
print '----Training----'
# 训练过程
for step in range(501):
# 进行训练, 返回损失(代价)函数
cost = model.train_on_batch(X_train, Y_train)#训练的时候用 model.train_on_batch 一批一批的训练 X_train, Y_train。默认的返回值是 cost
if step % 100 == 0:#每100步输出一下结果
print 'loss: ', cost
#检验模型
print '----Testing----'
# 训练结束进行测试
cost = model.evaluate(X_test, Y_test, batch_size = 40)
print 'test loss: ', cost
# 获取参数
W, b = model.layers[0].get_weights()#weights 和 biases 是取在模型的第一层 model.layers[0] 学习到的参数
print 'Weights: ',W
print 'Biases: ', b
#可视化结果
# plotting the prediction
Y_pred = model.predict(X_test)
plt.scatter(X_test, Y_test)
plt.plot(X_test, Y_pred)
plt.show()
结果:
Paste_Image.png----Training----
----Training----
loss: 4.05043315887
loss: 0.0760689899325
loss: 0.00436494173482
loss: 0.00265302229673
loss: 0.00251104100607
loss: 0.00248079258017
----Testing----
40/40 [==============================] - 0s
test loss: 0.00255159125663
Weights: [[ 0.49018186]]
Biases: [ 2.00758481]
Paste_Image.png