使用Keras处理回归问题——以预测房价为例
2019-10-01 本文已影响0人
进击的码农设计师
1.波士顿房价数据集:
本次使用的波士顿房价数据集包含506个样本,其中404个训练样本和102个测试样本。每个样本包含13个特征,需要注意的是每个特征都有不同的取值范围,包括0-1、0-100、1-12等不同的取值范围。
2.分步骤实现:
- 1.加载波士顿房价数据集
- 2.数据预处理
- 3.构建网络
- 4.使用K折验证
- 5.绘制验证分数
3.全流程代码:
import numpy as np
from keras.datasets import boston_housing
from keras import models
from keras import layers
import matplotlib.pyplot as plt
# 加载数据集
(train_data,train_targets),(test_data,test_targets) = boston_housing.load_data()
# 数据预处理
mean = train_data.mean(axis=0)
train_data -= mean
std = train_data.std(axis=0)
train_data /= std
test_data -= mean
test_data /= std
# 构建网络
def build_model():
model = models.Sequential()
model.add(layers.Dense(64,activation='relu',input_shape=(train_data.shape[1],)))
model.add(layers.Dense(64,activation='relu'))
model.add(layers.Dense(1))
model.compile(optimizer='rmsprop',loss='mse',metrics=['mae'])
return model
# 使用K折验证
k = 4
num_val_samples = len(train_data) // k
num_epochs = 100
all_mae_histories = []
for i in range(k):
# print('processing fold #', i)
# 准备验证数据:第k个分区的数据
val_data = train_data[i * num_val_samples:(i + 1) * num_val_samples]
val_targets = train_targets[i * num_val_samples:(i + 1) * num_val_samples]
# 准备训练数据:其他所有分区的数据
partial_train_data = np.concatenate([train_data[:i * num_val_samples], train_data[(i + 1) * num_val_samples:]],
axis=0)
partial_train_targets = np.concatenate([train_targets[:i * num_val_samples], train_targets[(i + 1) * num_val_samples:]],
axis=0)
# 构建keras模型(已编译)
model = build_model()
# 训练模型,verbose=0表示静默模式
history = model.fit(partial_train_data, partial_train_targets, validation_data=(val_data, val_targets),
epochs=num_epochs, batch_size=1, verbose=0)
# print(history.history.keys())
mae_history = history.history['val_mae']
all_mae_histories.append(mae_history)
# 计算所有轮次中的k折验证分数平均值
average_mae_history = [np.mean([x[i] for x in all_mae_histories]) for i in range(num_epochs)]
print(average_mae_history)
# 绘制验证分数
plt.plot(range(1,len(average_mae_history)+1),average_mae_history)
plt.xlabel('Epochs')
plt.ylabel('Validation MAE')
plt.show()
Reference:
《Deep Learning with Python》