预测燃油效率
2019-06-01 本文已影响0人
萍水间人
先来简要地看一下数据集
是一个类似于CSV文件的
column_names = ['MPG','Cylinders','Displacement','Horsepower','Weight',
'Acceleration', 'Model Year', 'Origin']
raw_dataset = pd.read_csv(dataset_path, names=column_names, na_values='?',
comment='\t', sep=" ", skipinitialspace=True)
dataset = raw_dataset.copy()
dataset.head()
将\t
后面的内容都设置为注释
用空白符分隔
得到一个这样的数据:
进行一点处理
dataset.isna().sum()
dataset = dataset.dropna()
然后
origin = dataset.pop('Origin')
dataset['USA'] = (origin == 1)*1.0
dataset['Europe'] = (origin == 2)*1.0
dataset['Japan'] = (origin == 3)*1.0
dataset.tail()
数据就变成这样子了
然后对数据集抽样
DataFrame.sample(n=None, frac=None, replace=False, weights=None, random_state=None, axis=None)
train_dataset = dataset.sample(frac=0.8,random_state=0)
test_dataset = dataset.drop(train_dataset.index)
n是要抽取的行数。(例如n=20000时,抽取其中的2W行)
frac是抽取的比列。(有一些时候,我们并对具体抽取的行数不关系,我们想抽取其中的百分比,这个时候就可以选择使用frac,例如frac=0.8,就是抽取其中80%)
replace:是否为有放回抽样,取replace=True时为有放回抽样。
weights这个是每个样本的权重,具体可以看官方文档说明。
axis是选择抽取数据的行还是列。axis=0的时是抽取行,axis=1时是抽取列(也就是说axis=1时,在列中随机抽取n列,在axis=0时,在行中随机抽取n行)
抽取了其中百分之八十的数据
画图:
sns.pairplot(train_dataset[["MPG", "Cylinders", "Displacement", "Weight"]], diag_kind="kde")
再做一次转置
train_stats = train_dataset.describe()
train_stats.pop("MPG")
train_stats = train_stats.transpose()
train_stats
训练和验证的数据
train_labels = train_dataset.pop('MPG')
test_labels = test_dataset.pop('MPG')
也就是百分之八十的数据用于训练,百分之二十的数据用于验证
开始建模
def norm(x):
return (x - train_stats['mean']) / train_stats['std']
normed_train_data = norm(train_dataset)
normed_test_data = norm(test_dataset)
def build_model():
model = keras.Sequential([
layers.Dense(64, activation=tf.nn.relu, input_shape=[len(train_dataset.keys())]),
layers.Dense(64, activation=tf.nn.relu),
layers.Dense(1)
])
optimizer = tf.train.RMSPropOptimizer(0.001)
model.compile(loss='mse',
optimizer=optimizer,
metrics=['mae', 'mse'])
return model
构建模型 model = build_model()
打印一下摘要model.summary()
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 64) 512
_________________________________________________________________
dense_1 (Dense) (None, 64) 4160
_________________________________________________________________
dense_2 (Dense) (None, 1) 65
=================================================================
Total params: 4,737
Trainable params: 4,737
Non-trainable params: 0
_________________________________________________________________
未完待续